AIFinder / train.py
CompactAI's picture
Upload 15 files
0051294 verified
"""
AIFinder Training Script
Loads data, trains a two-headed GPU classifier, reports metrics, and saves the model.
Usage: python3 train.py
"""
import os
import sys
import time
import joblib
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from config import (
MODEL_DIR,
TEST_SIZE,
RANDOM_STATE,
HIDDEN_DIM,
EMBED_DIM,
DROPOUT,
BATCH_SIZE,
EPOCHS,
LEARNING_RATE,
WEIGHT_DECAY,
EARLY_STOP_PATIENCE,
)
from data_loader import load_all_data
from features import FeaturePipeline
from model import AIFinderNet
def _log(msg, t0=None):
"""Print a timestamped log message, optionally with elapsed time."""
ts = time.strftime("%H:%M:%S")
if t0 is not None:
elapsed = time.time() - t0
print(f" [{ts}] {msg} ({elapsed:.1f}s)")
else:
print(f" [{ts}] {msg}")
def main():
t_start = time.time()
print("=" * 60)
print("AIFinder Training - Provider Classification")
print("=" * 60)
# ── GPU check ──────────────────────────────────────────────
if torch.cuda.is_available():
device = torch.device("cuda")
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
_log(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
else:
device = torch.device("cpu")
_log("No GPU available, using CPU")
# ── Load data ──────────────────────────────────────────────
_log("Starting data load...")
t0 = time.time()
texts, providers, models, _is_ai = load_all_data()
_log("Data load complete", t0)
if len(texts) < 100:
print("ERROR: Not enough data loaded. Check dataset access.")
sys.exit(1)
# ── Encode labels ──────────────────────────────────────────
_log("Encoding labels...")
t0 = time.time()
provider_enc = LabelEncoder()
provider_labels = provider_enc.fit_transform(providers)
num_providers = len(provider_enc.classes_)
_log(f"Labels encoded β€” {num_providers} providers", t0)
# ── Train/test split ───────────────────────────────────────
_log("Splitting train/test...")
t0 = time.time()
indices = np.arange(len(texts))
train_idx, test_idx = train_test_split(
indices,
test_size=TEST_SIZE,
random_state=RANDOM_STATE,
stratify=provider_labels,
)
train_texts = [texts[i] for i in train_idx]
test_texts = [texts[i] for i in test_idx]
_log(f"Split: {len(train_texts)} train / {len(test_texts)} test", t0)
# ── Build features ─────────────────────────────────────────
_log("Building feature pipeline (fit on train)...")
t0 = time.time()
pipeline = FeaturePipeline()
X_train = pipeline.fit_transform(train_texts)
_log(f"Train features: {X_train.shape}", t0)
_log("Transforming test set...")
t0 = time.time()
X_test = pipeline.transform(test_texts)
_log(f"Test features: {X_test.shape}", t0)
input_dim = X_train.shape[1]
# ── Move to device ─────────────────────────────────────────
_log(f"Moving data to {device}...")
t0 = time.time()
X_train_t = torch.tensor(X_train.toarray(), dtype=torch.float32).to(device)
X_test_t = torch.tensor(X_test.toarray(), dtype=torch.float32).to(device)
y_prov_train = torch.tensor(provider_labels[train_idx], dtype=torch.long).to(device)
y_prov_test = torch.tensor(provider_labels[test_idx], dtype=torch.long).to(device)
if device.type == "cuda":
mem_used = torch.cuda.memory_allocated() / 1024**3
_log(f"GPU memory used: {mem_used:.2f} GB", t0)
else:
_log(f"Data on {device}", t0)
# ── DataLoaders ────────────────────────────────────────────
batch_size = min(BATCH_SIZE, 512) if device.type == "cpu" else BATCH_SIZE
train_ds = TensorDataset(X_train_t, y_prov_train)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_ds = TensorDataset(X_test_t, y_prov_test)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
# ── Model ──────────────────────────────────────────────────
_log("Building model...")
net = AIFinderNet(
input_dim=input_dim,
num_providers=num_providers,
hidden_dim=HIDDEN_DIM,
embed_dim=EMBED_DIM,
dropout=DROPOUT,
).to(device)
n_params = sum(p.numel() for p in net.parameters())
_log(f"Model: {n_params:,} parameters")
# ── Class-weighted loss ────────────────────────────────────
prov_weights = compute_class_weight(
"balanced", classes=np.arange(num_providers), y=provider_labels[train_idx]
)
prov_criterion = nn.CrossEntropyLoss(
weight=torch.tensor(prov_weights, dtype=torch.float32).to(device)
)
# ── Optimizer + scheduler ──────────────────────────────────
optimizer = torch.optim.AdamW(
net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=LEARNING_RATE,
epochs=EPOCHS,
steps_per_epoch=len(train_loader),
)
use_amp = device.type == "cuda"
scaler = torch.amp.GradScaler() if use_amp else None
# ── Training loop ──────────────────────────────────────────
_log(
f"Training for {EPOCHS} epochs, batch_size={batch_size}, "
f"early_stop_patience={EARLY_STOP_PATIENCE}..."
)
t0 = time.time()
best_val_loss = float("inf")
best_state = None
patience_counter = 0
for epoch in range(EPOCHS):
# ── Train phase ───────────────────────────────────────
net.train()
epoch_loss = 0.0
n_batches = 0
for batch_X, batch_prov in train_loader:
optimizer.zero_grad(set_to_none=True)
if use_amp:
with torch.amp.autocast(device_type="cuda"):
prov_logits = net(batch_X)
loss = prov_criterion(prov_logits, batch_prov)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
else:
prov_logits = net(batch_X)
loss = prov_criterion(prov_logits, batch_prov)
loss.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
epoch_loss += loss.item()
n_batches += 1
avg_train_loss = epoch_loss / n_batches
# ── Validation phase ──────────────────────────────────
net.eval()
val_loss = 0.0
val_batches = 0
with torch.no_grad():
for batch_X, batch_prov in val_loader:
prov_logits = net(batch_X)
loss = prov_criterion(prov_logits, batch_prov)
val_loss += loss.item()
val_batches += 1
avg_val_loss = val_loss / val_batches
# ── Early stopping check ──────────────────────────────
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
best_state = {k: v.clone() for k, v in net.state_dict().items()}
patience_counter = 0
else:
patience_counter += 1
# ── Logging ───────────────────────────────────────────
if (epoch + 1) % 5 == 0 or epoch == 0:
lr = scheduler.get_last_lr()[0]
marker = " *" if patience_counter == 0 else ""
_log(
f"Epoch {epoch + 1:>3d}/{EPOCHS} "
f"train={avg_train_loss:.4f} "
f"val={avg_val_loss:.4f} "
f"lr={lr:.2e}{marker}"
)
if patience_counter >= EARLY_STOP_PATIENCE:
_log(
f"Early stopping at epoch {epoch + 1} "
f"(best val_loss={best_val_loss:.4f})"
)
break
# Restore best weights
if best_state is not None:
net.load_state_dict(best_state)
_log(f"Restored best weights (val_loss={best_val_loss:.4f})")
_log("Training complete", t0)
# ── Evaluate ───────────────────────────────────────────────
_log("Evaluating...")
net.eval()
with torch.no_grad():
prov_logits = net(X_test_t)
prov_preds = prov_logits.argmax(dim=1).cpu().numpy()
prov_true = y_prov_test.cpu().numpy()
print("\n === Provider Classification ===")
print(
classification_report(
prov_true,
prov_preds,
target_names=provider_enc.classes_,
zero_division=0,
)
)
# ── Save ───────────────────────────────────────────────────
_log(f"Saving to {MODEL_DIR}/ ...")
t0 = time.time()
os.makedirs(MODEL_DIR, exist_ok=True)
checkpoint = {
"input_dim": input_dim,
"num_providers": num_providers,
"hidden_dim": HIDDEN_DIM,
"embed_dim": EMBED_DIM,
"dropout": DROPOUT,
"state_dict": net.state_dict(),
}
torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt"))
_log(" Saved classifier.pt")
joblib.dump(pipeline, os.path.join(MODEL_DIR, "feature_pipeline.joblib"))
_log(" Saved feature_pipeline.joblib")
joblib.dump(provider_enc, os.path.join(MODEL_DIR, "provider_enc.joblib"))
_log(" Saved provider_enc.joblib")
_log("All artifacts saved", t0)
elapsed = time.time() - t_start
if device.type == "cuda":
mem_peak = torch.cuda.max_memory_allocated() / 1024**3
print(f"\n{'=' * 60}")
print(f"Training complete in {elapsed:.1f}s (peak GPU mem: {mem_peak:.2f} GB)")
print(f"{'=' * 60}")
else:
print(f"\n{'=' * 60}")
print(f"Training complete in {elapsed:.1f}s")
print(f"{'=' * 60}")
if __name__ == "__main__":
main()