Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| import io, base64 | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image, ImageOps | |
| # ---------------------------------------- | |
| # 1. Create the FastAPI app FIRST | |
| # ---------------------------------------- | |
| app = FastAPI() | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ---------------------------------------- | |
| # 2. Serve index.html (no static folder needed) | |
| # ---------------------------------------- | |
| def serve_home(): | |
| return FileResponse("index.html") | |
| # ---------------------------------------- | |
| # 3. Your model definition | |
| # ---------------------------------------- | |
| class LeNet5(nn.Module): | |
| def __init__(self, num_classes=10): | |
| super().__init__() | |
| self.layer1 = nn.Sequential( | |
| nn.Conv2d(1, 6, 5), | |
| nn.BatchNorm2d(6), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2) | |
| ) | |
| self.layer2 = nn.Sequential( | |
| nn.Conv2d(6, 16, 5), | |
| nn.BatchNorm2d(16), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2) | |
| ) | |
| self.flatten = nn.Flatten() | |
| self.fc = nn.Linear(16 * 5 * 5, 120) | |
| self.fc1 = nn.Linear(120, 84) | |
| self.fc2 = nn.Linear(84, num_classes) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.flatten(x) | |
| x = self.relu(self.fc(x)) | |
| x = self.relu(self.fc1(x)) | |
| return self.fc2(x) | |
| model = LeNet5() | |
| model.load_state_dict(torch.load("lenet5.pth", map_location="cpu")) | |
| model.eval() | |
| # ---------------------------------------- | |
| # 4. Predict route (base64 image) | |
| # ---------------------------------------- | |
| from pydantic import BaseModel | |
| class CanvasImage(BaseModel): | |
| image: str | |
| async def predict_digit(data: CanvasImage): | |
| img_data = data.image.split(",")[1] | |
| img_bytes = base64.b64decode(img_data) | |
| image = Image.open(io.BytesIO(img_bytes)).convert("L") | |
| image = image.resize((28, 28)) | |
| image = ImageOps.invert(image) | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| img = transform(image).unsqueeze(0) | |
| img = torch.nn.functional.pad(img, (2,2,2,2)) | |
| with torch.no_grad(): | |
| output = model(img) | |
| pred = output.argmax(1).item() | |
| return {"prediction": pred} | |