import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import TensorDataset, DataLoader from sklearn.decomposition import TruncatedSVD from sklearn.model_selection import train_test_split import pickle import os import numpy as np import random from src.hybrid_model import SimpleCNN from src.utils import load_data from src import config def set_seed(seed=42): random.seed(seed); np.random.seed(seed) torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True def train_svd(X_flat, n_components=20): print(f"Training SVD (k={n_components})...") X_np = X_flat.numpy() mean = X_np.mean(axis=0) svd = TruncatedSVD(n_components=n_components, random_state=42).fit(X_np - mean) svd._train_mean = mean with open(config.SVD_MODEL_PATH, "wb") as f: pickle.dump(svd, f) return svd def train_cnn(X_flat, y, batch_size=64, epochs=5): X_train, X_val, y_train, y_val = train_test_split(X_flat.numpy(), y.numpy(), test_size=0.2, random_state=42, stratify=y.numpy()) def to_loader(X, y, shuffle=True): return DataLoader(TensorDataset(torch.tensor(X).view(-1, 1, 28, 28), torch.tensor(y, dtype=torch.long)), batch_size=batch_size, shuffle=shuffle) train_loader, val_loader = to_loader(X_train, y_train), to_loader(X_val, y_val, False) model = SimpleCNN().to("cuda" if torch.cuda.is_available() else "cpu") opt = optim.Adam(model.parameters(), lr=0.001) crit = nn.CrossEntropyLoss() history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []} best_acc, best_state = 0, None for epoch in range(epochs): model.train() t_loss, t_corr = 0, 0 for x, labels in train_loader: x, labels = x.to(next(model.parameters()).device), labels.to(next(model.parameters()).device) opt.zero_grad(); out = model(x); loss = crit(out, labels); loss.backward(); opt.step() t_loss += loss.item(); t_corr += (out.argmax(1) == labels).sum().item() model.eval(); v_loss, v_corr = 0, 0 with torch.no_grad(): for x, labels in val_loader: x, labels = x.to(next(model.parameters()).device), labels.to(next(model.parameters()).device) out = model(x); v_loss += crit(out, labels).item(); v_corr += (out.argmax(1) == labels).sum().item() history['train_acc'].append(100 * t_corr / len(X_train)); history['val_acc'].append(100 * v_corr / len(X_val)) print(f"Epoch {epoch+1}: Train Acc {history['train_acc'][-1]:.2f}%, Val Acc {history['val_acc'][-1]:.2f}%") if history['val_acc'][-1] > best_acc: best_acc, best_state = history['val_acc'][-1], model.state_dict().copy() model.load_state_dict(best_state) torch.save(model.cpu().state_dict(), config.CNN_MODEL_PATH) with open(config.CNN_MODEL_PATH.replace('.pth', '_history.pkl'), 'wb') as f: pickle.dump(history, f) return model, history if __name__ == "__main__": set_seed() X, y = load_data() train_svd(X.view(-1, 784)) train_cnn(X.view(-1, 784), y)