""" AIFinder Training Script Loads data, trains a two-headed GPU classifier, reports metrics, and saves the model. Usage: python3 train.py """ import os import sys import time import joblib import numpy as np import torch import torch.nn as nn from torch.utils.data import TensorDataset, DataLoader from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report from sklearn.preprocessing import LabelEncoder from sklearn.utils.class_weight import compute_class_weight from config import ( MODEL_DIR, TEST_SIZE, RANDOM_STATE, HIDDEN_DIM, EMBED_DIM, DROPOUT, BATCH_SIZE, EPOCHS, LEARNING_RATE, WEIGHT_DECAY, EARLY_STOP_PATIENCE, ) from data_loader import load_all_data from features import FeaturePipeline from model import AIFinderNet def _log(msg, t0=None): """Print a timestamped log message, optionally with elapsed time.""" ts = time.strftime("%H:%M:%S") if t0 is not None: elapsed = time.time() - t0 print(f" [{ts}] {msg} ({elapsed:.1f}s)") else: print(f" [{ts}] {msg}") def main(): t_start = time.time() print("=" * 60) print("AIFinder Training - Provider Classification") print("=" * 60) # ── GPU check ────────────────────────────────────────────── if torch.cuda.is_available(): device = torch.device("cuda") gpu_name = torch.cuda.get_device_name(0) gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3 _log(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)") else: device = torch.device("cpu") _log("No GPU available, using CPU") # ── Load data ────────────────────────────────────────────── _log("Starting data load...") t0 = time.time() texts, providers, models, _is_ai = load_all_data() _log("Data load complete", t0) if len(texts) < 100: print("ERROR: Not enough data loaded. Check dataset access.") sys.exit(1) # ── Encode labels ────────────────────────────────────────── _log("Encoding labels...") t0 = time.time() provider_enc = LabelEncoder() provider_labels = provider_enc.fit_transform(providers) num_providers = len(provider_enc.classes_) _log(f"Labels encoded — {num_providers} providers", t0) # ── Train/test split ─────────────────────────────────────── _log("Splitting train/test...") t0 = time.time() indices = np.arange(len(texts)) train_idx, test_idx = train_test_split( indices, test_size=TEST_SIZE, random_state=RANDOM_STATE, stratify=provider_labels, ) train_texts = [texts[i] for i in train_idx] test_texts = [texts[i] for i in test_idx] _log(f"Split: {len(train_texts)} train / {len(test_texts)} test", t0) # ── Build features ───────────────────────────────────────── _log("Building feature pipeline (fit on train)...") t0 = time.time() pipeline = FeaturePipeline() X_train = pipeline.fit_transform(train_texts) _log(f"Train features: {X_train.shape}", t0) _log("Transforming test set...") t0 = time.time() X_test = pipeline.transform(test_texts) _log(f"Test features: {X_test.shape}", t0) input_dim = X_train.shape[1] # ── Move to device ───────────────────────────────────────── _log(f"Moving data to {device}...") t0 = time.time() X_train_t = torch.tensor(X_train.toarray(), dtype=torch.float32).to(device) X_test_t = torch.tensor(X_test.toarray(), dtype=torch.float32).to(device) y_prov_train = torch.tensor(provider_labels[train_idx], dtype=torch.long).to(device) y_prov_test = torch.tensor(provider_labels[test_idx], dtype=torch.long).to(device) if device.type == "cuda": mem_used = torch.cuda.memory_allocated() / 1024**3 _log(f"GPU memory used: {mem_used:.2f} GB", t0) else: _log(f"Data on {device}", t0) # ── DataLoaders ──────────────────────────────────────────── batch_size = min(BATCH_SIZE, 512) if device.type == "cpu" else BATCH_SIZE train_ds = TensorDataset(X_train_t, y_prov_train) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) val_ds = TensorDataset(X_test_t, y_prov_test) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) # ── Model ────────────────────────────────────────────────── _log("Building model...") net = AIFinderNet( input_dim=input_dim, num_providers=num_providers, hidden_dim=HIDDEN_DIM, embed_dim=EMBED_DIM, dropout=DROPOUT, ).to(device) n_params = sum(p.numel() for p in net.parameters()) _log(f"Model: {n_params:,} parameters") # ── Class-weighted loss ──────────────────────────────────── prov_weights = compute_class_weight( "balanced", classes=np.arange(num_providers), y=provider_labels[train_idx] ) prov_criterion = nn.CrossEntropyLoss( weight=torch.tensor(prov_weights, dtype=torch.float32).to(device) ) # ── Optimizer + scheduler ────────────────────────────────── optimizer = torch.optim.AdamW( net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY ) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=LEARNING_RATE, epochs=EPOCHS, steps_per_epoch=len(train_loader), ) use_amp = device.type == "cuda" scaler = torch.amp.GradScaler() if use_amp else None # ── Training loop ────────────────────────────────────────── _log( f"Training for {EPOCHS} epochs, batch_size={batch_size}, " f"early_stop_patience={EARLY_STOP_PATIENCE}..." ) t0 = time.time() best_val_loss = float("inf") best_state = None patience_counter = 0 for epoch in range(EPOCHS): # ── Train phase ─────────────────────────────────────── net.train() epoch_loss = 0.0 n_batches = 0 for batch_X, batch_prov in train_loader: optimizer.zero_grad(set_to_none=True) if use_amp: with torch.amp.autocast(device_type="cuda"): prov_logits = net(batch_X) loss = prov_criterion(prov_logits, batch_prov) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() else: prov_logits = net(batch_X) loss = prov_criterion(prov_logits, batch_prov) loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0) optimizer.step() scheduler.step() epoch_loss += loss.item() n_batches += 1 avg_train_loss = epoch_loss / n_batches # ── Validation phase ────────────────────────────────── net.eval() val_loss = 0.0 val_batches = 0 with torch.no_grad(): for batch_X, batch_prov in val_loader: prov_logits = net(batch_X) loss = prov_criterion(prov_logits, batch_prov) val_loss += loss.item() val_batches += 1 avg_val_loss = val_loss / val_batches # ── Early stopping check ────────────────────────────── if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss best_state = {k: v.clone() for k, v in net.state_dict().items()} patience_counter = 0 else: patience_counter += 1 # ── Logging ─────────────────────────────────────────── if (epoch + 1) % 5 == 0 or epoch == 0: lr = scheduler.get_last_lr()[0] marker = " *" if patience_counter == 0 else "" _log( f"Epoch {epoch + 1:>3d}/{EPOCHS} " f"train={avg_train_loss:.4f} " f"val={avg_val_loss:.4f} " f"lr={lr:.2e}{marker}" ) if patience_counter >= EARLY_STOP_PATIENCE: _log( f"Early stopping at epoch {epoch + 1} " f"(best val_loss={best_val_loss:.4f})" ) break # Restore best weights if best_state is not None: net.load_state_dict(best_state) _log(f"Restored best weights (val_loss={best_val_loss:.4f})") _log("Training complete", t0) # ── Evaluate ─────────────────────────────────────────────── _log("Evaluating...") net.eval() with torch.no_grad(): prov_logits = net(X_test_t) prov_preds = prov_logits.argmax(dim=1).cpu().numpy() prov_true = y_prov_test.cpu().numpy() print("\n === Provider Classification ===") print( classification_report( prov_true, prov_preds, target_names=provider_enc.classes_, zero_division=0, ) ) # ── Save ─────────────────────────────────────────────────── _log(f"Saving to {MODEL_DIR}/ ...") t0 = time.time() os.makedirs(MODEL_DIR, exist_ok=True) checkpoint = { "input_dim": input_dim, "num_providers": num_providers, "hidden_dim": HIDDEN_DIM, "embed_dim": EMBED_DIM, "dropout": DROPOUT, "state_dict": net.state_dict(), } torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt")) _log(" Saved classifier.pt") joblib.dump(pipeline, os.path.join(MODEL_DIR, "feature_pipeline.joblib")) _log(" Saved feature_pipeline.joblib") joblib.dump(provider_enc, os.path.join(MODEL_DIR, "provider_enc.joblib")) _log(" Saved provider_enc.joblib") _log("All artifacts saved", t0) elapsed = time.time() - t_start if device.type == "cuda": mem_peak = torch.cuda.max_memory_allocated() / 1024**3 print(f"\n{'=' * 60}") print(f"Training complete in {elapsed:.1f}s (peak GPU mem: {mem_peak:.2f} GB)") print(f"{'=' * 60}") else: print(f"\n{'=' * 60}") print(f"Training complete in {elapsed:.1f}s") print(f"{'=' * 60}") if __name__ == "__main__": main()