File size: 10,267 Bytes
3df5819 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | #!/usr/bin/env bash
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# train.sh β Multi-stage training orchestrator with checkpoint system
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#
# Usage: bash train.sh [--config CONFIG] [--auto]
#
# Each stage prompts: [S]kip, [R]edo, [C]ontinue
# Use --auto to skip all prompts and auto-detect what needs running
#
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
export PYTHONPATH="${SCRIPT_DIR}:${PYTHONPATH:-}"
CONFIG="${1:-configs/training_config.yaml}"
AUTO_MODE=false
# Parse args
for arg in "$@"; do
case $arg in
--auto) AUTO_MODE=true ;;
--config=*) CONFIG="${arg#*=}" ;;
esac
done
# ββ Colors ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
BOLD='\033[1m'
NC='\033[0m'
info() { echo -e "${CYAN}[INFO]${NC} $1"; }
ok() { echo -e "${GREEN}[ OK]${NC} $1"; }
warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
err() { echo -e "${RED}[FAIL]${NC} $1"; }
# ββ Stage prompt function ββββββββββββββββββββββββββββββββββββββββββββββββββ
# Asks user to [S]kip, [R]edo, or [C]ontinue for each stage
prompt_stage() {
local stage_name="$1"
local check_file="$2" # File to check if stage already completed
echo ""
echo -e "${BOLD}βββ Stage: ${stage_name} βββ${NC}"
if [ "$AUTO_MODE" = true ]; then
if [ -n "$check_file" ] && [ -e "$check_file" ]; then
info "Auto-mode: $check_file exists, skipping"
return 1 # Skip
fi
return 0 # Continue
fi
if [ -n "$check_file" ] && [ -e "$check_file" ]; then
warn "Previous output found: $check_file"
echo -e " ${YELLOW}[S]${NC}kip | ${CYAN}[R]${NC}edo | ${GREEN}[C]${NC}ontinue"
read -rp " Choice [S/R/C]: " choice
case "${choice,,}" in
r|redo) info "Redoing ${stage_name}..."; return 0 ;;
c|continue) info "Continuing ${stage_name}..."; return 0 ;;
*) info "Skipping ${stage_name}"; return 1 ;;
esac
else
info "No previous output found. Running ${stage_name}..."
return 0
fi
}
# ββ Detect environment βββββββββββββββββββββββββββββββββββββββββββββββββββββ
detect_env() {
echo -e "${BOLD}βββ Environment Detection βββ${NC}"
# Python
if command -v python3 &>/dev/null; then
PYTHON=python3
elif command -v python &>/dev/null; then
PYTHON=python
else
err "Python not found!"
exit 1
fi
ok "Python: $($PYTHON --version 2>&1)"
# GPU
if $PYTHON -c "import torch; print(torch.cuda.is_available())" 2>/dev/null | grep -q "True"; then
GPU_AVAILABLE=true
GPU_NAME=$($PYTHON -c "import torch; print(torch.cuda.get_device_name(0))" 2>/dev/null || echo "Unknown")
ok "GPU: $GPU_NAME"
# Check compute capability for bf16
COMPUTE_CAP=$($PYTHON -c "import torch; print(torch.cuda.get_device_capability()[0])" 2>/dev/null || echo "0")
if [ "$COMPUTE_CAP" -ge 8 ]; then
PRECISION="bf16"
else
PRECISION="fp16"
fi
ok "Precision: $PRECISION"
else
GPU_AVAILABLE=false
PRECISION="fp32"
warn "No GPU detected β training will use CPU (optimised settings)"
fi
# W&B
if [ -n "${WANDB_API_KEY:-}" ]; then
ok "W&B: API key found"
else
warn "W&B: No API key (WANDB_API_KEY). Logging to TensorBoard only."
export WANDB_DISABLED=true
fi
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# STAGE 1: Install dependencies & download models
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
stage_1_setup() {
if prompt_stage "Setup & Dependencies" ".train_stage1_done"; then
info "Installing Python dependencies..."
$PYTHON -m pip install -r requirements.txt --quiet 2>&1 | tail -5
info "Downloading spaCy models..."
$PYTHON -m spacy download en_core_web_sm --quiet 2>/dev/null || true
info "Downloading NLTK data..."
$PYTHON -c "import nltk; nltk.download('punkt', quiet=True); nltk.download('punkt_tab', quiet=True)" 2>/dev/null || true
touch .train_stage1_done
ok "Setup complete"
fi
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# STAGE 2: Data preprocessing
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
stage_2_preprocess() {
if prompt_stage "Data Preprocessing" "data/processed/train.jsonl"; then
info "Preprocessing datasets into unified JSONL..."
$PYTHON scripts/preprocess_data.py
ok "Data preprocessing complete"
fi
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# STAGE 3: Pre-train human pattern classifier
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
stage_3_pretrain_classifier() {
if prompt_stage "Pre-train Human Pattern Classifier" "checkpoints/human_pattern_classifier.pt"; then
info "Pre-training human pattern classifier on Kaggle datasets..."
info "This may take a while on CPU (extracting features for ~100k texts)..."
$PYTHON scripts/pretrain_human_pattern_classifier.py
ok "Human pattern classifier pre-trained"
fi
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# STAGE 4: Main model training
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
stage_4_train() {
if prompt_stage "Main Model Training" "checkpoints/best_model/config.json"; then
info "Starting main model training..."
info "Config: $CONFIG"
# Add V2 loss flag if classifier exists
V2_FLAG=""
if [ -f "checkpoints/human_pattern_classifier.pt" ]; then
info "Human pattern classifier found β using V2 loss (with anti-AI term)"
V2_FLAG="--use-v2-loss"
fi
$PYTHON scripts/train.py --config "$CONFIG" $V2_FLAG
ok "Main training complete"
fi
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# STAGE 5: Evaluation
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
stage_5_evaluate() {
if prompt_stage "Evaluation" "logs/eval_results_test.json"; then
info "Running evaluation on test set..."
mkdir -p logs
$PYTHON scripts/evaluate.py --config "$CONFIG" --split test
ok "Evaluation complete"
fi
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Main
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
main() {
echo ""
echo -e "${BOLD}ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ${NC}"
echo -e "${BOLD}β Dyslexia Academic Writing Corrector β Training Suite β${NC}"
echo -e "${BOLD}ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ${NC}"
echo ""
detect_env
stage_1_setup
stage_2_preprocess
stage_3_pretrain_classifier
stage_4_train
stage_5_evaluate
echo ""
echo -e "${GREEN}${BOLD}βββ All stages complete! βββ${NC}"
echo -e " Model saved to: ${CYAN}checkpoints/best_model/${NC}"
echo -e " Eval results: ${CYAN}logs/eval_results_test.json${NC}"
echo -e " Start inference: ${CYAN}bash start.sh${NC}"
echo ""
}
main
|