Spaces:
Sleeping
Sleeping
File size: 4,410 Bytes
9686dbe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """
MamaGuard — Training Loop
Trains the Mamba3 model with class-weighted loss and LR scheduling.
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
import os
from src.model import MamaGuardMamba3
from src.data_pipeline import get_dataloaders
MODEL_SAVE_PATH = "models/mamaguard_mamba3.pt"
SCALER_SAVE_PATH = "models/scaler.pkl"
os.makedirs("models", exist_ok=True)
def compute_class_weights(train_loader):
"""Compute inverse-frequency class weights with 3× boost for high-risk."""
all_labels = []
for _, y, _ in train_loader:
all_labels.extend(y.numpy())
counts = np.bincount(all_labels, minlength=3)
total = counts.sum()
weights = total / (3 * counts + 1e-6)
weights[2] *= 3.0 # high-risk class boost
print(f"Class weights: LOW={weights[0]:.2f}, MID={weights[1]:.2f}, HIGH={weights[2]:.2f}")
return torch.tensor(weights, dtype=torch.float32)
def train(
csv_path: str = "data/maternal_health.csv",
epochs: int = 50,
batch_size: int = 32,
lr: float = 1e-3,
device: str = None
):
"""Full training loop with validation and best-model checkpointing."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on: {device}")
# Load data
train_loader, val_loader, scaler = get_dataloaders(csv_path, batch_size)
with open(SCALER_SAVE_PATH, "wb") as f:
pickle.dump(scaler, f)
print(f"Scaler saved to {SCALER_SAVE_PATH}")
# Build model
model = MamaGuardMamba3(
input_dim=6, d_model=64, n_layers=4, n_classes=3, d_state=32
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
# Loss function with class weights
class_weights = compute_class_weights(train_loader).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
# Optimizer + scheduler
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5,
)
# Training loop
best_val_loss = float("inf")
best_val_acc = 0.0
for epoch in range(1, epochs + 1):
# Training phase
model.train()
train_loss, train_correct, train_total = 0.0, 0, 0
for X_batch, y_batch, _ in train_loader:
X_batch = X_batch.to(device)
y_batch = y_batch.to(device)
optimizer.zero_grad()
logits = model(X_batch)
loss = criterion(logits, y_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
preds = logits.argmax(dim=-1)
train_correct += (preds == y_batch).sum().item()
train_total += len(y_batch)
# Validation phase
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for X_batch, y_batch, _ in val_loader:
X_batch = X_batch.to(device)
y_batch = y_batch.to(device)
logits = model(X_batch)
loss = criterion(logits, y_batch)
val_loss += loss.item()
preds = logits.argmax(dim=-1)
val_correct += (preds == y_batch).sum().item()
val_total += len(y_batch)
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
train_acc = train_correct / train_total
val_acc = val_correct / val_total
print(
f"Epoch {epoch:3d}/{epochs} | "
f"Train Loss: {avg_train_loss:.4f} Acc: {train_acc:.3f} | "
f"Val Loss: {avg_val_loss:.4f} Acc: {val_acc:.3f}"
)
scheduler.step(avg_val_loss)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_val_loss = avg_val_loss
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f" * Best model saved (val_acc={val_acc:.3f})")
print(f"\nTraining complete. Best val accuracy: {best_val_acc:.3f}")
print(f"Model saved to: {MODEL_SAVE_PATH}")
return model
if __name__ == "__main__":
train() |