import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Subset, random_split from torchvision import datasets, transforms import gradio as gr import matplotlib.pyplot as plt # ----------------------------- # Custom model.nn (simple + fast) # ----------------------------- class MnistMLP(nn.Module): def __init__(self, hidden=256, dropout=0.2): super().__init__() self.fc1 = nn.Linear(28 * 28, hidden) self.drop = nn.Dropout(dropout) self.fc2 = nn.Linear(hidden, 10) def forward(self, x): x = x.view(x.size(0), -1) # flatten x = F.relu(self.fc1(x)) x = self.drop(x) return self.fc2(x) # logits # ----------------------------- # Helpers # ----------------------------- def get_device(): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def make_plot(train_loss, val_loss, train_acc, val_acc): # One figure, two lines for loss, two lines for acc fig = plt.figure() epochs = np.arange(1, len(train_loss) + 1) plt.plot(epochs, train_loss, label="train loss") plt.plot(epochs, val_loss, label="val loss") plt.xlabel("epoch") plt.ylabel("loss") plt.title("Loss curves") plt.legend() plt.tight_layout() return fig def make_acc_plot(train_acc, val_acc): fig = plt.figure() epochs = np.arange(1, len(train_acc) + 1) plt.plot(epochs, train_acc, label="train acc") plt.plot(epochs, val_acc, label="val acc") plt.xlabel("epoch") plt.ylabel("accuracy") plt.title("Accuracy curves") plt.legend() plt.tight_layout() return fig def evaluate(model, loader, device): model.eval() correct = 0 total = 0 loss_sum = 0.0 crit = nn.CrossEntropyLoss() with torch.no_grad(): for x, y in loader: x, y = x.to(device), y.to(device) logits = model(x) loss = crit(logits, y) loss_sum += loss.item() preds = logits.argmax(dim=1) correct += (preds == y).sum().item() total += y.numel() avg_loss = loss_sum / max(1, len(loader)) acc = correct / max(1, total) return avg_loss, acc # ----------------------------- # Train function (ONLY runs on button click) # ----------------------------- def train_mnist(epochs, lr, batch_size, hidden, dropout, train_subset, progress=gr.Progress()): device = get_device() # Dataset is created here (not at import), so the app loads instantly. tfm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) progress(0, desc="Downloading/loading MNIST (first run can take a bit)…") full_train = datasets.MNIST(root="data", train=True, download=True, transform=tfm) # Optional subset for speed n = int(train_subset) if n < len(full_train): full_train = Subset(full_train, range(n)) # Split train/val val_size = max(1, int(0.1 * len(full_train))) train_size = len(full_train) - val_size train_ds, val_ds = random_split(full_train, [train_size, val_size]) train_loader = DataLoader(train_ds, batch_size=int(batch_size), shuffle=True, num_workers=0) val_loader = DataLoader(val_ds, batch_size=int(batch_size), shuffle=False, num_workers=0) model = MnistMLP(hidden=int(hidden), dropout=float(dropout)).to(device) opt = torch.optim.Adam(model.parameters(), lr=float(lr)) crit = nn.CrossEntropyLoss() train_losses, val_losses = [], [] train_accs, val_accs = [], [] start = time.time() for ep in range(1, int(epochs) + 1): model.train() running_loss = 0.0 correct = 0 total = 0 progress((ep - 1) / max(1, int(epochs)), desc=f"Training epoch {ep}/{int(epochs)}…") for x, y in train_loader: x, y = x.to(device), y.to(device) opt.zero_grad() logits = model(x) loss = crit(logits, y) loss.backward() opt.step() running_loss += loss.item() preds = logits.argmax(dim=1) correct += (preds == y).sum().item() total += y.numel() train_loss = running_loss / max(1, len(train_loader)) train_acc = correct / max(1, total) val_loss, val_acc = evaluate(model, val_loader, device) train_losses.append(train_loss) val_losses.append(val_loss) train_accs.append(train_acc) val_accs.append(val_acc) elapsed = time.time() - start progress(1, desc="Done") # Build “hyperparameters” summary summary = ( f"**Device:** `{device}`\n\n" f"**Hyperparameters**\n" f"- epochs: `{int(epochs)}`\n" f"- lr: `{float(lr)}`\n" f"- batch_size: `{int(batch_size)}`\n" f"- hidden: `{int(hidden)}`\n" f"- dropout: `{float(dropout)}`\n" f"- train_subset: `{int(train_subset)}` (10% used for val)\n\n" f"**Final metrics**\n" f"- train loss: `{train_losses[-1]:.4f}` | train acc: `{train_accs[-1]:.4f}`\n" f"- val loss: `{val_losses[-1]:.4f}` | val acc: `{val_accs[-1]:.4f}`\n\n" f"**Time:** `{elapsed:.1f}s`" ) loss_fig = make_plot(train_losses, val_losses, train_accs, val_accs) acc_fig = make_acc_plot(train_accs, val_accs) # Return figs + a small table-like text log log_lines = ["epoch,train_loss,val_loss,train_acc,val_acc"] for i in range(len(train_losses)): log_lines.append( f"{i+1},{train_losses[i]:.4f},{val_losses[i]:.4f},{train_accs[i]:.4f},{val_accs[i]:.4f}" ) log_csv = "\n".join(log_lines) return summary, loss_fig, acc_fig, log_csv # ----------------------------- # Gradio UI (simple) # ----------------------------- with gr.Blocks() as demo: gr.Markdown("# MNIST Trainer (PyTorch `nn.Module`) — Loss Curves + Hyperparameters") gr.Markdown("This app only downloads/trains when you click **Train**, so it won’t hang on load.") with gr.Row(): with gr.Column(): epochs = gr.Slider(1, 10, value=3, step=1, label="epochs") lr = gr.Number(value=1e-3, label="learning rate", precision=6) batch = gr.Slider(32, 256, value=128, step=32, label="batch_size") hidden = gr.Slider(64, 512, value=256, step=64, label="hidden units") dropout = gr.Slider(0.0, 0.6, value=0.2, step=0.05, label="dropout") train_subset = gr.Slider(1000, 60000, value=12000, step=1000, label="train_subset (speed control)") train_btn = gr.Button("Train") with gr.Column(): summary = gr.Markdown() loss_plot = gr.Plot(label="Loss curves") acc_plot = gr.Plot(label="Accuracy curves") log_csv = gr.Textbox(label="Epoch log (CSV)", lines=10) train_btn.click( fn=train_mnist, inputs=[epochs, lr, batch, hidden, dropout, train_subset], outputs=[summary, loss_plot, acc_plot, log_csv], ) demo.queue() if __name__ == "__main__": demo.launch()