Coconut-MNIST / src /train_models.py
ymlin105's picture
feat: complete Hybrid SVD-CNN system with interactive app
b25b9cb
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.decomposition import TruncatedSVD
from sklearn.model_selection import train_test_split
import pickle
import os
import numpy as np
import random
from src.hybrid_model import SimpleCNN
from src.utils import load_data
from src import config
def set_seed(seed=42):
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def train_svd(X_flat, n_components=20):
print(f"Training SVD (k={n_components})...")
X_np = X_flat.numpy()
mean = X_np.mean(axis=0)
svd = TruncatedSVD(n_components=n_components, random_state=42).fit(X_np - mean)
svd._train_mean = mean
with open(config.SVD_MODEL_PATH, "wb") as f: pickle.dump(svd, f)
return svd
def train_cnn(X_flat, y, batch_size=64, epochs=5):
X_train, X_val, y_train, y_val = train_test_split(X_flat.numpy(), y.numpy(), test_size=0.2, random_state=42, stratify=y.numpy())
def to_loader(X, y, shuffle=True):
return DataLoader(TensorDataset(torch.tensor(X).view(-1, 1, 28, 28), torch.tensor(y, dtype=torch.long)), batch_size=batch_size, shuffle=shuffle)
train_loader, val_loader = to_loader(X_train, y_train), to_loader(X_val, y_val, False)
model = SimpleCNN().to("cuda" if torch.cuda.is_available() else "cpu")
opt = optim.Adam(model.parameters(), lr=0.001)
crit = nn.CrossEntropyLoss()
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
best_acc, best_state = 0, None
for epoch in range(epochs):
model.train()
t_loss, t_corr = 0, 0
for x, labels in train_loader:
x, labels = x.to(next(model.parameters()).device), labels.to(next(model.parameters()).device)
opt.zero_grad(); out = model(x); loss = crit(out, labels); loss.backward(); opt.step()
t_loss += loss.item(); t_corr += (out.argmax(1) == labels).sum().item()
model.eval(); v_loss, v_corr = 0, 0
with torch.no_grad():
for x, labels in val_loader:
x, labels = x.to(next(model.parameters()).device), labels.to(next(model.parameters()).device)
out = model(x); v_loss += crit(out, labels).item(); v_corr += (out.argmax(1) == labels).sum().item()
history['train_acc'].append(100 * t_corr / len(X_train)); history['val_acc'].append(100 * v_corr / len(X_val))
print(f"Epoch {epoch+1}: Train Acc {history['train_acc'][-1]:.2f}%, Val Acc {history['val_acc'][-1]:.2f}%")
if history['val_acc'][-1] > best_acc:
best_acc, best_state = history['val_acc'][-1], model.state_dict().copy()
model.load_state_dict(best_state)
torch.save(model.cpu().state_dict(), config.CNN_MODEL_PATH)
with open(config.CNN_MODEL_PATH.replace('.pth', '_history.pkl'), 'wb') as f: pickle.dump(history, f)
return model, history
if __name__ == "__main__":
set_seed()
X, y = load_data()
train_svd(X.view(-1, 784))
train_cnn(X.view(-1, 784), y)