SheGuard / src /train.py
3v324v23's picture
Deploy SheGuard - Maternal Risk Assessment with Mamba3 SSM
9686dbe
"""
MamaGuard — Training Loop
Trains the Mamba3 model with class-weighted loss and LR scheduling.
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
import os
from src.model import MamaGuardMamba3
from src.data_pipeline import get_dataloaders
MODEL_SAVE_PATH = "models/mamaguard_mamba3.pt"
SCALER_SAVE_PATH = "models/scaler.pkl"
os.makedirs("models", exist_ok=True)
def compute_class_weights(train_loader):
"""Compute inverse-frequency class weights with 3× boost for high-risk."""
all_labels = []
for _, y, _ in train_loader:
all_labels.extend(y.numpy())
counts = np.bincount(all_labels, minlength=3)
total = counts.sum()
weights = total / (3 * counts + 1e-6)
weights[2] *= 3.0 # high-risk class boost
print(f"Class weights: LOW={weights[0]:.2f}, MID={weights[1]:.2f}, HIGH={weights[2]:.2f}")
return torch.tensor(weights, dtype=torch.float32)
def train(
csv_path: str = "data/maternal_health.csv",
epochs: int = 50,
batch_size: int = 32,
lr: float = 1e-3,
device: str = None
):
"""Full training loop with validation and best-model checkpointing."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on: {device}")
# Load data
train_loader, val_loader, scaler = get_dataloaders(csv_path, batch_size)
with open(SCALER_SAVE_PATH, "wb") as f:
pickle.dump(scaler, f)
print(f"Scaler saved to {SCALER_SAVE_PATH}")
# Build model
model = MamaGuardMamba3(
input_dim=6, d_model=64, n_layers=4, n_classes=3, d_state=32
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
# Loss function with class weights
class_weights = compute_class_weights(train_loader).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
# Optimizer + scheduler
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5,
)
# Training loop
best_val_loss = float("inf")
best_val_acc = 0.0
for epoch in range(1, epochs + 1):
# Training phase
model.train()
train_loss, train_correct, train_total = 0.0, 0, 0
for X_batch, y_batch, _ in train_loader:
X_batch = X_batch.to(device)
y_batch = y_batch.to(device)
optimizer.zero_grad()
logits = model(X_batch)
loss = criterion(logits, y_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
preds = logits.argmax(dim=-1)
train_correct += (preds == y_batch).sum().item()
train_total += len(y_batch)
# Validation phase
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for X_batch, y_batch, _ in val_loader:
X_batch = X_batch.to(device)
y_batch = y_batch.to(device)
logits = model(X_batch)
loss = criterion(logits, y_batch)
val_loss += loss.item()
preds = logits.argmax(dim=-1)
val_correct += (preds == y_batch).sum().item()
val_total += len(y_batch)
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
train_acc = train_correct / train_total
val_acc = val_correct / val_total
print(
f"Epoch {epoch:3d}/{epochs} | "
f"Train Loss: {avg_train_loss:.4f} Acc: {train_acc:.3f} | "
f"Val Loss: {avg_val_loss:.4f} Acc: {val_acc:.3f}"
)
scheduler.step(avg_val_loss)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_val_loss = avg_val_loss
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f" * Best model saved (val_acc={val_acc:.3f})")
print(f"\nTraining complete. Best val accuracy: {best_val_acc:.3f}")
print(f"Model saved to: {MODEL_SAVE_PATH}")
return model
if __name__ == "__main__":
train()