File size: 15,920 Bytes
3df5819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
"""
Full training entry point.
Run: python scripts/train.py --config configs/training_config.yaml
"""

import click
import yaml
import torch
import os
import gc
from transformers import TrainingArguments, Seq2SeqTrainingArguments
from loguru import logger

try:
    import wandb
    HAS_WANDB = True
except ImportError:
    HAS_WANDB = False

from src.model.base_model import load_model_and_tokenizer
from src.model.style_conditioner import StyleConditioner
from src.training.dataset import WritingCorrectionDataset
from src.training.loss_functions import CombinedCorrectionLoss, CombinedCorrectionLossV2
from src.training.trainer import CorrectionTrainer
from src.training.callbacks import StyleMetricsCallback, EarlyStoppingOnStyleDrift
from src.style.fingerprinter import StyleFingerprinter
from src.evaluation.gleu_scorer import GLEUScorer


# ── Hybrid GPU Management ───────────────────────────────────────────────────
def _setup_device():
    """Detect GPU and configure hybrid VRAM management.

    Returns (device, gpu_info) where gpu_info is a dict with:
      - available: bool
      - name: str
      - vram_total_mb: int
      - vram_free_mb: int
      - compute_cap: tuple
    """
    gpu_info = {"available": False, "name": "CPU", "vram_total_mb": 0,
                "vram_free_mb": 0, "compute_cap": (0, 0)}

    if not torch.cuda.is_available():
        logger.info("No GPU detected β€” training on CPU")
        return "cpu", gpu_info

    gpu_info["available"] = True
    gpu_info["name"] = torch.cuda.get_device_name(0)
    gpu_info["compute_cap"] = torch.cuda.get_device_capability(0)

    # Query actual free VRAM
    vram_total = torch.cuda.get_device_properties(0).total_memory // (1024 * 1024)
    vram_reserved = torch.cuda.memory_reserved(0) // (1024 * 1024)
    vram_allocated = torch.cuda.memory_allocated(0) // (1024 * 1024)
    vram_free = vram_total - vram_allocated

    gpu_info["vram_total_mb"] = vram_total
    gpu_info["vram_free_mb"] = vram_free

    logger.info(
        f"GPU: {gpu_info['name']} | "
        f"VRAM: {vram_allocated}MB used / {vram_total}MB total ({vram_free}MB free) | "
        f"Compute: {gpu_info['compute_cap']}"
    )

    # Leave headroom for the system β€” reserve at most 85% of free VRAM
    # This prevents the desktop/compositor from starving
    usable_vram_mb = int(vram_free * 0.85)
    if usable_vram_mb > 0:
        # Set PyTorch memory limit to avoid hogging all VRAM
        fraction = min(usable_vram_mb / vram_total, 0.90)
        torch.cuda.set_per_process_memory_fraction(fraction, 0)
        logger.info(
            f"Hybrid GPU mode: capped PyTorch VRAM to {fraction:.0%} "
            f"(~{int(vram_total * fraction)}MB), leaving room for system"
        )

    return "cuda", gpu_info


def _auto_batch_size(model_key: str, device: str, gpu_info: dict,
                     config_batch: int) -> int:
    """Pick optimal batch size based on model size and available resources."""
    if device == "cpu":
        # CPU: T5-Small can handle batch=8 with 32GB RAM, larger models less
        if "small" in model_key:
            return min(config_batch, 8)
        return min(config_batch, 2)

    # GPU: estimate based on free VRAM
    free_mb = gpu_info["vram_free_mb"]

    # Rough VRAM per sample estimates (bf16, seq_len=128):
    #   T5-Small: ~120MB model + ~50MB/sample
    #   T5-Base:  ~350MB model + ~90MB/sample
    #   T5-Large: ~900MB model + ~150MB/sample
    model_vram_estimates = {
        "flan-t5-small": {"model_mb": 160, "per_sample_mb": 60},
        "flan-t5-base": {"model_mb": 400, "per_sample_mb": 100},
        "flan-t5-large": {"model_mb": 1000, "per_sample_mb": 160},
        "flan-t5-xl": {"model_mb": 3000, "per_sample_mb": 300},
    }
    est = model_vram_estimates.get(model_key, {"model_mb": 500, "per_sample_mb": 120})

    # Available for batches = free VRAM - model footprint - 300MB safety buffer
    available_for_batches = free_mb - est["model_mb"] - 300
    if available_for_batches <= 0:
        logger.warning("Very tight VRAM β€” using batch_size=1")
        return 1

    max_batch = max(1, available_for_batches // est["per_sample_mb"])
    optimal = min(config_batch, max_batch)

    logger.info(
        f"Auto batch size: {optimal} "
        f"(model ~{est['model_mb']}MB + {optimal}Γ—{est['per_sample_mb']}MB "
        f"= ~{est['model_mb'] + optimal * est['per_sample_mb']}MB / {free_mb}MB free)"
    )
    return max(1, optimal)


@click.command()
@click.option("--config", default="configs/training_config.yaml")
@click.option("--use-v2-loss", is_flag=True, help="Use V2 loss with human pattern term")
def train(config: str, use_v2_loss: bool):
    """Launch the full training pipeline."""
    # Step 1: Load config
    logger.info("Step 1: Loading config...")
    with open(config) as f:
        cfg = yaml.safe_load(f)

    model_cfg = cfg.get("model", {})
    lora_cfg = cfg.get("lora", {})
    data_cfg = cfg.get("data", {})
    train_cfg = cfg.get("training", {})
    loss_cfg = cfg.get("loss", {})
    gen_cfg = cfg.get("generation", {})

    # Step 2: Initialise W&B (optional)
    logger.info("Step 2: Initialising experiment tracking...")
    if HAS_WANDB and os.environ.get("WANDB_API_KEY"):
        wandb.init(
            project="dyslexia-rewriter",
            name=f"train-{model_cfg.get('key', 'flan-t5')}", 
            config=cfg,
        )
    else:
        logger.info("W&B not configured, logging to TensorBoard only")
        os.environ["WANDB_DISABLED"] = "true"

    # Step 3: Detect GPU and configure hybrid VRAM management
    logger.info("Step 3: Setting up device (hybrid GPU mode)...")
    device, gpu_info = _setup_device()

    # Step 4: Load model + tokenizer
    logger.info("Step 4: Loading model and tokenizer...")
    model_key = model_cfg.get("key", "flan-t5-small")
    model, tokenizer, is_seq2seq = load_model_and_tokenizer(
        model_key=model_key,
        quantize=model_cfg.get("quantize", False),
        use_lora=model_cfg.get("use_lora", True),
        lora_config_dict=lora_cfg,
    )

    # Required for PEFT + gradient checkpointing compatibility
    if hasattr(model, 'enable_input_require_grads'):
        model.enable_input_require_grads()

    # ── torch.compile for fused kernels (PyTorch 2.x) ───────────────────────
    if hasattr(torch, "compile") and device == "cuda":
        try:
            # "default" mode: fuses kernels via Triton without CUDA graphs.
            # "reduce-overhead" uses CUDA graphs which break with LoRA/PEFT
            # (tensor outputs get overwritten between graph replays).
            logger.info("Applying torch.compile(mode='default')...")
            model = torch.compile(model, mode="default")
            logger.info("βœ“ torch.compile applied β€” first few steps will be slower (compiling)")
        except Exception as e:
            logger.warning(f"torch.compile failed (non-fatal): {e}")

    # Step 5: Create fingerprinter
    logger.info("Step 5: Creating style fingerprinter...")
    fingerprinter = StyleFingerprinter(
        spacy_model="en_core_web_sm",  # Use small model for training speed
        awl_path="data/awl/coxhead_awl.txt",
    )

    # Step 6: Create datasets
    logger.info("Step 6: Loading datasets...")
    train_dataset = WritingCorrectionDataset(
        data_path=data_cfg.get("train_path", "data/processed/train.jsonl"),
        tokenizer=tokenizer,
        fingerprinter=fingerprinter,
        max_input_length=data_cfg.get("max_input_length", 512),
        max_target_length=data_cfg.get("max_target_length", 512),
        augment_with_synthetic=data_cfg.get("augment_synthetic", True),
        synthetic_ratio=data_cfg.get("synthetic_ratio", 0.3),
    )

    val_dataset = WritingCorrectionDataset(
        data_path=data_cfg.get("val_path", "data/processed/val.jsonl"),
        tokenizer=tokenizer,
        fingerprinter=fingerprinter,
        max_input_length=data_cfg.get("max_input_length", 512),
        max_target_length=data_cfg.get("max_target_length", 512),
        augment_with_synthetic=False,
    )

    logger.info(f"Train: {len(train_dataset)} | Val: {len(val_dataset)}")

    # Free memory after dataset loading
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()

    # Use simple CE-only loss for training β€” aux models (sentence-transformer,
    # GPT-2, HP classifier) are NOT loaded since they provide no gradient signal
    # (they decode via argmax under no_grad). This saves ~1GB+ memory.
    from torch import nn
    class CEOnlyLoss(nn.Module):
        """Cross-entropy only loss β€” the only loss that provides gradient signal."""
        def __init__(self):
            super().__init__()
            self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)

        def forward(self, logits, labels, **kwargs):
            if logits.dim() == 3:
                ce_logits = logits.view(-1, logits.size(-1))
                ce_labels = labels.view(-1)
            else:
                ce_logits = logits
                ce_labels = labels
            l_ce = self.ce_loss(ce_logits, ce_labels)
            return {"total_loss": l_ce, "ce_loss": l_ce}

    loss_fn = CEOnlyLoss()
    logger.info("Using CE-only loss (aux models skipped to save memory)")

    # Step 8: Create training arguments
    logger.info("Step 8: Creating training arguments...")

    # Auto-detect precision support
    use_bf16 = False
    use_fp16 = False
    if device == "cuda":
        if gpu_info["compute_cap"][0] >= 8:
            use_bf16 = True
            logger.info("Using BF16 (Ampere+ GPU)")
        else:
            use_fp16 = True
            logger.info("Using FP16 (pre-Ampere GPU)")
    elif device == "cpu":
        # Zen 3+ CPUs (Ryzen 5000+) support BF16 in PyTorch 2.x
        try:
            test = torch.tensor([1.0], dtype=torch.bfloat16)
            _ = test + test  # Test BF16 compute works
            use_bf16 = True
            logger.info("Using BF16 on CPU (Zen 3+ detected)")
        except Exception:
            logger.info("BF16 not supported on this CPU, using FP32")

    # Smart batch size based on model + available resources
    config_batch = train_cfg.get("per_device_train_batch_size", 4)
    batch_size = _auto_batch_size(model_key, device, gpu_info, config_batch)

    # Smart gradient checkpointing:
    # - ENABLE for large models (saves VRAM at cost of compute)
    # - DISABLE for small models (they fit in VRAM, checkpointing is pure overhead)
    # - ALWAYS DISABLE on CPU (plenty of RAM, checkpointing wastes CPU cycles)
    large_models = {"flan-t5-large", "flan-t5-xl", "llama-3.1-8b"}
    use_grad_ckpt = model_key in large_models and device == "cuda"
    if use_grad_ckpt:
        logger.info("Gradient checkpointing: ON (large model, saving VRAM)")
    else:
        logger.info(f"Gradient checkpointing: OFF ({'small model fits in VRAM' if device == 'cuda' else 'CPU has plenty of RAM'})")

    # Dataloader workers: Python 3.14 changed default start method to "forkserver"
    # on Linux, which hits "too many fds" with num_workers > 0.
    # Use 0 (main-process loading) β€” dataset is pre-tokenized so overhead is minimal.
    num_workers = train_cfg.get("dataloader_num_workers", 0)

    # Filter report_to to only available tools
    report_to = []
    if HAS_WANDB and os.environ.get("WANDB_API_KEY"):
        report_to.append("wandb")
    report_to.append("tensorboard")

    training_args = TrainingArguments(
        output_dir=train_cfg.get("output_dir", "checkpoints/"),
        num_train_epochs=train_cfg.get("num_train_epochs", 5),
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=train_cfg.get("per_device_eval_batch_size", 8) if device == "cuda" else 2,
        gradient_accumulation_steps=train_cfg.get("gradient_accumulation_steps", 8),
        learning_rate=train_cfg.get("learning_rate", 3e-4),
        lr_scheduler_type=train_cfg.get("lr_scheduler_type", "cosine"),
        warmup_ratio=train_cfg.get("warmup_ratio", 0.05),
        weight_decay=train_cfg.get("weight_decay", 0.01),
        fp16=use_fp16,
        bf16=use_bf16,
        eval_strategy=train_cfg.get("evaluation_strategy", "steps"),
        eval_steps=train_cfg.get("eval_steps", 100),
        save_strategy=train_cfg.get("save_strategy", "steps"),
        save_steps=train_cfg.get("save_steps", 100),
        save_total_limit=train_cfg.get("save_total_limit", 3),
        load_best_model_at_end=False,  # Handled manually below (PEFT adapters break Trainer's loader)
        metric_for_best_model=train_cfg.get("metric_for_best_model", "eval_loss"),
        greater_is_better=train_cfg.get("greater_is_better", False),
        logging_dir=train_cfg.get("logging_dir", "logs/"),
        logging_steps=train_cfg.get("logging_steps", 25),
        report_to=report_to,
        dataloader_num_workers=num_workers,
        seed=train_cfg.get("seed", 42),
        remove_unused_columns=False,  # We have custom columns (style_vector, etc.)
        gradient_checkpointing=use_grad_ckpt,
    )

    # Step 9: Create trainer
    logger.info("Step 9: Creating trainer...")
    trainer = CorrectionTrainer(
        loss_fn=loss_fn,
        fingerprinter=fingerprinter,
        tokenizer=tokenizer,
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        callbacks=[
            StyleMetricsCallback(),
            EarlyStoppingOnStyleDrift(min_style_similarity=0.75),
        ],
    )

    # Step 10: Train
    logger.info("Step 10: Starting training...")
    logger.info(
        f"Config summary: model={model_key} | batch={batch_size} | "
        f"accum={training_args.gradient_accumulation_steps} | "
        f"effective_batch={batch_size * training_args.gradient_accumulation_steps} | "
        f"epochs={training_args.num_train_epochs} | "
        f"precision={'bf16' if use_bf16 else 'fp16' if use_fp16 else 'fp32'} | "
        f"grad_ckpt={use_grad_ckpt} | device={device}"
    )
    trainer.train()

    # Step 11: Save best model (manual PEFT-aware loading)
    logger.info("Step 11: Saving best model...")
    output_dir = train_cfg.get("output_dir", "checkpoints/")
    save_path = os.path.join(output_dir, "best_model")

    # Find best checkpoint from trainer state
    best_ckpt = None
    state_path = os.path.join(output_dir, "trainer_state.json")
    # Check each checkpoint for trainer_state.json
    import glob
    for ckpt_dir in sorted(glob.glob(os.path.join(output_dir, "checkpoint-*"))):
        ts = os.path.join(ckpt_dir, "trainer_state.json")
        if os.path.exists(ts):
            import json as json_mod
            with open(ts) as f:
                state = json_mod.load(f)
            best_path = state.get("best_model_checkpoint")
            if best_path:
                best_ckpt = best_path

    if best_ckpt and os.path.isdir(best_ckpt):
        logger.info(f"Loading best checkpoint from {best_ckpt}")
        from peft import PeftModel
        # Reload the best adapter weights
        best_adapter = os.path.join(best_ckpt, "adapter_model.safetensors")
        if os.path.exists(best_adapter):
            model.load_adapter(best_ckpt, adapter_name="default")
            logger.info(f"Loaded best adapter from {best_ckpt}")
        else:
            logger.warning(f"No adapter found at {best_ckpt}, saving current model")
    else:
        logger.info("No best checkpoint found, saving final model state")

    trainer.save_model(save_path)
    tokenizer.save_pretrained(save_path)
    logger.info(f"Model saved to {save_path}")

    if HAS_WANDB and wandb.run is not None:
        wandb.finish()

    logger.info("βœ“ Training complete!")


if __name__ == "__main__":
    train()