""" Pre-trains the HumanPatternClassifier on both Kaggle datasets. Run this BEFORE the main training loop. The saved classifier weights are then loaded frozen during main training. Run: python scripts/pretrain_human_pattern_classifier.py Output: checkpoints/human_pattern_classifier.pt """ import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split from sklearn.metrics import accuracy_score, roc_auc_score import numpy as np from loguru import logger import os import yaml try: import wandb HAS_WANDB = True except ImportError: HAS_WANDB = False from src.training.human_pattern_extractor import ( HumanPatternFeatureExtractor, KaggleHumanPatternDataset, HumanPatternClassifier, ) def train_classifier(config_path: str = "configs/training_config.yaml"): """Pre-train the human pattern classifier on Kaggle datasets.""" # Load config with open(config_path) as f: config = yaml.safe_load(f) hp_cfg = config.get("human_pattern", {}) # Init W&B (optional) if HAS_WANDB and os.environ.get("WANDB_API_KEY"): wandb.init(project="dyslexia-rewriter", name="human-pattern-pretrain", tags=["pretrain"]) else: logger.info("W&B not configured, logging to console only") # Create extractor logger.info("Creating feature extractor...") extractor = HumanPatternFeatureExtractor(spacy_model="en_core_web_sm") # Load datasets shanegerami_path = hp_cfg.get("shanegerami_path", "data/raw/shanegerami/AI_Human.csv") starblasters_path = hp_cfg.get("starblasters_path", "data/raw/starblasters8/data.parquet") max_samples = hp_cfg.get("max_samples_per_source", 50000) logger.info("Loading datasets...") dataset = KaggleHumanPatternDataset( shanegerami_path=shanegerami_path, starblasters_path=starblasters_path, extractor=extractor, max_samples_per_source=max_samples, ) if len(dataset) == 0: logger.error("No data loaded! Check dataset paths.") return # Pre-compute features dataset.precompute_features() # Train/val split (80/20) val_size = int(len(dataset) * 0.2) train_size = len(dataset) - val_size train_dataset, val_dataset = random_split( dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42), ) # Create dataloaders batch_size = hp_cfg.get("pretrain_batch_size", 512) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) logger.info(f"Train: {train_size} | Val: {val_size} | Batch size: {batch_size}") # Create model classifier = HumanPatternClassifier(input_dim=17, hidden_dim=128) device = "cpu" classifier = classifier.to(device) # Training setup epochs = hp_cfg.get("pretrain_epochs", 20) lr = hp_cfg.get("pretrain_lr", 1e-3) target_auc = hp_cfg.get("target_auc", 0.88) optimizer = torch.optim.AdamW(classifier.parameters(), lr=lr, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) criterion = nn.BCELoss() best_auc = 0.0 os.makedirs("checkpoints", exist_ok=True) # Training loop for epoch in range(1, epochs + 1): classifier.train() train_loss = 0.0 train_preds = [] train_labels = [] for features, labels in train_loader: features = features.to(device) labels = labels.float().to(device) optimizer.zero_grad() outputs = classifier(features) loss = criterion(outputs, labels) loss.backward() # Gradient clipping for stability torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0) optimizer.step() train_loss += loss.item() * features.size(0) train_preds.extend(outputs.detach().cpu().numpy()) train_labels.extend(labels.cpu().numpy()) scheduler.step() train_loss /= train_size # Validation classifier.eval() val_preds = [] val_labels = [] val_loss = 0.0 with torch.no_grad(): for features, labels in val_loader: features = features.to(device) labels = labels.float().to(device) outputs = classifier(features) loss = criterion(outputs, labels) val_loss += loss.item() * features.size(0) val_preds.extend(outputs.cpu().numpy()) val_labels.extend(labels.cpu().numpy()) val_loss /= val_size # Metrics train_preds_binary = [1 if p > 0.5 else 0 for p in train_preds] val_preds_binary = [1 if p > 0.5 else 0 for p in val_preds] train_acc = accuracy_score(train_labels, train_preds_binary) val_acc = accuracy_score(val_labels, val_preds_binary) try: train_auc = roc_auc_score(train_labels, train_preds) val_auc = roc_auc_score(val_labels, val_preds) except ValueError: train_auc = 0.0 val_auc = 0.0 logger.info( f"Epoch {epoch}/{epochs} | " f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} AUC: {train_auc:.4f} | " f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} AUC: {val_auc:.4f}" ) # Log to W&B if HAS_WANDB and wandb.run is not None: wandb.log({ "epoch": epoch, "train/loss": train_loss, "train/accuracy": train_acc, "train/auc": train_auc, "val/loss": val_loss, "val/accuracy": val_acc, "val/auc": val_auc, "lr": scheduler.get_last_lr()[0], }) # Save best model by AUC if val_auc > best_auc: best_auc = val_auc save_path = hp_cfg.get("classifier_path", "checkpoints/human_pattern_classifier.pt") torch.save(classifier.state_dict(), save_path) logger.info(f" ✓ New best AUC: {val_auc:.4f} — saved to {save_path}") # Early stopping if target AUC reached if val_auc >= target_auc: logger.info(f"Target AUC {target_auc} reached at epoch {epoch}! Stopping.") break logger.info(f"\nPre-training complete. Best AUC: {best_auc:.4f}") if HAS_WANDB and wandb.run is not None: wandb.finish() if __name__ == "__main__": train_classifier()