import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader from PIL import Image import gradio as gr import os # Device (CPU for compatibility with Hugging Face Spaces) device = torch.device("cpu") # Transform for training and uploaded images transform = transforms.Compose([ transforms.Resize((6, 6)), transforms.ToTensor() ]) # Define a convolution block def conv(ic, oc): ks=3 return nn.Sequential( nn.Conv2d(ic, oc, stride=2, kernel_size=ks, padding=ks//2), nn.BatchNorm2d(oc) ) # CNN Model class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.model = nn.Sequential( conv(1, 8), nn.Dropout2d(0.25), nn.ReLU(), conv(8, 16), nn.Dropout2d(0.25), nn.ReLU(), conv(16, 10), nn.Flatten() ) def forward(self, x): return self.model(x) # Training function def train_model(): train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) batch_size = 36 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) model = SimpleCNN().to(device) optimizer = optim.Adam(model.parameters(), lr=0.005) criterion = nn.CrossEntropyLoss() model.train() for epoch in range(3): # Keep it light for HF Spaces for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() return model # Load or train model model_path = "mnist_cnn.pt" if os.path.exists(model_path): model = SimpleCNN().to(device) model.load_state_dict(torch.load(model_path, map_location=device)) else: model = train_model() torch.save(model.state_dict(), model_path) # Prediction function def predict(img): if isinstance(img, Image.Image): img = img.convert("L") else: return "Invalid image" x = transform(img).unsqueeze(0).to(device) # Shape: [1,1,8,8] model.eval() with torch.no_grad(): output = model(x) pred = torch.argmax(output, dim=1).item() return f"Predicted digit: {pred}" # Gradio Interface demo = gr.Interface( fn=predict, inputs=gr.Image(), outputs="text", title="MNIST Digit Classifier (6x6 CNN)", description="Upload or draw a digit to classify it using a lightweight CNN trained on MNIST resized to 8×8." ) if __name__ == "__main__": demo.launch()