rewrite / scripts /pretrain_human_pattern_classifier.py
morpheuslord's picture
Add files using upload-large-folder tool
3df5819 verified
"""
Pre-trains the HumanPatternClassifier on both Kaggle datasets.
Run this BEFORE the main training loop.
The saved classifier weights are then loaded frozen during main training.
Run: python scripts/pretrain_human_pattern_classifier.py
Output: checkpoints/human_pattern_classifier.pt
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import accuracy_score, roc_auc_score
import numpy as np
from loguru import logger
import os
import yaml
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
from src.training.human_pattern_extractor import (
HumanPatternFeatureExtractor,
KaggleHumanPatternDataset,
HumanPatternClassifier,
)
def train_classifier(config_path: str = "configs/training_config.yaml"):
"""Pre-train the human pattern classifier on Kaggle datasets."""
# Load config
with open(config_path) as f:
config = yaml.safe_load(f)
hp_cfg = config.get("human_pattern", {})
# Init W&B (optional)
if HAS_WANDB and os.environ.get("WANDB_API_KEY"):
wandb.init(project="dyslexia-rewriter", name="human-pattern-pretrain", tags=["pretrain"])
else:
logger.info("W&B not configured, logging to console only")
# Create extractor
logger.info("Creating feature extractor...")
extractor = HumanPatternFeatureExtractor(spacy_model="en_core_web_sm")
# Load datasets
shanegerami_path = hp_cfg.get("shanegerami_path", "data/raw/shanegerami/AI_Human.csv")
starblasters_path = hp_cfg.get("starblasters_path", "data/raw/starblasters8/data.parquet")
max_samples = hp_cfg.get("max_samples_per_source", 50000)
logger.info("Loading datasets...")
dataset = KaggleHumanPatternDataset(
shanegerami_path=shanegerami_path,
starblasters_path=starblasters_path,
extractor=extractor,
max_samples_per_source=max_samples,
)
if len(dataset) == 0:
logger.error("No data loaded! Check dataset paths.")
return
# Pre-compute features
dataset.precompute_features()
# Train/val split (80/20)
val_size = int(len(dataset) * 0.2)
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(
dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42),
)
# Create dataloaders
batch_size = hp_cfg.get("pretrain_batch_size", 512)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
logger.info(f"Train: {train_size} | Val: {val_size} | Batch size: {batch_size}")
# Create model
classifier = HumanPatternClassifier(input_dim=17, hidden_dim=128)
device = "cpu"
classifier = classifier.to(device)
# Training setup
epochs = hp_cfg.get("pretrain_epochs", 20)
lr = hp_cfg.get("pretrain_lr", 1e-3)
target_auc = hp_cfg.get("target_auc", 0.88)
optimizer = torch.optim.AdamW(classifier.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
criterion = nn.BCELoss()
best_auc = 0.0
os.makedirs("checkpoints", exist_ok=True)
# Training loop
for epoch in range(1, epochs + 1):
classifier.train()
train_loss = 0.0
train_preds = []
train_labels = []
for features, labels in train_loader:
features = features.to(device)
labels = labels.float().to(device)
optimizer.zero_grad()
outputs = classifier(features)
loss = criterion(outputs, labels)
loss.backward()
# Gradient clipping for stability
torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item() * features.size(0)
train_preds.extend(outputs.detach().cpu().numpy())
train_labels.extend(labels.cpu().numpy())
scheduler.step()
train_loss /= train_size
# Validation
classifier.eval()
val_preds = []
val_labels = []
val_loss = 0.0
with torch.no_grad():
for features, labels in val_loader:
features = features.to(device)
labels = labels.float().to(device)
outputs = classifier(features)
loss = criterion(outputs, labels)
val_loss += loss.item() * features.size(0)
val_preds.extend(outputs.cpu().numpy())
val_labels.extend(labels.cpu().numpy())
val_loss /= val_size
# Metrics
train_preds_binary = [1 if p > 0.5 else 0 for p in train_preds]
val_preds_binary = [1 if p > 0.5 else 0 for p in val_preds]
train_acc = accuracy_score(train_labels, train_preds_binary)
val_acc = accuracy_score(val_labels, val_preds_binary)
try:
train_auc = roc_auc_score(train_labels, train_preds)
val_auc = roc_auc_score(val_labels, val_preds)
except ValueError:
train_auc = 0.0
val_auc = 0.0
logger.info(
f"Epoch {epoch}/{epochs} | "
f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} AUC: {train_auc:.4f} | "
f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} AUC: {val_auc:.4f}"
)
# Log to W&B
if HAS_WANDB and wandb.run is not None:
wandb.log({
"epoch": epoch,
"train/loss": train_loss,
"train/accuracy": train_acc,
"train/auc": train_auc,
"val/loss": val_loss,
"val/accuracy": val_acc,
"val/auc": val_auc,
"lr": scheduler.get_last_lr()[0],
})
# Save best model by AUC
if val_auc > best_auc:
best_auc = val_auc
save_path = hp_cfg.get("classifier_path", "checkpoints/human_pattern_classifier.pt")
torch.save(classifier.state_dict(), save_path)
logger.info(f" ✓ New best AUC: {val_auc:.4f} — saved to {save_path}")
# Early stopping if target AUC reached
if val_auc >= target_auc:
logger.info(f"Target AUC {target_auc} reached at epoch {epoch}! Stopping.")
break
logger.info(f"\nPre-training complete. Best AUC: {best_auc:.4f}")
if HAS_WANDB and wandb.run is not None:
wandb.finish()
if __name__ == "__main__":
train_classifier()