| #!/usr/bin/env bash |
| set -euo pipefail |
|
|
| |
|
|
| |
| export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-2}" |
|
|
| repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" |
| workdir="$repo_root/compare_model/LLM-Pruner" |
|
|
| base_model="${BASE_MODEL:-meta-llama/Llama-2-7b-hf}" |
| prune_ckpt_path="${PRUNE_CKPT_PATH:-llama2_7b_prune}" |
| pruning_ratio="${PRUNING_RATIO:-0.25}" |
| block_mlp_layer_start="${BLOCK_MLP_LAYER_START:-4}" |
| block_mlp_layer_end="${BLOCK_MLP_LAYER_END:-30}" |
| block_attention_layer_start="${BLOCK_ATTENTION_LAYER_START:-4}" |
| block_attention_layer_end="${BLOCK_ATTENTION_LAYER_END:-30}" |
| pruner_type="${PRUNER_TYPE:-taylor}" |
| taylor_mode="${TAYLOR_MODE:-param_first}" |
| device="${DEVICE:-cpu}" |
| eval_device="${EVAL_DEVICE:-cuda}" |
|
|
| default_script="hf_prune.py" |
| skip_eval_flag="--skip_post_eval" |
| if [[ "$base_model" == *"Llama-3"* ]] || [[ "$base_model" == *"Llama-3."* ]] || [[ "$base_model" == *"llama-3"* ]]; then |
| default_script="llama3.py" |
| skip_eval_flag="--skip_eval_after_prune" |
| fi |
| script_name="${PRUNE_SCRIPT:-$default_script}" |
|
|
| output_dir="${OUTDIR:-$workdir/prune_log/$prune_ckpt_path}" |
|
|
| python_args=( |
| --base_model "$base_model" |
| --pruning_ratio "$pruning_ratio" |
| --block_wise |
| --block_mlp_layer_start "$block_mlp_layer_start" |
| --block_mlp_layer_end "$block_mlp_layer_end" |
| --block_attention_layer_start "$block_attention_layer_start" |
| --block_attention_layer_end "$block_attention_layer_end" |
| --pruner_type "$pruner_type" |
| --taylor "$taylor_mode" |
| --device "$device" |
| --eval_device "$eval_device" |
| --save_ckpt_log_name "$prune_ckpt_path" |
| --save_model |
| "$skip_eval_flag" |
| ) |
| python_args+=("$@") |
|
|
| mkdir -p "$output_dir" |
| git_commit="unknown" |
| if git -C "$repo_root" rev-parse --is-inside-work-tree >/dev/null 2>&1; then |
| git_commit=$(git -C "$repo_root" rev-parse HEAD) |
| fi |
| { |
| echo "git_commit=$git_commit" |
| echo "command:" |
| printf '%q ' python "$repo_root/compare_model/LLM-Pruner/$script_name" "${python_args[@]}" |
| echo |
| } > "$output_dir/run_args.txt" |
|
|
| cd "$workdir" |
| PYTHONPATH="$workdir:$repo_root${PYTHONPATH:+:$PYTHONPATH}" \ |
| python "$script_name" "${python_args[@]}" |
|
|