#!/usr/bin/env bash set -euo pipefail export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" if [[ $# -lt 1 ]]; then cat <<'USAGE' Usage: script/run_abprune.sh [output_dir] [extra fuse_layers args...] Examples: script/run_abprune.sh Qwen/Qwen3-1.7B script/run_abprune.sh /path/to/model /path/to/output --num_progressive 8 USAGE exit 1 fi # all meta-llama/Llama-2-7b-hf, meta-llama/Llama-3.1-8B model="$1" shift dataset="${DATASET:-slimpajama}" dataset_config="${DATASET_CONFIG:-none}" num_progressive="${NUM_PROGRESSIVE:-16}" seq_len="${SEQ_LEN:-1024}" target_tokens="${TARGET_TOKENS:-500000}" calib_sequences="${CALIB_SEQUENCES:-128}" distill_batch_size="${DISTILL_BATCH_SIZE:-1}" eval_batch_size="${EVAL_BATCH_SIZE:-1}" eval_num_samples="${EVAL_NUM_SAMPLES:-200}" distill_seq_len="${DISTILL_SEQ_LEN:-1024}" lora_epochs="${LORA_EPOCHS:-0}" distill_epochs="${DISTILL_EPOCHS:-1.0}" distill_kl_weight="${DISTILL_KL_WEIGHT:-0.02}" distill_kl_temp="${DISTILL_KL_TEMP:-4.0}" distill_hidden_mse_weight="${DISTILL_HIDDEN_MSE_WEIGHT:-1.0}" distill_attn_mse_weight="${DISTILL_ATTN_MSE_WEIGHT:-0.25}" distill_mlp_mse_weight="${DISTILL_MLP_MSE_WEIGHT:-1.0}" reparam_eta="${REPARAM_ETA:-0}" reparam_gamma="${REPARAM_GAMMA:-0}" reparam_attn_reg_scale="${REPARAM_ATTN_REG_SCALE:-1.0}" reparam_mlp_reg_scale="${REPARAM_MLP_REG_SCALE:-1.0}" reparam_param_subset="${REPARAM_PARAM_SUBSET:-mlp}" dtype="${DTYPE:-bfloat16}" batch_size="${BATCH_SIZE:-2}" use_pertensor_fisher="${USE_PERTENSOR_FISHER:-0}" save_full_model_cycles="${SAVE_FULL_MODEL_CYCLES:-6,11}" comm_skip_post_reselect="${COMM_SKIP_POST_RESELECT:-1}" head_permute="${HEAD_PERMUTE:-0}" head_permute_select="${HEAD_PERMUTE_SELECT:-$head_permute}" head_permute_merge="${HEAD_PERMUTE_MERGE:-$head_permute}" fisher_args=(--fisher_mode param) if [[ "$use_pertensor_fisher" == "1" ]]; then fisher_args=(--fisher_mode tensor) fi output_dir_suffix="progressive_common_${num_progressive}_nopost_only_last" if [[ "$use_pertensor_fisher" == "1" ]]; then output_dir_suffix="${output_dir_suffix}_pertensor" fi model_slug="$(echo "$model" | tr '/:@' '___' | tr -cs '[:alnum:]_.-' '_' | sed 's/^_\\+//; s/_\\+$//')" output_dir_default="$repo_root/results/${model_slug}_${output_dir_suffix}" output_dir="" if [[ $# -gt 0 && "${1:0:2}" != "--" ]]; then output_dir="$1" shift elif [[ -n "${OUTDIR:-}" ]]; then output_dir="${OUTDIR}" else output_dir="${output_dir_default}" fi if [[ -n "${RUN_NAME:-}" ]]; then output_dir="${output_dir}_${RUN_NAME}" fi python_args=( --model "$model" \ --dataset "$dataset" \ --dataset_config "$dataset_config" \ --target_tokens "$target_tokens" \ --num_samples "$calib_sequences" \ --seq_len "$seq_len" \ --batch_size "$batch_size" \ --distill_batch_size "$distill_batch_size" \ --distill_seq_len "$distill_seq_len" \ --distill_epochs "$distill_epochs" \ --eval_batch_size "$eval_batch_size" \ --eval_seq_len "$seq_len" \ --eval_num_samples "$eval_num_samples" \ --distill_kl_weight "$distill_kl_weight" \ --distill_kl_temp "$distill_kl_temp" \ --distill_hidden_mse_weight "$distill_hidden_mse_weight" \ --distill_attn_mse_weight "$distill_attn_mse_weight" \ --distill_mlp_mse_weight "$distill_mlp_mse_weight" \ --reparam_eta "$reparam_eta" \ --reparam_gamma "$reparam_gamma" \ --reparam_attn_reg_scale "$reparam_attn_reg_scale" \ --reparam_mlp_reg_scale "$reparam_mlp_reg_scale" \ --reparam_param_subset "$reparam_param_subset" \ --distill_weight_decay 0.0 \ --distill_max_grad_norm 1.0 \ --distill_grad_accum_steps 1 \ --distill_eval_every 2000 \ --lora_eval_every 2000 \ --lora_epochs "$lora_epochs" \ ) python_args+=("${fisher_args[@]}") if [[ -n "$save_full_model_cycles" ]]; then python_args+=(--save_full_model_cycles "$save_full_model_cycles") fi python_args+=( --distill_method reparam \ --redistrib_teacher_source previous_cycle \ --comm_enabled \ --comm_mu_auto \ --layer auto \ --exclude_pairs 0,1,-1 \ --num_progressive "$num_progressive" \ --output_dir "$output_dir" \ --dtype "$dtype" \ ) if [[ "$comm_skip_post_reselect" == "1" ]]; then python_args+=(--comm_skip_post_reselect) fi if [[ "$head_permute_select" == "0" ]]; then python_args+=(--no_head_permute_select) fi if [[ "$head_permute_merge" == "0" ]]; then python_args+=(--no_head_permute_merge) fi python_args+=("$@") mkdir -p "$output_dir" run_args_file="$output_dir/run_args.txt" git_commit="unknown" if git -C "$repo_root" rev-parse --is-inside-work-tree >/dev/null 2>&1; then git_commit=$(git -C "$repo_root" rev-parse HEAD) fi start_epoch=$(date +%s) start_time=$(date --iso-8601=seconds) { echo "git_commit=$git_commit" echo "start_time=$start_time" echo "HEAD_PERMUTE=$head_permute" echo "HEAD_PERMUTE_SELECT=$head_permute_select" echo "HEAD_PERMUTE_MERGE=$head_permute_merge" echo "command:" printf '%q ' python "$repo_root/src/fuse_layers.py" "${python_args[@]}" echo } > "$run_args_file" write_run_summary() { local exit_code=$? local end_epoch end_time elapsed_seconds end_epoch=$(date +%s) end_time=$(date --iso-8601=seconds) elapsed_seconds=$((end_epoch - start_epoch)) { echo "end_time=$end_time" echo "elapsed_seconds=$elapsed_seconds" echo "exit_code=$exit_code" } >> "$run_args_file" } trap write_run_summary EXIT python "$repo_root/src/fuse_layers.py" "${python_args[@]}"