NMIST_PyTorch / app.py
eaglelandsonce's picture
Update app.py
dfdc53b verified
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms
import gradio as gr
import matplotlib.pyplot as plt
# -----------------------------
# Custom model.nn (simple + fast)
# -----------------------------
class MnistMLP(nn.Module):
def __init__(self, hidden=256, dropout=0.2):
super().__init__()
self.fc1 = nn.Linear(28 * 28, hidden)
self.drop = nn.Dropout(dropout)
self.fc2 = nn.Linear(hidden, 10)
def forward(self, x):
x = x.view(x.size(0), -1) # flatten
x = F.relu(self.fc1(x))
x = self.drop(x)
return self.fc2(x) # logits
# -----------------------------
# Helpers
# -----------------------------
def get_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def make_plot(train_loss, val_loss, train_acc, val_acc):
# One figure, two lines for loss, two lines for acc
fig = plt.figure()
epochs = np.arange(1, len(train_loss) + 1)
plt.plot(epochs, train_loss, label="train loss")
plt.plot(epochs, val_loss, label="val loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.title("Loss curves")
plt.legend()
plt.tight_layout()
return fig
def make_acc_plot(train_acc, val_acc):
fig = plt.figure()
epochs = np.arange(1, len(train_acc) + 1)
plt.plot(epochs, train_acc, label="train acc")
plt.plot(epochs, val_acc, label="val acc")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.title("Accuracy curves")
plt.legend()
plt.tight_layout()
return fig
def evaluate(model, loader, device):
model.eval()
correct = 0
total = 0
loss_sum = 0.0
crit = nn.CrossEntropyLoss()
with torch.no_grad():
for x, y in loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = crit(logits, y)
loss_sum += loss.item()
preds = logits.argmax(dim=1)
correct += (preds == y).sum().item()
total += y.numel()
avg_loss = loss_sum / max(1, len(loader))
acc = correct / max(1, total)
return avg_loss, acc
# -----------------------------
# Train function (ONLY runs on button click)
# -----------------------------
def train_mnist(epochs, lr, batch_size, hidden, dropout, train_subset, progress=gr.Progress()):
device = get_device()
# Dataset is created here (not at import), so the app loads instantly.
tfm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
progress(0, desc="Downloading/loading MNIST (first run can take a bit)…")
full_train = datasets.MNIST(root="data", train=True, download=True, transform=tfm)
# Optional subset for speed
n = int(train_subset)
if n < len(full_train):
full_train = Subset(full_train, range(n))
# Split train/val
val_size = max(1, int(0.1 * len(full_train)))
train_size = len(full_train) - val_size
train_ds, val_ds = random_split(full_train, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size=int(batch_size), shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=int(batch_size), shuffle=False, num_workers=0)
model = MnistMLP(hidden=int(hidden), dropout=float(dropout)).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(lr))
crit = nn.CrossEntropyLoss()
train_losses, val_losses = [], []
train_accs, val_accs = [], []
start = time.time()
for ep in range(1, int(epochs) + 1):
model.train()
running_loss = 0.0
correct = 0
total = 0
progress((ep - 1) / max(1, int(epochs)), desc=f"Training epoch {ep}/{int(epochs)}…")
for x, y in train_loader:
x, y = x.to(device), y.to(device)
opt.zero_grad()
logits = model(x)
loss = crit(logits, y)
loss.backward()
opt.step()
running_loss += loss.item()
preds = logits.argmax(dim=1)
correct += (preds == y).sum().item()
total += y.numel()
train_loss = running_loss / max(1, len(train_loader))
train_acc = correct / max(1, total)
val_loss, val_acc = evaluate(model, val_loader, device)
train_losses.append(train_loss)
val_losses.append(val_loss)
train_accs.append(train_acc)
val_accs.append(val_acc)
elapsed = time.time() - start
progress(1, desc="Done")
# Build “hyperparameters” summary
summary = (
f"**Device:** `{device}`\n\n"
f"**Hyperparameters**\n"
f"- epochs: `{int(epochs)}`\n"
f"- lr: `{float(lr)}`\n"
f"- batch_size: `{int(batch_size)}`\n"
f"- hidden: `{int(hidden)}`\n"
f"- dropout: `{float(dropout)}`\n"
f"- train_subset: `{int(train_subset)}` (10% used for val)\n\n"
f"**Final metrics**\n"
f"- train loss: `{train_losses[-1]:.4f}` | train acc: `{train_accs[-1]:.4f}`\n"
f"- val loss: `{val_losses[-1]:.4f}` | val acc: `{val_accs[-1]:.4f}`\n\n"
f"**Time:** `{elapsed:.1f}s`"
)
loss_fig = make_plot(train_losses, val_losses, train_accs, val_accs)
acc_fig = make_acc_plot(train_accs, val_accs)
# Return figs + a small table-like text log
log_lines = ["epoch,train_loss,val_loss,train_acc,val_acc"]
for i in range(len(train_losses)):
log_lines.append(
f"{i+1},{train_losses[i]:.4f},{val_losses[i]:.4f},{train_accs[i]:.4f},{val_accs[i]:.4f}"
)
log_csv = "\n".join(log_lines)
return summary, loss_fig, acc_fig, log_csv
# -----------------------------
# Gradio UI (simple)
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("# MNIST Trainer (PyTorch `nn.Module`) — Loss Curves + Hyperparameters")
gr.Markdown("This app only downloads/trains when you click **Train**, so it won’t hang on load.")
with gr.Row():
with gr.Column():
epochs = gr.Slider(1, 10, value=3, step=1, label="epochs")
lr = gr.Number(value=1e-3, label="learning rate", precision=6)
batch = gr.Slider(32, 256, value=128, step=32, label="batch_size")
hidden = gr.Slider(64, 512, value=256, step=64, label="hidden units")
dropout = gr.Slider(0.0, 0.6, value=0.2, step=0.05, label="dropout")
train_subset = gr.Slider(1000, 60000, value=12000, step=1000, label="train_subset (speed control)")
train_btn = gr.Button("Train")
with gr.Column():
summary = gr.Markdown()
loss_plot = gr.Plot(label="Loss curves")
acc_plot = gr.Plot(label="Accuracy curves")
log_csv = gr.Textbox(label="Epoch log (CSV)", lines=10)
train_btn.click(
fn=train_mnist,
inputs=[epochs, lr, batch, hidden, dropout, train_subset],
outputs=[summary, loss_plot, acc_plot, log_csv],
)
demo.queue()
if __name__ == "__main__":
demo.launch()