#!/usr/bin/env bash # ═══════════════════════════════════════════════════════════════════════════ # train.sh — Multi-stage training orchestrator with checkpoint system # ═══════════════════════════════════════════════════════════════════════════ # # Usage: bash train.sh [--config CONFIG] [--auto] # # Each stage prompts: [S]kip, [R]edo, [C]ontinue # Use --auto to skip all prompts and auto-detect what needs running # set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" export PYTHONPATH="${SCRIPT_DIR}:${PYTHONPATH:-}" CONFIG="${1:-configs/training_config.yaml}" AUTO_MODE=false # Parse args for arg in "$@"; do case $arg in --auto) AUTO_MODE=true ;; --config=*) CONFIG="${arg#*=}" ;; esac done # ── Colors ────────────────────────────────────────────────────────────────── RED='\033[0;31m' GREEN='\033[0;32m' YELLOW='\033[1;33m' CYAN='\033[0;36m' BOLD='\033[1m' NC='\033[0m' info() { echo -e "${CYAN}[INFO]${NC} $1"; } ok() { echo -e "${GREEN}[ OK]${NC} $1"; } warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } err() { echo -e "${RED}[FAIL]${NC} $1"; } # ── Stage prompt function ────────────────────────────────────────────────── # Asks user to [S]kip, [R]edo, or [C]ontinue for each stage prompt_stage() { local stage_name="$1" local check_file="$2" # File to check if stage already completed echo "" echo -e "${BOLD}═══ Stage: ${stage_name} ═══${NC}" if [ "$AUTO_MODE" = true ]; then if [ -n "$check_file" ] && [ -e "$check_file" ]; then info "Auto-mode: $check_file exists, skipping" return 1 # Skip fi return 0 # Continue fi if [ -n "$check_file" ] && [ -e "$check_file" ]; then warn "Previous output found: $check_file" echo -e " ${YELLOW}[S]${NC}kip | ${CYAN}[R]${NC}edo | ${GREEN}[C]${NC}ontinue" read -rp " Choice [S/R/C]: " choice case "${choice,,}" in r|redo) info "Redoing ${stage_name}..."; return 0 ;; c|continue) info "Continuing ${stage_name}..."; return 0 ;; *) info "Skipping ${stage_name}"; return 1 ;; esac else info "No previous output found. Running ${stage_name}..." return 0 fi } # ── Detect environment ───────────────────────────────────────────────────── detect_env() { echo -e "${BOLD}═══ Environment Detection ═══${NC}" # Python if command -v python3 &>/dev/null; then PYTHON=python3 elif command -v python &>/dev/null; then PYTHON=python else err "Python not found!" exit 1 fi ok "Python: $($PYTHON --version 2>&1)" # GPU if $PYTHON -c "import torch; print(torch.cuda.is_available())" 2>/dev/null | grep -q "True"; then GPU_AVAILABLE=true GPU_NAME=$($PYTHON -c "import torch; print(torch.cuda.get_device_name(0))" 2>/dev/null || echo "Unknown") ok "GPU: $GPU_NAME" # Check compute capability for bf16 COMPUTE_CAP=$($PYTHON -c "import torch; print(torch.cuda.get_device_capability()[0])" 2>/dev/null || echo "0") if [ "$COMPUTE_CAP" -ge 8 ]; then PRECISION="bf16" else PRECISION="fp16" fi ok "Precision: $PRECISION" else GPU_AVAILABLE=false PRECISION="fp32" warn "No GPU detected — training will use CPU (optimised settings)" fi # W&B if [ -n "${WANDB_API_KEY:-}" ]; then ok "W&B: API key found" else warn "W&B: No API key (WANDB_API_KEY). Logging to TensorBoard only." export WANDB_DISABLED=true fi } # ═══════════════════════════════════════════════════════════════════════════ # STAGE 1: Install dependencies & download models # ═══════════════════════════════════════════════════════════════════════════ stage_1_setup() { if prompt_stage "Setup & Dependencies" ".train_stage1_done"; then info "Installing Python dependencies..." $PYTHON -m pip install -r requirements.txt --quiet 2>&1 | tail -5 info "Downloading spaCy models..." $PYTHON -m spacy download en_core_web_sm --quiet 2>/dev/null || true info "Downloading NLTK data..." $PYTHON -c "import nltk; nltk.download('punkt', quiet=True); nltk.download('punkt_tab', quiet=True)" 2>/dev/null || true touch .train_stage1_done ok "Setup complete" fi } # ═══════════════════════════════════════════════════════════════════════════ # STAGE 2: Data preprocessing # ═══════════════════════════════════════════════════════════════════════════ stage_2_preprocess() { if prompt_stage "Data Preprocessing" "data/processed/train.jsonl"; then info "Preprocessing datasets into unified JSONL..." $PYTHON scripts/preprocess_data.py ok "Data preprocessing complete" fi } # ═══════════════════════════════════════════════════════════════════════════ # STAGE 3: Pre-train human pattern classifier # ═══════════════════════════════════════════════════════════════════════════ stage_3_pretrain_classifier() { if prompt_stage "Pre-train Human Pattern Classifier" "checkpoints/human_pattern_classifier.pt"; then info "Pre-training human pattern classifier on Kaggle datasets..." info "This may take a while on CPU (extracting features for ~100k texts)..." $PYTHON scripts/pretrain_human_pattern_classifier.py ok "Human pattern classifier pre-trained" fi } # ═══════════════════════════════════════════════════════════════════════════ # STAGE 4: Main model training # ═══════════════════════════════════════════════════════════════════════════ stage_4_train() { if prompt_stage "Main Model Training" "checkpoints/best_model/config.json"; then info "Starting main model training..." info "Config: $CONFIG" # Add V2 loss flag if classifier exists V2_FLAG="" if [ -f "checkpoints/human_pattern_classifier.pt" ]; then info "Human pattern classifier found — using V2 loss (with anti-AI term)" V2_FLAG="--use-v2-loss" fi $PYTHON scripts/train.py --config "$CONFIG" $V2_FLAG ok "Main training complete" fi } # ═══════════════════════════════════════════════════════════════════════════ # STAGE 5: Evaluation # ═══════════════════════════════════════════════════════════════════════════ stage_5_evaluate() { if prompt_stage "Evaluation" "logs/eval_results_test.json"; then info "Running evaluation on test set..." mkdir -p logs $PYTHON scripts/evaluate.py --config "$CONFIG" --split test ok "Evaluation complete" fi } # ═══════════════════════════════════════════════════════════════════════════ # Main # ═══════════════════════════════════════════════════════════════════════════ main() { echo "" echo -e "${BOLD}╔══════════════════════════════════════════════════════════╗${NC}" echo -e "${BOLD}║ Dyslexia Academic Writing Corrector — Training Suite ║${NC}" echo -e "${BOLD}╚══════════════════════════════════════════════════════════╝${NC}" echo "" detect_env stage_1_setup stage_2_preprocess stage_3_pretrain_classifier stage_4_train stage_5_evaluate echo "" echo -e "${GREEN}${BOLD}═══ All stages complete! ═══${NC}" echo -e " Model saved to: ${CYAN}checkpoints/best_model/${NC}" echo -e " Eval results: ${CYAN}logs/eval_results_test.json${NC}" echo -e " Start inference: ${CYAN}bash start.sh${NC}" echo "" } main