| #!/usr/bin/env bash |
| set -euo pipefail |
|
|
| export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-2}" |
|
|
| ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" |
|
|
| if [[ $# -lt 2 ]]; then |
| cat <<'USAGE' |
| Usage: |
| script/run_eval_zeroshot.sh <model_path> <output_dir> [--mmlu] [extra lm_eval args...] |
|
|
| Examples: |
| script/run_eval_zeroshot.sh /path/to/model /path/to/output |
| script/run_eval_zeroshot.sh /path/to/model /path/to/output --mmlu |
| script/run_eval_zeroshot.sh /path/to/model /path/to/output --tasks arc_easy,arc_challenge,hellaswag |
| USAGE |
| exit 1 |
| fi |
|
|
| MODEL_PATH="$1" |
| OUTPUT_DIR="$2" |
| shift 2 |
|
|
| TASKS="${TASKS:-arc_easy,arc_challenge,hellaswag,piqa,winogrande,openbookqa,boolq}" |
| DEVICE="${DEVICE:-cuda}" |
| BATCH_SIZE="${BATCH_SIZE:-auto}" |
| NUM_FEWSHOT="${NUM_FEWSHOT:-0}" |
| OUTPUT_FILE="${OUTPUT_FILE:-zeroshot_results.json}" |
|
|
| INCLUDE_MMLU=0 |
| PASSTHROUGH_ARGS=() |
| for arg in "$@"; do |
| if [[ "$arg" == "--mmlu" ]]; then |
| INCLUDE_MMLU=1 |
| continue |
| fi |
| PASSTHROUGH_ARGS+=("$arg") |
| done |
| if [[ "$INCLUDE_MMLU" -eq 1 && ",$TASKS," != *",mmlu,"* ]]; then |
| TASKS="${TASKS},mmlu" |
| fi |
|
|
| mkdir -p "$OUTPUT_DIR" |
| RUN_ARGS_FILE="$OUTPUT_DIR/run_zeroshot_args.txt" |
| RESOLVED_MODEL_PATH="$MODEL_PATH" |
|
|
| git_commit="unknown" |
| if git -C "$ROOT_DIR" rev-parse --is-inside-work-tree >/dev/null 2>&1; then |
| git_commit=$(git -C "$ROOT_DIR" rev-parse HEAD) |
| fi |
| start_epoch=$(date +%s) |
| start_time=$(date --iso-8601=seconds) |
|
|
| LM_EVAL_CMD=( |
| lm_eval |
| --model hf |
| --model_args "pretrained=$RESOLVED_MODEL_PATH" |
| --tasks "$TASKS" |
| --num_fewshot "$NUM_FEWSHOT" |
| --device "$DEVICE" |
| --batch_size 32 |
| --output_path "$OUTPUT_DIR/$OUTPUT_FILE" |
| ) |
| LM_EVAL_CMD+=("${PASSTHROUGH_ARGS[@]}") |
|
|
| { |
| echo "git_commit=$git_commit" |
| echo "start_time=$start_time" |
| echo "resolved_model_path=$RESOLVED_MODEL_PATH" |
| echo "command:" |
| printf '%q ' "${LM_EVAL_CMD[@]}" |
| echo |
| } > "$RUN_ARGS_FILE" |
|
|
| write_run_summary() { |
| local exit_code=$? |
| local end_epoch end_time elapsed_seconds |
| end_epoch=$(date +%s) |
| end_time=$(date --iso-8601=seconds) |
| elapsed_seconds=$((end_epoch - start_epoch)) |
| { |
| echo "end_time=$end_time" |
| echo "elapsed_seconds=$elapsed_seconds" |
| echo "exit_code=$exit_code" |
| } >> "$RUN_ARGS_FILE" |
| } |
| trap write_run_summary EXIT |
|
|
| echo "Running: ${LM_EVAL_CMD[*]}" |
| exec "${LM_EVAL_CMD[@]}" |
|
|