| """ |
| Pre-trains the HumanPatternClassifier on both Kaggle datasets. |
| Run this BEFORE the main training loop. |
| The saved classifier weights are then loaded frozen during main training. |
| |
| Run: python scripts/pretrain_human_pattern_classifier.py |
| Output: checkpoints/human_pattern_classifier.pt |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader, random_split |
| from sklearn.metrics import accuracy_score, roc_auc_score |
| import numpy as np |
| from loguru import logger |
| import os |
| import yaml |
|
|
| try: |
| import wandb |
| HAS_WANDB = True |
| except ImportError: |
| HAS_WANDB = False |
|
|
| from src.training.human_pattern_extractor import ( |
| HumanPatternFeatureExtractor, |
| KaggleHumanPatternDataset, |
| HumanPatternClassifier, |
| ) |
|
|
|
|
| def train_classifier(config_path: str = "configs/training_config.yaml"): |
| """Pre-train the human pattern classifier on Kaggle datasets.""" |
| |
| with open(config_path) as f: |
| config = yaml.safe_load(f) |
|
|
| hp_cfg = config.get("human_pattern", {}) |
|
|
| |
| if HAS_WANDB and os.environ.get("WANDB_API_KEY"): |
| wandb.init(project="dyslexia-rewriter", name="human-pattern-pretrain", tags=["pretrain"]) |
| else: |
| logger.info("W&B not configured, logging to console only") |
|
|
| |
| logger.info("Creating feature extractor...") |
| extractor = HumanPatternFeatureExtractor(spacy_model="en_core_web_sm") |
|
|
| |
| shanegerami_path = hp_cfg.get("shanegerami_path", "data/raw/shanegerami/AI_Human.csv") |
| starblasters_path = hp_cfg.get("starblasters_path", "data/raw/starblasters8/data.parquet") |
| max_samples = hp_cfg.get("max_samples_per_source", 50000) |
|
|
| logger.info("Loading datasets...") |
| dataset = KaggleHumanPatternDataset( |
| shanegerami_path=shanegerami_path, |
| starblasters_path=starblasters_path, |
| extractor=extractor, |
| max_samples_per_source=max_samples, |
| ) |
|
|
| if len(dataset) == 0: |
| logger.error("No data loaded! Check dataset paths.") |
| return |
|
|
| |
| dataset.precompute_features() |
|
|
| |
| val_size = int(len(dataset) * 0.2) |
| train_size = len(dataset) - val_size |
| train_dataset, val_dataset = random_split( |
| dataset, |
| [train_size, val_size], |
| generator=torch.Generator().manual_seed(42), |
| ) |
|
|
| |
| batch_size = hp_cfg.get("pretrain_batch_size", 512) |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
|
|
| logger.info(f"Train: {train_size} | Val: {val_size} | Batch size: {batch_size}") |
|
|
| |
| classifier = HumanPatternClassifier(input_dim=17, hidden_dim=128) |
| device = "cpu" |
| classifier = classifier.to(device) |
|
|
| |
| epochs = hp_cfg.get("pretrain_epochs", 20) |
| lr = hp_cfg.get("pretrain_lr", 1e-3) |
| target_auc = hp_cfg.get("target_auc", 0.88) |
|
|
| optimizer = torch.optim.AdamW(classifier.parameters(), lr=lr, weight_decay=1e-4) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
| criterion = nn.BCELoss() |
|
|
| best_auc = 0.0 |
| os.makedirs("checkpoints", exist_ok=True) |
|
|
| |
| for epoch in range(1, epochs + 1): |
| classifier.train() |
| train_loss = 0.0 |
| train_preds = [] |
| train_labels = [] |
|
|
| for features, labels in train_loader: |
| features = features.to(device) |
| labels = labels.float().to(device) |
|
|
| optimizer.zero_grad() |
| outputs = classifier(features) |
| loss = criterion(outputs, labels) |
| loss.backward() |
|
|
| |
| torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0) |
|
|
| optimizer.step() |
|
|
| train_loss += loss.item() * features.size(0) |
| train_preds.extend(outputs.detach().cpu().numpy()) |
| train_labels.extend(labels.cpu().numpy()) |
|
|
| scheduler.step() |
| train_loss /= train_size |
|
|
| |
| classifier.eval() |
| val_preds = [] |
| val_labels = [] |
| val_loss = 0.0 |
|
|
| with torch.no_grad(): |
| for features, labels in val_loader: |
| features = features.to(device) |
| labels = labels.float().to(device) |
| outputs = classifier(features) |
| loss = criterion(outputs, labels) |
| val_loss += loss.item() * features.size(0) |
| val_preds.extend(outputs.cpu().numpy()) |
| val_labels.extend(labels.cpu().numpy()) |
|
|
| val_loss /= val_size |
|
|
| |
| train_preds_binary = [1 if p > 0.5 else 0 for p in train_preds] |
| val_preds_binary = [1 if p > 0.5 else 0 for p in val_preds] |
|
|
| train_acc = accuracy_score(train_labels, train_preds_binary) |
| val_acc = accuracy_score(val_labels, val_preds_binary) |
|
|
| try: |
| train_auc = roc_auc_score(train_labels, train_preds) |
| val_auc = roc_auc_score(val_labels, val_preds) |
| except ValueError: |
| train_auc = 0.0 |
| val_auc = 0.0 |
|
|
| logger.info( |
| f"Epoch {epoch}/{epochs} | " |
| f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} AUC: {train_auc:.4f} | " |
| f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} AUC: {val_auc:.4f}" |
| ) |
|
|
| |
| if HAS_WANDB and wandb.run is not None: |
| wandb.log({ |
| "epoch": epoch, |
| "train/loss": train_loss, |
| "train/accuracy": train_acc, |
| "train/auc": train_auc, |
| "val/loss": val_loss, |
| "val/accuracy": val_acc, |
| "val/auc": val_auc, |
| "lr": scheduler.get_last_lr()[0], |
| }) |
|
|
| |
| if val_auc > best_auc: |
| best_auc = val_auc |
| save_path = hp_cfg.get("classifier_path", "checkpoints/human_pattern_classifier.pt") |
| torch.save(classifier.state_dict(), save_path) |
| logger.info(f" ✓ New best AUC: {val_auc:.4f} — saved to {save_path}") |
|
|
| |
| if val_auc >= target_auc: |
| logger.info(f"Target AUC {target_auc} reached at epoch {epoch}! Stopping.") |
| break |
|
|
| logger.info(f"\nPre-training complete. Best AUC: {best_auc:.4f}") |
|
|
| if HAS_WANDB and wandb.run is not None: |
| wandb.finish() |
|
|
|
|
| if __name__ == "__main__": |
| train_classifier() |
|
|