rewrite / train.sh
morpheuslord's picture
Add files using upload-large-folder tool
3df5819 verified
#!/usr/bin/env bash
# ═══════════════════════════════════════════════════════════════════════════
# train.sh β€” Multi-stage training orchestrator with checkpoint system
# ═══════════════════════════════════════════════════════════════════════════
#
# Usage: bash train.sh [--config CONFIG] [--auto]
#
# Each stage prompts: [S]kip, [R]edo, [C]ontinue
# Use --auto to skip all prompts and auto-detect what needs running
#
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
export PYTHONPATH="${SCRIPT_DIR}:${PYTHONPATH:-}"
CONFIG="${1:-configs/training_config.yaml}"
AUTO_MODE=false
# Parse args
for arg in "$@"; do
case $arg in
--auto) AUTO_MODE=true ;;
--config=*) CONFIG="${arg#*=}" ;;
esac
done
# ── Colors ──────────────────────────────────────────────────────────────────
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
BOLD='\033[1m'
NC='\033[0m'
info() { echo -e "${CYAN}[INFO]${NC} $1"; }
ok() { echo -e "${GREEN}[ OK]${NC} $1"; }
warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
err() { echo -e "${RED}[FAIL]${NC} $1"; }
# ── Stage prompt function ──────────────────────────────────────────────────
# Asks user to [S]kip, [R]edo, or [C]ontinue for each stage
prompt_stage() {
local stage_name="$1"
local check_file="$2" # File to check if stage already completed
echo ""
echo -e "${BOLD}═══ Stage: ${stage_name} ═══${NC}"
if [ "$AUTO_MODE" = true ]; then
if [ -n "$check_file" ] && [ -e "$check_file" ]; then
info "Auto-mode: $check_file exists, skipping"
return 1 # Skip
fi
return 0 # Continue
fi
if [ -n "$check_file" ] && [ -e "$check_file" ]; then
warn "Previous output found: $check_file"
echo -e " ${YELLOW}[S]${NC}kip | ${CYAN}[R]${NC}edo | ${GREEN}[C]${NC}ontinue"
read -rp " Choice [S/R/C]: " choice
case "${choice,,}" in
r|redo) info "Redoing ${stage_name}..."; return 0 ;;
c|continue) info "Continuing ${stage_name}..."; return 0 ;;
*) info "Skipping ${stage_name}"; return 1 ;;
esac
else
info "No previous output found. Running ${stage_name}..."
return 0
fi
}
# ── Detect environment ─────────────────────────────────────────────────────
detect_env() {
echo -e "${BOLD}═══ Environment Detection ═══${NC}"
# Python
if command -v python3 &>/dev/null; then
PYTHON=python3
elif command -v python &>/dev/null; then
PYTHON=python
else
err "Python not found!"
exit 1
fi
ok "Python: $($PYTHON --version 2>&1)"
# GPU
if $PYTHON -c "import torch; print(torch.cuda.is_available())" 2>/dev/null | grep -q "True"; then
GPU_AVAILABLE=true
GPU_NAME=$($PYTHON -c "import torch; print(torch.cuda.get_device_name(0))" 2>/dev/null || echo "Unknown")
ok "GPU: $GPU_NAME"
# Check compute capability for bf16
COMPUTE_CAP=$($PYTHON -c "import torch; print(torch.cuda.get_device_capability()[0])" 2>/dev/null || echo "0")
if [ "$COMPUTE_CAP" -ge 8 ]; then
PRECISION="bf16"
else
PRECISION="fp16"
fi
ok "Precision: $PRECISION"
else
GPU_AVAILABLE=false
PRECISION="fp32"
warn "No GPU detected β€” training will use CPU (optimised settings)"
fi
# W&B
if [ -n "${WANDB_API_KEY:-}" ]; then
ok "W&B: API key found"
else
warn "W&B: No API key (WANDB_API_KEY). Logging to TensorBoard only."
export WANDB_DISABLED=true
fi
}
# ═══════════════════════════════════════════════════════════════════════════
# STAGE 1: Install dependencies & download models
# ═══════════════════════════════════════════════════════════════════════════
stage_1_setup() {
if prompt_stage "Setup & Dependencies" ".train_stage1_done"; then
info "Installing Python dependencies..."
$PYTHON -m pip install -r requirements.txt --quiet 2>&1 | tail -5
info "Downloading spaCy models..."
$PYTHON -m spacy download en_core_web_sm --quiet 2>/dev/null || true
info "Downloading NLTK data..."
$PYTHON -c "import nltk; nltk.download('punkt', quiet=True); nltk.download('punkt_tab', quiet=True)" 2>/dev/null || true
touch .train_stage1_done
ok "Setup complete"
fi
}
# ═══════════════════════════════════════════════════════════════════════════
# STAGE 2: Data preprocessing
# ═══════════════════════════════════════════════════════════════════════════
stage_2_preprocess() {
if prompt_stage "Data Preprocessing" "data/processed/train.jsonl"; then
info "Preprocessing datasets into unified JSONL..."
$PYTHON scripts/preprocess_data.py
ok "Data preprocessing complete"
fi
}
# ═══════════════════════════════════════════════════════════════════════════
# STAGE 3: Pre-train human pattern classifier
# ═══════════════════════════════════════════════════════════════════════════
stage_3_pretrain_classifier() {
if prompt_stage "Pre-train Human Pattern Classifier" "checkpoints/human_pattern_classifier.pt"; then
info "Pre-training human pattern classifier on Kaggle datasets..."
info "This may take a while on CPU (extracting features for ~100k texts)..."
$PYTHON scripts/pretrain_human_pattern_classifier.py
ok "Human pattern classifier pre-trained"
fi
}
# ═══════════════════════════════════════════════════════════════════════════
# STAGE 4: Main model training
# ═══════════════════════════════════════════════════════════════════════════
stage_4_train() {
if prompt_stage "Main Model Training" "checkpoints/best_model/config.json"; then
info "Starting main model training..."
info "Config: $CONFIG"
# Add V2 loss flag if classifier exists
V2_FLAG=""
if [ -f "checkpoints/human_pattern_classifier.pt" ]; then
info "Human pattern classifier found β€” using V2 loss (with anti-AI term)"
V2_FLAG="--use-v2-loss"
fi
$PYTHON scripts/train.py --config "$CONFIG" $V2_FLAG
ok "Main training complete"
fi
}
# ═══════════════════════════════════════════════════════════════════════════
# STAGE 5: Evaluation
# ═══════════════════════════════════════════════════════════════════════════
stage_5_evaluate() {
if prompt_stage "Evaluation" "logs/eval_results_test.json"; then
info "Running evaluation on test set..."
mkdir -p logs
$PYTHON scripts/evaluate.py --config "$CONFIG" --split test
ok "Evaluation complete"
fi
}
# ═══════════════════════════════════════════════════════════════════════════
# Main
# ═══════════════════════════════════════════════════════════════════════════
main() {
echo ""
echo -e "${BOLD}╔══════════════════════════════════════════════════════════╗${NC}"
echo -e "${BOLD}β•‘ Dyslexia Academic Writing Corrector β€” Training Suite β•‘${NC}"
echo -e "${BOLD}β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•${NC}"
echo ""
detect_env
stage_1_setup
stage_2_preprocess
stage_3_pretrain_classifier
stage_4_train
stage_5_evaluate
echo ""
echo -e "${GREEN}${BOLD}═══ All stages complete! ═══${NC}"
echo -e " Model saved to: ${CYAN}checkpoints/best_model/${NC}"
echo -e " Eval results: ${CYAN}logs/eval_results_test.json${NC}"
echo -e " Start inference: ${CYAN}bash start.sh${NC}"
echo ""
}
main