| #!/usr/bin/env bash |
| set -euo pipefail |
|
|
| export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" |
|
|
| repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" |
|
|
| if [[ $# -lt 1 ]]; then |
| cat <<'USAGE' |
| Usage: |
| script/run_abprune.sh <model> [output_dir] [extra fuse_layers args...] |
|
|
| Examples: |
| script/run_abprune.sh Qwen/Qwen3-1.7B |
| script/run_abprune.sh /path/to/model /path/to/output --num_progressive 8 |
| USAGE |
| exit 1 |
| fi |
| |
| model="$1" |
| shift |
|
|
| dataset="${DATASET:-slimpajama}" |
| dataset_config="${DATASET_CONFIG:-none}" |
| num_progressive="${NUM_PROGRESSIVE:-16}" |
| seq_len="${SEQ_LEN:-1024}" |
| target_tokens="${TARGET_TOKENS:-500000}" |
| calib_sequences="${CALIB_SEQUENCES:-128}" |
| distill_batch_size="${DISTILL_BATCH_SIZE:-1}" |
| eval_batch_size="${EVAL_BATCH_SIZE:-1}" |
| eval_num_samples="${EVAL_NUM_SAMPLES:-200}" |
| distill_seq_len="${DISTILL_SEQ_LEN:-1024}" |
| lora_epochs="${LORA_EPOCHS:-0}" |
| distill_epochs="${DISTILL_EPOCHS:-1.0}" |
| distill_kl_weight="${DISTILL_KL_WEIGHT:-0.02}" |
| distill_kl_temp="${DISTILL_KL_TEMP:-4.0}" |
| distill_hidden_mse_weight="${DISTILL_HIDDEN_MSE_WEIGHT:-1.0}" |
| distill_attn_mse_weight="${DISTILL_ATTN_MSE_WEIGHT:-0.25}" |
| distill_mlp_mse_weight="${DISTILL_MLP_MSE_WEIGHT:-1.0}" |
| reparam_eta="${REPARAM_ETA:-0}" |
| reparam_gamma="${REPARAM_GAMMA:-0}" |
| reparam_attn_reg_scale="${REPARAM_ATTN_REG_SCALE:-1.0}" |
| reparam_mlp_reg_scale="${REPARAM_MLP_REG_SCALE:-1.0}" |
| reparam_param_subset="${REPARAM_PARAM_SUBSET:-mlp}" |
| dtype="${DTYPE:-bfloat16}" |
| batch_size="${BATCH_SIZE:-2}" |
| use_pertensor_fisher="${USE_PERTENSOR_FISHER:-0}" |
| save_full_model_cycles="${SAVE_FULL_MODEL_CYCLES:-6,11}" |
| comm_skip_post_reselect="${COMM_SKIP_POST_RESELECT:-1}" |
| head_permute="${HEAD_PERMUTE:-0}" |
| head_permute_select="${HEAD_PERMUTE_SELECT:-$head_permute}" |
| head_permute_merge="${HEAD_PERMUTE_MERGE:-$head_permute}" |
|
|
| fisher_args=(--fisher_mode param) |
| if [[ "$use_pertensor_fisher" == "1" ]]; then |
| fisher_args=(--fisher_mode tensor) |
| fi |
|
|
| output_dir_suffix="progressive_common_${num_progressive}_nopost_only_last" |
| if [[ "$use_pertensor_fisher" == "1" ]]; then |
| output_dir_suffix="${output_dir_suffix}_pertensor" |
| fi |
|
|
| model_slug="$(echo "$model" | tr '/:@' '___' | tr -cs '[:alnum:]_.-' '_' | sed 's/^_\\+//; s/_\\+$//')" |
| output_dir_default="$repo_root/results/${model_slug}_${output_dir_suffix}" |
| output_dir="" |
| if [[ $# -gt 0 && "${1:0:2}" != "--" ]]; then |
| output_dir="$1" |
| shift |
| elif [[ -n "${OUTDIR:-}" ]]; then |
| output_dir="${OUTDIR}" |
| else |
| output_dir="${output_dir_default}" |
| fi |
| if [[ -n "${RUN_NAME:-}" ]]; then |
| output_dir="${output_dir}_${RUN_NAME}" |
| fi |
|
|
| python_args=( |
| --model "$model" \ |
| --dataset "$dataset" \ |
| --dataset_config "$dataset_config" \ |
| --target_tokens "$target_tokens" \ |
| --num_samples "$calib_sequences" \ |
| --seq_len "$seq_len" \ |
| --batch_size "$batch_size" \ |
| --distill_batch_size "$distill_batch_size" \ |
| --distill_seq_len "$distill_seq_len" \ |
| --distill_epochs "$distill_epochs" \ |
| --eval_batch_size "$eval_batch_size" \ |
| --eval_seq_len "$seq_len" \ |
| --eval_num_samples "$eval_num_samples" \ |
| --distill_kl_weight "$distill_kl_weight" \ |
| --distill_kl_temp "$distill_kl_temp" \ |
| --distill_hidden_mse_weight "$distill_hidden_mse_weight" \ |
| --distill_attn_mse_weight "$distill_attn_mse_weight" \ |
| --distill_mlp_mse_weight "$distill_mlp_mse_weight" \ |
| --reparam_eta "$reparam_eta" \ |
| --reparam_gamma "$reparam_gamma" \ |
| --reparam_attn_reg_scale "$reparam_attn_reg_scale" \ |
| --reparam_mlp_reg_scale "$reparam_mlp_reg_scale" \ |
| --reparam_param_subset "$reparam_param_subset" \ |
| --distill_weight_decay 0.0 \ |
| --distill_max_grad_norm 1.0 \ |
| --distill_grad_accum_steps 1 \ |
| --distill_eval_every 2000 \ |
| --lora_eval_every 2000 \ |
| --lora_epochs "$lora_epochs" \ |
| ) |
|
|
| python_args+=("${fisher_args[@]}") |
| if [[ -n "$save_full_model_cycles" ]]; then |
| python_args+=(--save_full_model_cycles "$save_full_model_cycles") |
| fi |
|
|
| python_args+=( |
| --distill_method reparam \ |
| --redistrib_teacher_source previous_cycle \ |
| --comm_enabled \ |
| --comm_mu_auto \ |
| --layer auto \ |
| --exclude_pairs 0,1,-1 \ |
| --num_progressive "$num_progressive" \ |
| --output_dir "$output_dir" \ |
| --dtype "$dtype" \ |
| ) |
| if [[ "$comm_skip_post_reselect" == "1" ]]; then |
| python_args+=(--comm_skip_post_reselect) |
| fi |
| if [[ "$head_permute_select" == "0" ]]; then |
| python_args+=(--no_head_permute_select) |
| fi |
| if [[ "$head_permute_merge" == "0" ]]; then |
| python_args+=(--no_head_permute_merge) |
| fi |
| python_args+=("$@") |
|
|
| mkdir -p "$output_dir" |
| run_args_file="$output_dir/run_args.txt" |
| git_commit="unknown" |
| if git -C "$repo_root" rev-parse --is-inside-work-tree >/dev/null 2>&1; then |
| git_commit=$(git -C "$repo_root" rev-parse HEAD) |
| fi |
| start_epoch=$(date +%s) |
| start_time=$(date --iso-8601=seconds) |
| { |
| echo "git_commit=$git_commit" |
| echo "start_time=$start_time" |
| echo "HEAD_PERMUTE=$head_permute" |
| echo "HEAD_PERMUTE_SELECT=$head_permute_select" |
| echo "HEAD_PERMUTE_MERGE=$head_permute_merge" |
| echo "command:" |
| printf '%q ' python "$repo_root/src/fuse_layers.py" "${python_args[@]}" |
| echo |
| } > "$run_args_file" |
|
|
| write_run_summary() { |
| local exit_code=$? |
| local end_epoch end_time elapsed_seconds |
| end_epoch=$(date +%s) |
| end_time=$(date --iso-8601=seconds) |
| elapsed_seconds=$((end_epoch - start_epoch)) |
| { |
| echo "end_time=$end_time" |
| echo "elapsed_seconds=$elapsed_seconds" |
| echo "exit_code=$exit_code" |
| } >> "$run_args_file" |
| } |
| trap write_run_summary EXIT |
|
|
| python "$repo_root/src/fuse_layers.py" "${python_args[@]}" |
|
|