File size: 4,957 Bytes
fd1afc8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | #!/usr/bin/env bash
set -uo pipefail
########################################
# 配置区(你只需要改这里)
########################################
SCRIPT_PATH="qwen3_plain_ar.py"
DATASET_PATH="muse_mucodec_chord.ds"
# tokenizer(必须是带 chat_template 的)
TOKENIZER_PATH="/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/final"
# checkpoint 列表
CHECKPOINTS=(
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-907"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-1814"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-2721"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-3628"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-4535"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-5442"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-6349"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-7256"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-8163"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-9070"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-9977"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-10884"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-11791"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-12698"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-13605"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-14512"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-15419"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-16326"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-17233"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-18140"
)
# 输出根目录
OUTPUT_ROOT="/root/batch_preditions_ablation"
# 每个 checkpoint 推理多少条
NUM_SAMPLES=20
########################################
# 推理参数(可以调)
########################################
DEVICE="cuda:0"
DTYPE="bfloat16"
ATTN_IMPLEMENTATION="sdpa"
TEMPERATURE=1.0
TOP_K=50
TOP_P=0.9
MAX_NEW_TOKENS=4096
# 是否跳过音频解码(调试建议先开)
SKIP_DECODE=false
########################################
# 日志文件
########################################
FAILED_LOG="${OUTPUT_ROOT}/failed_cases.log"
SUCCESS_LOG="${OUTPUT_ROOT}/success_cases.log"
########################################
# 开始执行
########################################
mkdir -p "${OUTPUT_ROOT}"
touch "${FAILED_LOG}"
touch "${SUCCESS_LOG}"
echo "======================================" | tee -a "${SUCCESS_LOG}"
echo "Batch inference started at $(date)" | tee -a "${SUCCESS_LOG}"
echo "Output root: ${OUTPUT_ROOT}" | tee -a "${SUCCESS_LOG}"
echo "======================================" | tee -a "${SUCCESS_LOG}"
for CKPT in "${CHECKPOINTS[@]}"; do
CKPT_NAME=$(basename "${CKPT}")
OUT_DIR="${OUTPUT_ROOT}/${CKPT_NAME}"
CKPT_LOG="${OUT_DIR}/run.log"
echo "======================================"
echo "Running checkpoint: ${CKPT_NAME}"
echo "Output dir: ${OUT_DIR}"
echo "======================================"
if [ ! -d "${CKPT}" ]; then
echo "[ERROR] checkpoint directory not found: ${CKPT}" | tee -a "${FAILED_LOG}"
continue
fi
mkdir -p "${OUT_DIR}"
touch "${CKPT_LOG}"
for ((i=0; i<NUM_SAMPLES; i++)); do
echo "[INFO] checkpoint=${CKPT_NAME} sample_idx=${i}" | tee -a "${CKPT_LOG}"
CMD=(
python "${SCRIPT_PATH}" infer
--model_path "${CKPT}"
--tokenizer_path "${TOKENIZER_PATH}"
--dataset_path "${DATASET_PATH}"
--split validation
--sample_idx "${i}"
--device "${DEVICE}"
--dtype "${DTYPE}"
--attn_implementation "${ATTN_IMPLEMENTATION}"
--temperature "${TEMPERATURE}"
--top_k "${TOP_K}"
--top_p "${TOP_P}"
--max_new_tokens_per_section "${MAX_NEW_TOKENS}"
--output_dir "${OUT_DIR}"
--output_prefix "sample_${i}"
)
if [ "${SKIP_DECODE}" = true ]; then
CMD+=(--skip_decode)
fi
{
echo "[CMD] ${CMD[*]}"
"${CMD[@]}"
} >> "${CKPT_LOG}" 2>&1
EXIT_CODE=$?
if [ ${EXIT_CODE} -ne 0 ]; then
echo "[ERROR] checkpoint=${CKPT_NAME} sample_idx=${i} exit_code=${EXIT_CODE}" | tee -a "${FAILED_LOG}"
continue
else
echo "[OK] checkpoint=${CKPT_NAME} sample_idx=${i}" | tee -a "${SUCCESS_LOG}"
fi
done
echo "[DONE] checkpoint=${CKPT_NAME}" | tee -a "${SUCCESS_LOG}"
done
echo "======================================" | tee -a "${SUCCESS_LOG}"
echo "Batch inference finished at $(date)" | tee -a "${SUCCESS_LOG}"
echo "Success log: ${SUCCESS_LOG}" | tee -a "${SUCCESS_LOG}"
echo "Failed log: ${FAILED_LOG}" | tee -a "${SUCCESS_LOG}"
echo "All done." |