from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse import io, base64 import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image, ImageOps # ---------------------------------------- # 1. Create the FastAPI app FIRST # ---------------------------------------- app = FastAPI() from fastapi.middleware.cors import CORSMiddleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ---------------------------------------- # 2. Serve index.html (no static folder needed) # ---------------------------------------- @app.get("/") def serve_home(): return FileResponse("index.html") # ---------------------------------------- # 3. Your model definition # ---------------------------------------- class LeNet5(nn.Module): def __init__(self, num_classes=10): super().__init__() self.layer1 = nn.Sequential( nn.Conv2d(1, 6, 5), nn.BatchNorm2d(6), nn.ReLU(), nn.MaxPool2d(2) ) self.layer2 = nn.Sequential( nn.Conv2d(6, 16, 5), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2) ) self.flatten = nn.Flatten() self.fc = nn.Linear(16 * 5 * 5, 120) self.fc1 = nn.Linear(120, 84) self.fc2 = nn.Linear(84, num_classes) self.relu = nn.ReLU() def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.flatten(x) x = self.relu(self.fc(x)) x = self.relu(self.fc1(x)) return self.fc2(x) model = LeNet5() model.load_state_dict(torch.load("lenet5.pth", map_location="cpu")) model.eval() # ---------------------------------------- # 4. Predict route (base64 image) # ---------------------------------------- from pydantic import BaseModel class CanvasImage(BaseModel): image: str @app.post("/predict") async def predict_digit(data: CanvasImage): img_data = data.image.split(",")[1] img_bytes = base64.b64decode(img_data) image = Image.open(io.BytesIO(img_bytes)).convert("L") image = image.resize((28, 28)) image = ImageOps.invert(image) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) img = transform(image).unsqueeze(0) img = torch.nn.functional.pad(img, (2,2,2,2)) with torch.no_grad(): output = model(img) pred = output.argmax(1).item() return {"prediction": pred}