Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import TensorDataset, DataLoader | |
| from sklearn.decomposition import TruncatedSVD | |
| from sklearn.model_selection import train_test_split | |
| import pickle | |
| import os | |
| import numpy as np | |
| import random | |
| from src.hybrid_model import SimpleCNN | |
| from src.utils import load_data | |
| from src import config | |
| def set_seed(seed=42): | |
| random.seed(seed); np.random.seed(seed) | |
| torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| def train_svd(X_flat, n_components=20): | |
| print(f"Training SVD (k={n_components})...") | |
| X_np = X_flat.numpy() | |
| mean = X_np.mean(axis=0) | |
| svd = TruncatedSVD(n_components=n_components, random_state=42).fit(X_np - mean) | |
| svd._train_mean = mean | |
| with open(config.SVD_MODEL_PATH, "wb") as f: pickle.dump(svd, f) | |
| return svd | |
| def train_cnn(X_flat, y, batch_size=64, epochs=5): | |
| X_train, X_val, y_train, y_val = train_test_split(X_flat.numpy(), y.numpy(), test_size=0.2, random_state=42, stratify=y.numpy()) | |
| def to_loader(X, y, shuffle=True): | |
| return DataLoader(TensorDataset(torch.tensor(X).view(-1, 1, 28, 28), torch.tensor(y, dtype=torch.long)), batch_size=batch_size, shuffle=shuffle) | |
| train_loader, val_loader = to_loader(X_train, y_train), to_loader(X_val, y_val, False) | |
| model = SimpleCNN().to("cuda" if torch.cuda.is_available() else "cpu") | |
| opt = optim.Adam(model.parameters(), lr=0.001) | |
| crit = nn.CrossEntropyLoss() | |
| history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []} | |
| best_acc, best_state = 0, None | |
| for epoch in range(epochs): | |
| model.train() | |
| t_loss, t_corr = 0, 0 | |
| for x, labels in train_loader: | |
| x, labels = x.to(next(model.parameters()).device), labels.to(next(model.parameters()).device) | |
| opt.zero_grad(); out = model(x); loss = crit(out, labels); loss.backward(); opt.step() | |
| t_loss += loss.item(); t_corr += (out.argmax(1) == labels).sum().item() | |
| model.eval(); v_loss, v_corr = 0, 0 | |
| with torch.no_grad(): | |
| for x, labels in val_loader: | |
| x, labels = x.to(next(model.parameters()).device), labels.to(next(model.parameters()).device) | |
| out = model(x); v_loss += crit(out, labels).item(); v_corr += (out.argmax(1) == labels).sum().item() | |
| history['train_acc'].append(100 * t_corr / len(X_train)); history['val_acc'].append(100 * v_corr / len(X_val)) | |
| print(f"Epoch {epoch+1}: Train Acc {history['train_acc'][-1]:.2f}%, Val Acc {history['val_acc'][-1]:.2f}%") | |
| if history['val_acc'][-1] > best_acc: | |
| best_acc, best_state = history['val_acc'][-1], model.state_dict().copy() | |
| model.load_state_dict(best_state) | |
| torch.save(model.cpu().state_dict(), config.CNN_MODEL_PATH) | |
| with open(config.CNN_MODEL_PATH.replace('.pth', '_history.pkl'), 'wb') as f: pickle.dump(history, f) | |
| return model, history | |
| if __name__ == "__main__": | |
| set_seed() | |
| X, y = load_data() | |
| train_svd(X.view(-1, 784)) | |
| train_cnn(X.view(-1, 784), y) | |