arpit-gour02's picture
Update app.py
de6b24b verified
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
import io, base64
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image, ImageOps
# ----------------------------------------
# 1. Create the FastAPI app FIRST
# ----------------------------------------
app = FastAPI()
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ----------------------------------------
# 2. Serve index.html (no static folder needed)
# ----------------------------------------
@app.get("/")
def serve_home():
return FileResponse("index.html")
# ----------------------------------------
# 3. Your model definition
# ----------------------------------------
class LeNet5(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 6, 5),
nn.BatchNorm2d(6),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(6, 16, 5),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.flatten = nn.Flatten()
self.fc = nn.Linear(16 * 5 * 5, 120)
self.fc1 = nn.Linear(120, 84)
self.fc2 = nn.Linear(84, num_classes)
self.relu = nn.ReLU()
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.flatten(x)
x = self.relu(self.fc(x))
x = self.relu(self.fc1(x))
return self.fc2(x)
model = LeNet5()
model.load_state_dict(torch.load("lenet5.pth", map_location="cpu"))
model.eval()
# ----------------------------------------
# 4. Predict route (base64 image)
# ----------------------------------------
from pydantic import BaseModel
class CanvasImage(BaseModel):
image: str
@app.post("/predict")
async def predict_digit(data: CanvasImage):
img_data = data.image.split(",")[1]
img_bytes = base64.b64decode(img_data)
image = Image.open(io.BytesIO(img_bytes)).convert("L")
image = image.resize((28, 28))
image = ImageOps.invert(image)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
img = transform(image).unsqueeze(0)
img = torch.nn.functional.pad(img, (2,2,2,2))
with torch.no_grad():
output = model(img)
pred = output.argmax(1).item()
return {"prediction": pred}