temp_ss / script /run_abprune_inst.sh
LJYAI's picture
upload script
3738140 verified
#!/usr/bin/env bash
set -euo pipefail
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
if [[ $# -lt 1 ]]; then
cat <<'USAGE'
Usage:
script/run_abprune_inst.sh <model> [output_dir] [extra fuse_layers args...]
Examples:
script/run_abprune_inst.sh Qwen/Qwen3-1.7B
script/run_abprune_inst.sh /path/to/model /path/to/output --num_progressive 8
USAGE
exit 1
fi
# all meta-llama/Llama-2-7b-hf, meta-llama/Llama-3.1-8B, facebook/opt-6.7b
model="$1"
shift
dataset="${DATASET:-slimpajama}"
dataset_config="${DATASET_CONFIG:-none}"
num_progressive="${NUM_PROGRESSIVE:-16}"
seq_len="${SEQ_LEN:-1024}"
target_tokens="${TARGET_TOKENS:-500000}"
calib_sequences="${CALIB_SEQUENCES:-128}"
distill_batch_size="${DISTILL_BATCH_SIZE:-1}"
eval_batch_size="${EVAL_BATCH_SIZE:-1}"
eval_num_samples="${EVAL_NUM_SAMPLES:-200}"
distill_seq_len="${DISTILL_SEQ_LEN:-1024}"
lora_epochs="${LORA_EPOCHS:-0}"
calibration_source="${CALIBRATION_SOURCE:-lm}"
instruction_dataset="${INSTRUCTION_DATASET:-}"
instruction_config="${INSTRUCTION_CONFIG:-none}"
instruction_split="${INSTRUCTION_SPLIT:-train}"
instruction_format="${INSTRUCTION_FORMAT:-auto}"
instruction_field_instruction="${INSTRUCTION_FIELD_INSTRUCTION:-instruction}"
instruction_field_input="${INSTRUCTION_FIELD_INPUT:-input}"
instruction_field_output="${INSTRUCTION_FIELD_OUTPUT:-output}"
distillation_source="${DISTILLATION_SOURCE:-$calibration_source}"
distill_inst_samples="${DISTILL_INST_SAMPLES:-500}"
distill_epochs="${DISTILL_EPOCHS:-1.0}"
distill_kl_weight="${DISTILL_KL_WEIGHT:-0.02}"
distill_kl_temp="${DISTILL_KL_TEMP:-4.0}"
distill_hidden_mse_weight="${DISTILL_HIDDEN_MSE_WEIGHT:-1.0}"
distill_attn_mse_weight="${DISTILL_ATTN_MSE_WEIGHT:-0.25}"
distill_mlp_mse_weight="${DISTILL_MLP_MSE_WEIGHT:-1.0}"
reparam_eta="${REPARAM_ETA:-0}"
reparam_gamma="${REPARAM_GAMMA:-0}"
reparam_attn_reg_scale="${REPARAM_ATTN_REG_SCALE:-1.0}"
reparam_mlp_reg_scale="${REPARAM_MLP_REG_SCALE:-1.0}"
reparam_param_subset="${REPARAM_PARAM_SUBSET:-mlp}"
dtype="${DTYPE:-bfloat16}"
batch_size="${BATCH_SIZE:-2}"
use_pertensor_fisher="${USE_PERTENSOR_FISHER:-0}"
save_full_model_cycles="${SAVE_FULL_MODEL_CYCLES:-6,11}"
comm_skip_post_reselect="${COMM_SKIP_POST_RESELECT:-1}"
head_permute="${HEAD_PERMUTE:-1}"
head_permute_select="${HEAD_PERMUTE_SELECT:-$head_permute}"
head_permute_merge="${HEAD_PERMUTE_MERGE:-$head_permute}"
fisher_args=(--fisher_mode param)
if [[ "$use_pertensor_fisher" == "1" ]]; then
fisher_args=(--fisher_mode tensor)
fi
output_dir_suffix="progressive_common_${num_progressive}_nopost_only_last"
if [[ "$use_pertensor_fisher" == "1" ]]; then
output_dir_suffix="${output_dir_suffix}_pertensor"
fi
model_slug="$(echo "$model" | tr '/:@' '___' | tr -cs '[:alnum:]_.-' '_' | sed 's/^_\\+//; s/_\\+$//')"
output_dir_default="$repo_root/results/${model_slug}_${output_dir_suffix}"
output_dir=""
if [[ $# -gt 0 && "${1:0:2}" != "--" ]]; then
output_dir="$1"
shift
elif [[ -n "${OUTDIR:-}" ]]; then
output_dir="${OUTDIR}"
else
output_dir="${output_dir_default}"
fi
if [[ -n "${RUN_NAME:-}" ]]; then
output_dir="${output_dir}_${RUN_NAME}"
fi
python_args=(
--model "$model" \
--dataset "$dataset" \
--dataset_config "$dataset_config" \
--target_tokens "$target_tokens" \
--num_samples "$calib_sequences" \
--seq_len "$seq_len" \
--batch_size "$batch_size" \
--calibration_source "$calibration_source" \
--distillation_source "$distillation_source" \
--distill_batch_size "$distill_batch_size" \
--distill_inst_samples "$distill_inst_samples" \
--distill_seq_len "$distill_seq_len" \
--distill_epochs "$distill_epochs" \
--eval_batch_size "$eval_batch_size" \
--eval_seq_len "$seq_len" \
--eval_num_samples "$eval_num_samples" \
--distill_kl_weight "$distill_kl_weight" \
--distill_kl_temp "$distill_kl_temp" \
--distill_hidden_mse_weight "$distill_hidden_mse_weight" \
--distill_attn_mse_weight "$distill_attn_mse_weight" \
--distill_mlp_mse_weight "$distill_mlp_mse_weight" \
--reparam_eta "$reparam_eta" \
--reparam_gamma "$reparam_gamma" \
--reparam_attn_reg_scale "$reparam_attn_reg_scale" \
--reparam_mlp_reg_scale "$reparam_mlp_reg_scale" \
--reparam_param_subset "$reparam_param_subset" \
--distill_weight_decay 0.0 \
--distill_max_grad_norm 1.0 \
--distill_grad_accum_steps 1 \
--distill_eval_every 2000 \
--lora_eval_every 2000 \
--lora_epochs "$lora_epochs" \
)
if [[ -n "$instruction_dataset" ]]; then
python_args+=(
--instruction_dataset "$instruction_dataset" \
--instruction_config "$instruction_config" \
--instruction_split "$instruction_split" \
--instruction_format "$instruction_format" \
--instruction_field_instruction "$instruction_field_instruction" \
--instruction_field_input "$instruction_field_input" \
--instruction_field_output "$instruction_field_output" \
)
fi
python_args+=("${fisher_args[@]}")
if [[ -n "$save_full_model_cycles" ]]; then
python_args+=(--save_full_model_cycles "$save_full_model_cycles")
fi
python_args+=(
--distill_method reparam \
--redistrib_teacher_source previous_cycle \
--comm_enabled \
--comm_mu_auto \
--layer auto \
--exclude_pairs 0,1,-1 \
--num_progressive "$num_progressive" \
--output_dir "$output_dir" \
--dtype "$dtype" \
)
if [[ "$comm_skip_post_reselect" == "1" ]]; then
python_args+=(--comm_skip_post_reselect)
fi
if [[ "$head_permute_select" == "0" ]]; then
python_args+=(--no_head_permute_select)
fi
if [[ "$head_permute_merge" == "0" ]]; then
python_args+=(--no_head_permute_merge)
fi
python_args+=("$@")
mkdir -p "$output_dir"
run_args_file="$output_dir/run_args.txt"
git_commit="unknown"
if git -C "$repo_root" rev-parse --is-inside-work-tree >/dev/null 2>&1; then
git_commit=$(git -C "$repo_root" rev-parse HEAD)
fi
start_epoch=$(date +%s)
start_time=$(date --iso-8601=seconds)
{
echo "git_commit=$git_commit"
echo "start_time=$start_time"
echo "HEAD_PERMUTE=$head_permute"
echo "HEAD_PERMUTE_SELECT=$head_permute_select"
echo "HEAD_PERMUTE_MERGE=$head_permute_merge"
echo "command:"
printf '%q ' python "$repo_root/src_inst/fuse_layers.py" "${python_args[@]}"
echo
} > "$run_args_file"
write_run_summary() {
local exit_code=$?
local end_epoch end_time elapsed_seconds
end_epoch=$(date +%s)
end_time=$(date --iso-8601=seconds)
elapsed_seconds=$((end_epoch - start_epoch))
{
echo "end_time=$end_time"
echo "elapsed_seconds=$elapsed_seconds"
echo "exit_code=$exit_code"
} >> "$run_args_file"
}
trap write_run_summary EXIT
PYTHONPATH="$repo_root/src_inst:$repo_root${PYTHONPATH:+:$PYTHONPATH}" \
python "$repo_root/src_inst/fuse_layers.py" "${python_args[@]}"