arcspan / scripts /run_train_v9.sh
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/bin/bash
# =============================================================================
# Round 9: 5-class R9 dataset training
#
# R9 data:
# - R8 strict/deleaked train
# - CyberNER_harmonized deleaked + OPF span-format normalized
# - DNRTI deleaked
# - Prefix-80 deduplicated
#
# Main change vs R8:
# - O-downsample lowered from 0.7 to 0.3. R8 improved recall with 0.7, but
# overpredicted Indicator; R9 has better Org/System coverage and should not
# need such an aggressive O-token loss mask.
#
# Defaults target RTX PRO 6000 96GB. Override BATCH_SIZE/GRAD_ACCUM_STEPS for
# smaller GPUs, e.g. BATCH_SIZE=1 GRAD_ACCUM_STEPS=8 on a 32GB RTX 5090.
# =============================================================================
set -euo pipefail
cd ~/alkyline
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
DATA="${DATA:-data/processed}"
LABELS="${LABELS:-data/label_spaces}"
RESULTS="${RESULTS:-results}"
PYTHON="${PYTHON:-$PWD/.venv/bin/python}"
if [ ! -x "$PYTHON" ]; then
PYTHON="$(command -v python3)"
fi
OPF=("$PYTHON" -m opf)
PATIENCE="${PATIENCE:-3}"
BATCH_SIZE="${BATCH_SIZE:-4}"
GRAD_ACCUM_STEPS="${GRAD_ACCUM_STEPS:-2}"
EPOCHS="${EPOCHS:-15}"
LEARNING_RATE="${LEARNING_RATE:-5e-5}"
O_DOWNSAMPLE="${O_DOWNSAMPLE:-0.3}"
DEVICE="${DEVICE:-cuda}"
TRAIN_LOG="${TRAIN_LOG:-train_r9.log}"
CKPT_DIR="${CKPT_DIR:-checkpoints/r9_5class}"
mkdir -p "$RESULTS"
echo "================================================================"
echo " R9 PRE-FLIGHT"
echo "================================================================"
"$PYTHON" scripts/build_r9_dataset.py 2>&1 | tee build_r9.log
"$PYTHON" scripts/audit_r9_readiness.py \
--json-out "$RESULTS/r9_readiness_audit.json" \
--md-out "$RESULTS/r9_readiness_audit.md" \
2>&1 | tee audit_r9_readiness.log
echo "================================================================"
echo " R9: strict merged 5-class training"
echo "================================================================"
echo "Start time: $(date)"
echo "Train data: $DATA/r9_5class_train.jsonl"
echo "Validation data: $DATA/r9_5class_valid.jsonl"
echo "Checkpoint dir: $CKPT_DIR"
echo "Batch/accum: $BATCH_SIZE / $GRAD_ACCUM_STEPS"
echo "Epochs/patience: $EPOCHS / $PATIENCE"
echo "LR: $LEARNING_RATE"
echo "O-downsample: $O_DOWNSAMPLE"
echo "Device: $DEVICE"
echo "================================================================"
"${OPF[@]}" train "$DATA/r9_5class_train.jsonl" \
--validation-dataset "$DATA/r9_5class_valid.jsonl" \
--label-space-json "$LABELS/cyner_5class.json" \
--output-dir "$CKPT_DIR" \
--overwrite-output \
--epochs "$EPOCHS" --batch-size "$BATCH_SIZE" --grad-accum-steps "$GRAD_ACCUM_STEPS" \
--learning-rate "$LEARNING_RATE" \
--warmup-fraction 0.1 --lr-schedule cosine \
--loss-fn focal --focal-gamma 2.0 \
--llrd-factor 0.9 \
--o-downsample "$O_DOWNSAMPLE" \
--device "$DEVICE" 2>&1 | tee "$TRAIN_LOG" &
TRAIN_PID=$!
bash scripts/early_stop_monitor.sh "$TRAIN_LOG" "$PATIENCE" "$TRAIN_PID" &
MONITOR_PID=$!
wait "$TRAIN_PID" 2>/dev/null || true
kill "$MONITOR_PID" 2>/dev/null || true
echo "Training finished: $(date)"
BEST_EPOCH=$(grep '^epoch' "$TRAIN_LOG" \
| awk -F'[ :/=]' '{for(i=1;i<=NF;i++){if($i=="val_loss")print $(i+1)" "$2}}' \
| sort -n | head -1 | awk '{print $2}')
echo "Best epoch by val_loss: $BEST_EPOCH"
CKPT="$CKPT_DIR"
if [ ! -f "$CKPT/model.safetensors" ] && [ -n "$BEST_EPOCH" ]; then
echo "Training killed early - using epoch $BEST_EPOCH checkpoint"
CKPT="${CKPT_DIR}/epoch_${BEST_EPOCH}"
elif [ -n "$BEST_EPOCH" ] && [ -d "${CKPT_DIR}/epoch_${BEST_EPOCH}" ]; then
echo "Using best-epoch checkpoint (epoch $BEST_EPOCH) over final"
CKPT="${CKPT_DIR}/epoch_${BEST_EPOCH}"
fi
echo "Selected checkpoint: $CKPT"
EVAL_FLAGS="--checkpoint $CKPT --decode-mode viterbi --per-class --label-counts --device $DEVICE"
echo ""
echo "================================================================"
echo " EVALUATION PHASE"
echo "================================================================"
echo ""
echo "===== Eval R9: Enriched 5-class test ====="
"${OPF[@]}" eval "$DATA/enriched_5class_test.jsonl" \
$EVAL_FLAGS 2>&1 | tee eval_r9_enriched.log
echo ""
echo "===== Eval R9: CyNER test ====="
"${OPF[@]}" eval "$DATA/cyner_test.jsonl" \
$EVAL_FLAGS 2>&1 | tee eval_r9_cyner.log
echo ""
echo "===== Eval R9: SecureBERT2 5-class test ====="
"${OPF[@]}" eval "$DATA/securebert2_5class_test.jsonl" \
$EVAL_FLAGS 2>&1 | tee eval_r9_sb2.log
echo ""
echo "===== Eval R9: APTNER 5-class independent test ====="
"${OPF[@]}" eval "$DATA/aptner_5class_test_clean.jsonl" \
$EVAL_FLAGS 2>&1 | tee eval_r9_aptner.log
echo ""
echo "===== Viterbi Grid Search (R9 validation) ====="
"$PYTHON" scripts/viterbi_grid_search.py \
--checkpoint "$CKPT" \
--val-data "$DATA/r9_5class_valid.jsonl" \
--output "$RESULTS/viterbi_r9_best.json" \
--device "$DEVICE" 2>&1 | tee viterbi_r9.log
echo ""
echo "===== Exact-Match Eval R9: CyNER test ====="
"$PYTHON" scripts/eval_exact_match.py \
--checkpoint "$CKPT" \
--test-data "$DATA/cyner_test.jsonl" \
--device "$DEVICE" \
--decode-mode viterbi \
--json-out "$RESULTS/r9_cyner_exact_match.json" \
2>&1 | tee exact_r9_cyner.log
echo ""
echo "===== Exact-Match Eval R9: APTNER independent test ====="
"$PYTHON" scripts/eval_exact_match.py \
--checkpoint "$CKPT" \
--test-data "$DATA/aptner_5class_test_clean.jsonl" \
--device "$DEVICE" \
--decode-mode viterbi \
--json-out "$RESULTS/r9_aptner_exact_match.json" \
2>&1 | tee exact_r9_aptner.log
echo ""
echo "================================================================"
echo " R9 COMPLETE - $(date)"
echo "================================================================"
echo "Checkpoint: $CKPT"
echo "Logs: $TRAIN_LOG, eval_r9_{enriched,cyner,sb2,aptner}.log, exact_r9_{cyner,aptner}.log, viterbi_r9.log"
echo ""
echo "--- Enriched ---"
grep -E '(micro|macro|^ )' eval_r9_enriched.log 2>/dev/null | head -20 || true
echo ""
echo "--- APTNER independent ---"
grep -E '(micro|macro|^ )' eval_r9_aptner.log 2>/dev/null | head -20 || true
echo "================================================================"