| | """ |
| | AIFinder Training Script |
| | Loads data, trains a two-headed GPU classifier, reports metrics, and saves the model. |
| | |
| | Usage: python3 train.py |
| | """ |
| |
|
| | import os |
| | import sys |
| | import time |
| | import joblib |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import TensorDataset, DataLoader |
| | from sklearn.model_selection import train_test_split |
| | from sklearn.metrics import classification_report |
| | from sklearn.preprocessing import LabelEncoder |
| | from sklearn.utils.class_weight import compute_class_weight |
| |
|
| | from config import ( |
| | MODEL_DIR, |
| | TEST_SIZE, |
| | RANDOM_STATE, |
| | HIDDEN_DIM, |
| | EMBED_DIM, |
| | DROPOUT, |
| | BATCH_SIZE, |
| | EPOCHS, |
| | LEARNING_RATE, |
| | WEIGHT_DECAY, |
| | EARLY_STOP_PATIENCE, |
| | ) |
| | from data_loader import load_all_data |
| | from features import FeaturePipeline |
| | from model import AIFinderNet |
| |
|
| |
|
| | def _log(msg, t0=None): |
| | """Print a timestamped log message, optionally with elapsed time.""" |
| | ts = time.strftime("%H:%M:%S") |
| | if t0 is not None: |
| | elapsed = time.time() - t0 |
| | print(f" [{ts}] {msg} ({elapsed:.1f}s)") |
| | else: |
| | print(f" [{ts}] {msg}") |
| |
|
| |
|
| | def main(): |
| | t_start = time.time() |
| |
|
| | print("=" * 60) |
| | print("AIFinder Training - Provider Classification") |
| | print("=" * 60) |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | gpu_name = torch.cuda.get_device_name(0) |
| | gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3 |
| | _log(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)") |
| | else: |
| | device = torch.device("cpu") |
| | _log("No GPU available, using CPU") |
| |
|
| | |
| | _log("Starting data load...") |
| | t0 = time.time() |
| | texts, providers, models, _is_ai = load_all_data() |
| | _log("Data load complete", t0) |
| |
|
| | if len(texts) < 100: |
| | print("ERROR: Not enough data loaded. Check dataset access.") |
| | sys.exit(1) |
| |
|
| | |
| | _log("Encoding labels...") |
| | t0 = time.time() |
| | provider_enc = LabelEncoder() |
| | provider_labels = provider_enc.fit_transform(providers) |
| | num_providers = len(provider_enc.classes_) |
| | _log(f"Labels encoded β {num_providers} providers", t0) |
| |
|
| | |
| | _log("Splitting train/test...") |
| | t0 = time.time() |
| | indices = np.arange(len(texts)) |
| | train_idx, test_idx = train_test_split( |
| | indices, |
| | test_size=TEST_SIZE, |
| | random_state=RANDOM_STATE, |
| | stratify=provider_labels, |
| | ) |
| | train_texts = [texts[i] for i in train_idx] |
| | test_texts = [texts[i] for i in test_idx] |
| | _log(f"Split: {len(train_texts)} train / {len(test_texts)} test", t0) |
| |
|
| | |
| | _log("Building feature pipeline (fit on train)...") |
| | t0 = time.time() |
| | pipeline = FeaturePipeline() |
| | X_train = pipeline.fit_transform(train_texts) |
| | _log(f"Train features: {X_train.shape}", t0) |
| |
|
| | _log("Transforming test set...") |
| | t0 = time.time() |
| | X_test = pipeline.transform(test_texts) |
| | _log(f"Test features: {X_test.shape}", t0) |
| |
|
| | input_dim = X_train.shape[1] |
| |
|
| | |
| | _log(f"Moving data to {device}...") |
| | t0 = time.time() |
| | X_train_t = torch.tensor(X_train.toarray(), dtype=torch.float32).to(device) |
| | X_test_t = torch.tensor(X_test.toarray(), dtype=torch.float32).to(device) |
| | y_prov_train = torch.tensor(provider_labels[train_idx], dtype=torch.long).to(device) |
| | y_prov_test = torch.tensor(provider_labels[test_idx], dtype=torch.long).to(device) |
| | if device.type == "cuda": |
| | mem_used = torch.cuda.memory_allocated() / 1024**3 |
| | _log(f"GPU memory used: {mem_used:.2f} GB", t0) |
| | else: |
| | _log(f"Data on {device}", t0) |
| |
|
| | |
| | batch_size = min(BATCH_SIZE, 512) if device.type == "cpu" else BATCH_SIZE |
| | train_ds = TensorDataset(X_train_t, y_prov_train) |
| | train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) |
| | val_ds = TensorDataset(X_test_t, y_prov_test) |
| | val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) |
| |
|
| | |
| | _log("Building model...") |
| | net = AIFinderNet( |
| | input_dim=input_dim, |
| | num_providers=num_providers, |
| | hidden_dim=HIDDEN_DIM, |
| | embed_dim=EMBED_DIM, |
| | dropout=DROPOUT, |
| | ).to(device) |
| | n_params = sum(p.numel() for p in net.parameters()) |
| | _log(f"Model: {n_params:,} parameters") |
| |
|
| | |
| | prov_weights = compute_class_weight( |
| | "balanced", classes=np.arange(num_providers), y=provider_labels[train_idx] |
| | ) |
| | prov_criterion = nn.CrossEntropyLoss( |
| | weight=torch.tensor(prov_weights, dtype=torch.float32).to(device) |
| | ) |
| |
|
| | |
| | optimizer = torch.optim.AdamW( |
| | net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY |
| | ) |
| | scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| | optimizer, |
| | max_lr=LEARNING_RATE, |
| | epochs=EPOCHS, |
| | steps_per_epoch=len(train_loader), |
| | ) |
| | use_amp = device.type == "cuda" |
| | scaler = torch.amp.GradScaler() if use_amp else None |
| |
|
| | |
| | _log( |
| | f"Training for {EPOCHS} epochs, batch_size={batch_size}, " |
| | f"early_stop_patience={EARLY_STOP_PATIENCE}..." |
| | ) |
| | t0 = time.time() |
| |
|
| | best_val_loss = float("inf") |
| | best_state = None |
| | patience_counter = 0 |
| |
|
| | for epoch in range(EPOCHS): |
| | |
| | net.train() |
| | epoch_loss = 0.0 |
| | n_batches = 0 |
| |
|
| | for batch_X, batch_prov in train_loader: |
| | optimizer.zero_grad(set_to_none=True) |
| | if use_amp: |
| | with torch.amp.autocast(device_type="cuda"): |
| | prov_logits = net(batch_X) |
| | loss = prov_criterion(prov_logits, batch_prov) |
| | scaler.scale(loss).backward() |
| | scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0) |
| | scaler.step(optimizer) |
| | scaler.update() |
| | else: |
| | prov_logits = net(batch_X) |
| | loss = prov_criterion(prov_logits, batch_prov) |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0) |
| | optimizer.step() |
| | scheduler.step() |
| | epoch_loss += loss.item() |
| | n_batches += 1 |
| |
|
| | avg_train_loss = epoch_loss / n_batches |
| |
|
| | |
| | net.eval() |
| | val_loss = 0.0 |
| | val_batches = 0 |
| | with torch.no_grad(): |
| | for batch_X, batch_prov in val_loader: |
| | prov_logits = net(batch_X) |
| | loss = prov_criterion(prov_logits, batch_prov) |
| | val_loss += loss.item() |
| | val_batches += 1 |
| | avg_val_loss = val_loss / val_batches |
| |
|
| | |
| | if avg_val_loss < best_val_loss: |
| | best_val_loss = avg_val_loss |
| | best_state = {k: v.clone() for k, v in net.state_dict().items()} |
| | patience_counter = 0 |
| | else: |
| | patience_counter += 1 |
| |
|
| | |
| | if (epoch + 1) % 5 == 0 or epoch == 0: |
| | lr = scheduler.get_last_lr()[0] |
| | marker = " *" if patience_counter == 0 else "" |
| | _log( |
| | f"Epoch {epoch + 1:>3d}/{EPOCHS} " |
| | f"train={avg_train_loss:.4f} " |
| | f"val={avg_val_loss:.4f} " |
| | f"lr={lr:.2e}{marker}" |
| | ) |
| |
|
| | if patience_counter >= EARLY_STOP_PATIENCE: |
| | _log( |
| | f"Early stopping at epoch {epoch + 1} " |
| | f"(best val_loss={best_val_loss:.4f})" |
| | ) |
| | break |
| |
|
| | |
| | if best_state is not None: |
| | net.load_state_dict(best_state) |
| | _log(f"Restored best weights (val_loss={best_val_loss:.4f})") |
| |
|
| | _log("Training complete", t0) |
| |
|
| | |
| | _log("Evaluating...") |
| | net.eval() |
| | with torch.no_grad(): |
| | prov_logits = net(X_test_t) |
| |
|
| | prov_preds = prov_logits.argmax(dim=1).cpu().numpy() |
| | prov_true = y_prov_test.cpu().numpy() |
| |
|
| | print("\n === Provider Classification ===") |
| | print( |
| | classification_report( |
| | prov_true, |
| | prov_preds, |
| | target_names=provider_enc.classes_, |
| | zero_division=0, |
| | ) |
| | ) |
| |
|
| | |
| | _log(f"Saving to {MODEL_DIR}/ ...") |
| | t0 = time.time() |
| | os.makedirs(MODEL_DIR, exist_ok=True) |
| |
|
| | checkpoint = { |
| | "input_dim": input_dim, |
| | "num_providers": num_providers, |
| | "hidden_dim": HIDDEN_DIM, |
| | "embed_dim": EMBED_DIM, |
| | "dropout": DROPOUT, |
| | "state_dict": net.state_dict(), |
| | } |
| | torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt")) |
| | _log(" Saved classifier.pt") |
| |
|
| | joblib.dump(pipeline, os.path.join(MODEL_DIR, "feature_pipeline.joblib")) |
| | _log(" Saved feature_pipeline.joblib") |
| | joblib.dump(provider_enc, os.path.join(MODEL_DIR, "provider_enc.joblib")) |
| | _log(" Saved provider_enc.joblib") |
| |
|
| | _log("All artifacts saved", t0) |
| |
|
| | elapsed = time.time() - t_start |
| | if device.type == "cuda": |
| | mem_peak = torch.cuda.max_memory_allocated() / 1024**3 |
| | print(f"\n{'=' * 60}") |
| | print(f"Training complete in {elapsed:.1f}s (peak GPU mem: {mem_peak:.2f} GB)") |
| | print(f"{'=' * 60}") |
| | else: |
| | print(f"\n{'=' * 60}") |
| | print(f"Training complete in {elapsed:.1f}s") |
| | print(f"{'=' * 60}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|