File size: 4,410 Bytes
9686dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
MamaGuard — Training Loop
Trains the Mamba3 model with class-weighted loss and LR scheduling.
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
import os
from src.model import MamaGuardMamba3
from src.data_pipeline import get_dataloaders

MODEL_SAVE_PATH  = "models/mamaguard_mamba3.pt"
SCALER_SAVE_PATH = "models/scaler.pkl"
os.makedirs("models", exist_ok=True)


def compute_class_weights(train_loader):
    """Compute inverse-frequency class weights with 3× boost for high-risk."""
    all_labels = []
    for _, y, _ in train_loader:
        all_labels.extend(y.numpy())

    counts = np.bincount(all_labels, minlength=3)
    total  = counts.sum()

    weights = total / (3 * counts + 1e-6)
    weights[2] *= 3.0   # high-risk class boost
    print(f"Class weights: LOW={weights[0]:.2f}, MID={weights[1]:.2f}, HIGH={weights[2]:.2f}")
    return torch.tensor(weights, dtype=torch.float32)


def train(
    csv_path:   str   = "data/maternal_health.csv",
    epochs:     int   = 50,
    batch_size: int   = 32,
    lr:         float = 1e-3,
    device:     str   = None
):
    """Full training loop with validation and best-model checkpointing."""
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Training on: {device}")

    # Load data
    train_loader, val_loader, scaler = get_dataloaders(csv_path, batch_size)

    with open(SCALER_SAVE_PATH, "wb") as f:
        pickle.dump(scaler, f)
    print(f"Scaler saved to {SCALER_SAVE_PATH}")

    # Build model
    model = MamaGuardMamba3(
        input_dim=6, d_model=64, n_layers=4, n_classes=3, d_state=32
    ).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {total_params:,}")

    # Loss function with class weights
    class_weights = compute_class_weights(train_loader).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Optimizer + scheduler
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=5,
    )

    # Training loop
    best_val_loss = float("inf")
    best_val_acc  = 0.0

    for epoch in range(1, epochs + 1):
        # Training phase
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0

        for X_batch, y_batch, _ in train_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            optimizer.zero_grad()
            logits = model(X_batch)
            loss = criterion(logits, y_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item()
            preds = logits.argmax(dim=-1)
            train_correct += (preds == y_batch).sum().item()
            train_total   += len(y_batch)

        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        with torch.no_grad():
            for X_batch, y_batch, _ in val_loader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)
                logits  = model(X_batch)
                loss    = criterion(logits, y_batch)
                val_loss   += loss.item()
                preds = logits.argmax(dim=-1)
                val_correct += (preds == y_batch).sum().item()
                val_total   += len(y_batch)

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss   = val_loss   / len(val_loader)
        train_acc      = train_correct / train_total
        val_acc        = val_correct   / val_total

        print(
            f"Epoch {epoch:3d}/{epochs} | "
            f"Train Loss: {avg_train_loss:.4f} Acc: {train_acc:.3f} | "
            f"Val Loss: {avg_val_loss:.4f} Acc: {val_acc:.3f}"
        )

        scheduler.step(avg_val_loss)

        if val_acc > best_val_acc:
            best_val_acc  = val_acc
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"  * Best model saved (val_acc={val_acc:.3f})")

    print(f"\nTraining complete. Best val accuracy: {best_val_acc:.3f}")
    print(f"Model saved to: {MODEL_SAVE_PATH}")
    return model


if __name__ == "__main__":
    train()