temp_ss / script /run_abprune_small.sh
LJYAI's picture
upload script
3738140 verified
#!/usr/bin/env bash
set -euo pipefail
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-3}"
repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
if [[ $# -lt 1 ]]; then
cat <<'USAGE'
Usage:
script/run_abprune.sh <model> [output_dir] [extra fuse_layers args...]
Examples:
script/run_abprune.sh Qwen/Qwen3-1.7B
script/run_abprune.sh /path/to/model /path/to/output --num_progressive 8
USAGE
exit 1
fi
model="$1"
shift
# whole settings
dataset="${DATASET:-slimpajama}"
dataset_config="${DATASET_CONFIG:-none}"
num_progressive="${NUM_PROGRESSIVE:-14}"
dtype="${DTYPE:-bfloat16}"
use_pertensor_fisher="${USE_PERTENSOR_FISHER:-0}"
save_full_model_cycles="${SAVE_FULL_MODEL_CYCLES:-6,11}"
head_permute="${HEAD_PERMUTE:-1}"
head_permute_select="${HEAD_PERMUTE_SELECT:-$head_permute}"
head_permute_merge="${HEAD_PERMUTE_MERGE:-$head_permute}"
# calibration dataset
calib_sequences="${CALIB_SEQUENCES:-128}"
seq_len="${SEQ_LEN:-512}"
# distillation dataset
distill_seq_len="${DISTILL_SEQ_LEN:-512}"
target_tokens="${TARGET_TOKENS:-500000}"
distill_batch_size="${DISTILL_BATCH_SIZE:-1}"
# distillation evaluation
batch_size="${BATCH_SIZE:-1}"
eval_batch_size="${EVAL_BATCH_SIZE:-1}"
eval_num_samples="${EVAL_NUM_SAMPLES:-200}"
lora_epochs="${LORA_EPOCHS:-0}"
distill_epochs="${DISTILL_EPOCHS:-1.0}"
distill_kl_weight="${DISTILL_KL_WEIGHT:-0.01}"
distill_kl_temp="${DISTILL_KL_TEMP:-4.0}"
fisher_args=(--fisher_mode param)
if [[ "$use_pertensor_fisher" == "1" ]]; then
fisher_args=(--fisher_mode tensor)
fi
output_dir_suffix="progressive_common_${num_progressive}_nopost_only_last"
if [[ "$use_pertensor_fisher" == "1" ]]; then
output_dir_suffix="${output_dir_suffix}_pertensor"
fi
model_slug="$(echo "$model" | tr '/:@' '___' | tr -cs '[:alnum:]_.-' '_' | sed 's/^_\\+//; s/_\\+$//')"
output_dir_default="$repo_root/results/${model_slug}_${output_dir_suffix}"
output_dir=""
if [[ $# -gt 0 && "${1:0:2}" != "--" ]]; then
output_dir="$1"
shift
elif [[ -n "${OUTDIR:-}" ]]; then
output_dir="${OUTDIR}"
else
output_dir="${output_dir_default}"
fi
if [[ -n "${RUN_NAME:-}" ]]; then
output_dir="${output_dir}_${RUN_NAME}"
fi
python_args=(
--model "$model" \
--dataset "$dataset" \
--dataset_config "$dataset_config" \
--target_tokens "$target_tokens" \
--num_samples "$calib_sequences" \
--seq_len "$seq_len" \
--batch_size "$batch_size" \
--distill_batch_size "$distill_batch_size" \
--distill_seq_len "$distill_seq_len" \
--distill_epochs "$distill_epochs" \
--eval_batch_size "$eval_batch_size" \
--eval_seq_len "$seq_len" \
--eval_num_samples "$eval_num_samples" \
--distill_kl_weight "$distill_kl_weight" \
--distill_kl_temp "$distill_kl_temp" \
--distill_weight_decay 0.0 \
--distill_max_grad_norm 1.0 \
--distill_grad_accum_steps 1 \
--distill_eval_every 2000 \
--lora_eval_every 2000 \
--lora_epochs "$lora_epochs" \
--auto_metric dwce \
)
# --auto_cosine_topk 5
python_args+=("${fisher_args[@]}")
if [[ -n "$save_full_model_cycles" ]]; then
python_args+=(--save_full_model_cycles "$save_full_model_cycles")
fi
python_args+=(
--distill_method reparam \
--redistrib_teacher_source previous_cycle \
--comm_enabled \
--comm_mu_auto \
--layer auto \
--exclude_pairs 0,1,-1 \
--num_progressive "$num_progressive" \
--output_dir "$output_dir" \
--dtype "$dtype" \
)
if [[ "$head_permute_select" == "0" ]]; then
python_args+=(--no_head_permute_select)
fi
if [[ "$head_permute_merge" == "0" ]]; then
python_args+=(--no_head_permute_merge)
fi
python_args+=("$@")
mkdir -p "$output_dir"
run_args_file="$output_dir/run_args.txt"
git_commit="unknown"
if git -C "$repo_root" rev-parse --is-inside-work-tree >/dev/null 2>&1; then
git_commit=$(git -C "$repo_root" rev-parse HEAD)
fi
start_epoch=$(date +%s)
start_time=$(date --iso-8601=seconds)
{
echo "git_commit=$git_commit"
echo "start_time=$start_time"
echo "HEAD_PERMUTE=$head_permute"
echo "HEAD_PERMUTE_SELECT=$head_permute_select"
echo "HEAD_PERMUTE_MERGE=$head_permute_merge"
echo "command:"
printf '%q ' python "$repo_root/src/fuse_layers.py" "${python_args[@]}"
echo
} > "$run_args_file"
write_run_summary() {
local exit_code=$?
local end_epoch end_time elapsed_seconds
end_epoch=$(date +%s)
end_time=$(date --iso-8601=seconds)
elapsed_seconds=$((end_epoch - start_epoch))
{
echo "end_time=$end_time"
echo "elapsed_seconds=$elapsed_seconds"
echo "exit_code=$exit_code"
} >> "$run_args_file"
}
trap write_run_summary EXIT
python "$repo_root/src/fuse_layers.py" "${python_args[@]}"