Spaces:
Sleeping
Sleeping
| import time | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Subset, random_split | |
| from torchvision import datasets, transforms | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| # ----------------------------- | |
| # Custom model.nn (simple + fast) | |
| # ----------------------------- | |
| class MnistMLP(nn.Module): | |
| def __init__(self, hidden=256, dropout=0.2): | |
| super().__init__() | |
| self.fc1 = nn.Linear(28 * 28, hidden) | |
| self.drop = nn.Dropout(dropout) | |
| self.fc2 = nn.Linear(hidden, 10) | |
| def forward(self, x): | |
| x = x.view(x.size(0), -1) # flatten | |
| x = F.relu(self.fc1(x)) | |
| x = self.drop(x) | |
| return self.fc2(x) # logits | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def get_device(): | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def make_plot(train_loss, val_loss, train_acc, val_acc): | |
| # One figure, two lines for loss, two lines for acc | |
| fig = plt.figure() | |
| epochs = np.arange(1, len(train_loss) + 1) | |
| plt.plot(epochs, train_loss, label="train loss") | |
| plt.plot(epochs, val_loss, label="val loss") | |
| plt.xlabel("epoch") | |
| plt.ylabel("loss") | |
| plt.title("Loss curves") | |
| plt.legend() | |
| plt.tight_layout() | |
| return fig | |
| def make_acc_plot(train_acc, val_acc): | |
| fig = plt.figure() | |
| epochs = np.arange(1, len(train_acc) + 1) | |
| plt.plot(epochs, train_acc, label="train acc") | |
| plt.plot(epochs, val_acc, label="val acc") | |
| plt.xlabel("epoch") | |
| plt.ylabel("accuracy") | |
| plt.title("Accuracy curves") | |
| plt.legend() | |
| plt.tight_layout() | |
| return fig | |
| def evaluate(model, loader, device): | |
| model.eval() | |
| correct = 0 | |
| total = 0 | |
| loss_sum = 0.0 | |
| crit = nn.CrossEntropyLoss() | |
| with torch.no_grad(): | |
| for x, y in loader: | |
| x, y = x.to(device), y.to(device) | |
| logits = model(x) | |
| loss = crit(logits, y) | |
| loss_sum += loss.item() | |
| preds = logits.argmax(dim=1) | |
| correct += (preds == y).sum().item() | |
| total += y.numel() | |
| avg_loss = loss_sum / max(1, len(loader)) | |
| acc = correct / max(1, total) | |
| return avg_loss, acc | |
| # ----------------------------- | |
| # Train function (ONLY runs on button click) | |
| # ----------------------------- | |
| def train_mnist(epochs, lr, batch_size, hidden, dropout, train_subset, progress=gr.Progress()): | |
| device = get_device() | |
| # Dataset is created here (not at import), so the app loads instantly. | |
| tfm = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| progress(0, desc="Downloading/loading MNIST (first run can take a bit)…") | |
| full_train = datasets.MNIST(root="data", train=True, download=True, transform=tfm) | |
| # Optional subset for speed | |
| n = int(train_subset) | |
| if n < len(full_train): | |
| full_train = Subset(full_train, range(n)) | |
| # Split train/val | |
| val_size = max(1, int(0.1 * len(full_train))) | |
| train_size = len(full_train) - val_size | |
| train_ds, val_ds = random_split(full_train, [train_size, val_size]) | |
| train_loader = DataLoader(train_ds, batch_size=int(batch_size), shuffle=True, num_workers=0) | |
| val_loader = DataLoader(val_ds, batch_size=int(batch_size), shuffle=False, num_workers=0) | |
| model = MnistMLP(hidden=int(hidden), dropout=float(dropout)).to(device) | |
| opt = torch.optim.Adam(model.parameters(), lr=float(lr)) | |
| crit = nn.CrossEntropyLoss() | |
| train_losses, val_losses = [], [] | |
| train_accs, val_accs = [], [] | |
| start = time.time() | |
| for ep in range(1, int(epochs) + 1): | |
| model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| progress((ep - 1) / max(1, int(epochs)), desc=f"Training epoch {ep}/{int(epochs)}…") | |
| for x, y in train_loader: | |
| x, y = x.to(device), y.to(device) | |
| opt.zero_grad() | |
| logits = model(x) | |
| loss = crit(logits, y) | |
| loss.backward() | |
| opt.step() | |
| running_loss += loss.item() | |
| preds = logits.argmax(dim=1) | |
| correct += (preds == y).sum().item() | |
| total += y.numel() | |
| train_loss = running_loss / max(1, len(train_loader)) | |
| train_acc = correct / max(1, total) | |
| val_loss, val_acc = evaluate(model, val_loader, device) | |
| train_losses.append(train_loss) | |
| val_losses.append(val_loss) | |
| train_accs.append(train_acc) | |
| val_accs.append(val_acc) | |
| elapsed = time.time() - start | |
| progress(1, desc="Done") | |
| # Build “hyperparameters” summary | |
| summary = ( | |
| f"**Device:** `{device}`\n\n" | |
| f"**Hyperparameters**\n" | |
| f"- epochs: `{int(epochs)}`\n" | |
| f"- lr: `{float(lr)}`\n" | |
| f"- batch_size: `{int(batch_size)}`\n" | |
| f"- hidden: `{int(hidden)}`\n" | |
| f"- dropout: `{float(dropout)}`\n" | |
| f"- train_subset: `{int(train_subset)}` (10% used for val)\n\n" | |
| f"**Final metrics**\n" | |
| f"- train loss: `{train_losses[-1]:.4f}` | train acc: `{train_accs[-1]:.4f}`\n" | |
| f"- val loss: `{val_losses[-1]:.4f}` | val acc: `{val_accs[-1]:.4f}`\n\n" | |
| f"**Time:** `{elapsed:.1f}s`" | |
| ) | |
| loss_fig = make_plot(train_losses, val_losses, train_accs, val_accs) | |
| acc_fig = make_acc_plot(train_accs, val_accs) | |
| # Return figs + a small table-like text log | |
| log_lines = ["epoch,train_loss,val_loss,train_acc,val_acc"] | |
| for i in range(len(train_losses)): | |
| log_lines.append( | |
| f"{i+1},{train_losses[i]:.4f},{val_losses[i]:.4f},{train_accs[i]:.4f},{val_accs[i]:.4f}" | |
| ) | |
| log_csv = "\n".join(log_lines) | |
| return summary, loss_fig, acc_fig, log_csv | |
| # ----------------------------- | |
| # Gradio UI (simple) | |
| # ----------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# MNIST Trainer (PyTorch `nn.Module`) — Loss Curves + Hyperparameters") | |
| gr.Markdown("This app only downloads/trains when you click **Train**, so it won’t hang on load.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| epochs = gr.Slider(1, 10, value=3, step=1, label="epochs") | |
| lr = gr.Number(value=1e-3, label="learning rate", precision=6) | |
| batch = gr.Slider(32, 256, value=128, step=32, label="batch_size") | |
| hidden = gr.Slider(64, 512, value=256, step=64, label="hidden units") | |
| dropout = gr.Slider(0.0, 0.6, value=0.2, step=0.05, label="dropout") | |
| train_subset = gr.Slider(1000, 60000, value=12000, step=1000, label="train_subset (speed control)") | |
| train_btn = gr.Button("Train") | |
| with gr.Column(): | |
| summary = gr.Markdown() | |
| loss_plot = gr.Plot(label="Loss curves") | |
| acc_plot = gr.Plot(label="Accuracy curves") | |
| log_csv = gr.Textbox(label="Epoch log (CSV)", lines=10) | |
| train_btn.click( | |
| fn=train_mnist, | |
| inputs=[epochs, lr, batch, hidden, dropout, train_subset], | |
| outputs=[summary, loss_plot, acc_plot, log_csv], | |
| ) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() |