File size: 10,267 Bytes
3df5819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#!/usr/bin/env bash
# ═══════════════════════════════════════════════════════════════════════════
# train.sh β€” Multi-stage training orchestrator with checkpoint system
# ═══════════════════════════════════════════════════════════════════════════
#
# Usage: bash train.sh [--config CONFIG] [--auto]
#
# Each stage prompts: [S]kip, [R]edo, [C]ontinue
# Use --auto to skip all prompts and auto-detect what needs running
#
set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
export PYTHONPATH="${SCRIPT_DIR}:${PYTHONPATH:-}"
CONFIG="${1:-configs/training_config.yaml}"
AUTO_MODE=false

# Parse args
for arg in "$@"; do
    case $arg in
        --auto) AUTO_MODE=true ;;
        --config=*) CONFIG="${arg#*=}" ;;
    esac
done

# ── Colors ──────────────────────────────────────────────────────────────────
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
BOLD='\033[1m'
NC='\033[0m'

info()  { echo -e "${CYAN}[INFO]${NC} $1"; }
ok()    { echo -e "${GREEN}[  OK]${NC} $1"; }
warn()  { echo -e "${YELLOW}[WARN]${NC} $1"; }
err()   { echo -e "${RED}[FAIL]${NC} $1"; }

# ── Stage prompt function ──────────────────────────────────────────────────
# Asks user to [S]kip, [R]edo, or [C]ontinue for each stage
prompt_stage() {
    local stage_name="$1"
    local check_file="$2"  # File to check if stage already completed

    echo ""
    echo -e "${BOLD}═══ Stage: ${stage_name} ═══${NC}"

    if [ "$AUTO_MODE" = true ]; then
        if [ -n "$check_file" ] && [ -e "$check_file" ]; then
            info "Auto-mode: $check_file exists, skipping"
            return 1  # Skip
        fi
        return 0  # Continue
    fi

    if [ -n "$check_file" ] && [ -e "$check_file" ]; then
        warn "Previous output found: $check_file"
        echo -e "  ${YELLOW}[S]${NC}kip  |  ${CYAN}[R]${NC}edo  |  ${GREEN}[C]${NC}ontinue"
        read -rp "  Choice [S/R/C]: " choice
        case "${choice,,}" in
            r|redo)   info "Redoing ${stage_name}..."; return 0 ;;
            c|continue) info "Continuing ${stage_name}..."; return 0 ;;
            *)         info "Skipping ${stage_name}"; return 1 ;;
        esac
    else
        info "No previous output found. Running ${stage_name}..."
        return 0
    fi
}

# ── Detect environment ─────────────────────────────────────────────────────
detect_env() {
    echo -e "${BOLD}═══ Environment Detection ═══${NC}"

    # Python
    if command -v python3 &>/dev/null; then
        PYTHON=python3
    elif command -v python &>/dev/null; then
        PYTHON=python
    else
        err "Python not found!"
        exit 1
    fi
    ok "Python: $($PYTHON --version 2>&1)"

    # GPU
    if $PYTHON -c "import torch; print(torch.cuda.is_available())" 2>/dev/null | grep -q "True"; then
        GPU_AVAILABLE=true
        GPU_NAME=$($PYTHON -c "import torch; print(torch.cuda.get_device_name(0))" 2>/dev/null || echo "Unknown")
        ok "GPU: $GPU_NAME"

        # Check compute capability for bf16
        COMPUTE_CAP=$($PYTHON -c "import torch; print(torch.cuda.get_device_capability()[0])" 2>/dev/null || echo "0")
        if [ "$COMPUTE_CAP" -ge 8 ]; then
            PRECISION="bf16"
        else
            PRECISION="fp16"
        fi
        ok "Precision: $PRECISION"
    else
        GPU_AVAILABLE=false
        PRECISION="fp32"
        warn "No GPU detected β€” training will use CPU (optimised settings)"
    fi

    # W&B
    if [ -n "${WANDB_API_KEY:-}" ]; then
        ok "W&B: API key found"
    else
        warn "W&B: No API key (WANDB_API_KEY). Logging to TensorBoard only."
        export WANDB_DISABLED=true
    fi
}

# ═══════════════════════════════════════════════════════════════════════════
# STAGE 1: Install dependencies & download models
# ═══════════════════════════════════════════════════════════════════════════
stage_1_setup() {
    if prompt_stage "Setup & Dependencies" ".train_stage1_done"; then
        info "Installing Python dependencies..."
        $PYTHON -m pip install -r requirements.txt --quiet 2>&1 | tail -5

        info "Downloading spaCy models..."
        $PYTHON -m spacy download en_core_web_sm --quiet 2>/dev/null || true

        info "Downloading NLTK data..."
        $PYTHON -c "import nltk; nltk.download('punkt', quiet=True); nltk.download('punkt_tab', quiet=True)" 2>/dev/null || true

        touch .train_stage1_done
        ok "Setup complete"
    fi
}

# ═══════════════════════════════════════════════════════════════════════════
# STAGE 2: Data preprocessing
# ═══════════════════════════════════════════════════════════════════════════
stage_2_preprocess() {
    if prompt_stage "Data Preprocessing" "data/processed/train.jsonl"; then
        info "Preprocessing datasets into unified JSONL..."
        $PYTHON scripts/preprocess_data.py
        ok "Data preprocessing complete"
    fi
}

# ═══════════════════════════════════════════════════════════════════════════
# STAGE 3: Pre-train human pattern classifier
# ═══════════════════════════════════════════════════════════════════════════
stage_3_pretrain_classifier() {
    if prompt_stage "Pre-train Human Pattern Classifier" "checkpoints/human_pattern_classifier.pt"; then
        info "Pre-training human pattern classifier on Kaggle datasets..."
        info "This may take a while on CPU (extracting features for ~100k texts)..."
        $PYTHON scripts/pretrain_human_pattern_classifier.py
        ok "Human pattern classifier pre-trained"
    fi
}

# ═══════════════════════════════════════════════════════════════════════════
# STAGE 4: Main model training
# ═══════════════════════════════════════════════════════════════════════════
stage_4_train() {
    if prompt_stage "Main Model Training" "checkpoints/best_model/config.json"; then
        info "Starting main model training..."
        info "Config: $CONFIG"

        # Add V2 loss flag if classifier exists
        V2_FLAG=""
        if [ -f "checkpoints/human_pattern_classifier.pt" ]; then
            info "Human pattern classifier found β€” using V2 loss (with anti-AI term)"
            V2_FLAG="--use-v2-loss"
        fi

        $PYTHON scripts/train.py --config "$CONFIG" $V2_FLAG
        ok "Main training complete"
    fi
}

# ═══════════════════════════════════════════════════════════════════════════
# STAGE 5: Evaluation
# ═══════════════════════════════════════════════════════════════════════════
stage_5_evaluate() {
    if prompt_stage "Evaluation" "logs/eval_results_test.json"; then
        info "Running evaluation on test set..."
        mkdir -p logs
        $PYTHON scripts/evaluate.py --config "$CONFIG" --split test
        ok "Evaluation complete"
    fi
}

# ═══════════════════════════════════════════════════════════════════════════
# Main
# ═══════════════════════════════════════════════════════════════════════════
main() {
    echo ""
    echo -e "${BOLD}╔══════════════════════════════════════════════════════════╗${NC}"
    echo -e "${BOLD}β•‘  Dyslexia Academic Writing Corrector β€” Training Suite   β•‘${NC}"
    echo -e "${BOLD}β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•${NC}"
    echo ""

    detect_env

    stage_1_setup
    stage_2_preprocess
    stage_3_pretrain_classifier
    stage_4_train
    stage_5_evaluate

    echo ""
    echo -e "${GREEN}${BOLD}═══ All stages complete! ═══${NC}"
    echo -e "  Model saved to: ${CYAN}checkpoints/best_model/${NC}"
    echo -e "  Eval results:   ${CYAN}logs/eval_results_test.json${NC}"
    echo -e "  Start inference: ${CYAN}bash start.sh${NC}"
    echo ""
}

main