import torch from fastapi import FastAPI import torchvision.transforms as transforms from pydantic import BaseModel from PIL import Image import io import base64 # Load model model = torch.load("rice-recognizer-vgg16-v1.pkl") model.eval() app = FastAPI() class ImageData(BaseModel): image: str transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) @app.post("/predict/") async def predict(data: ImageData): img_bytes = io.BytesIO(base64.b64decode(data.image)) img = Image.open(img_bytes).convert("RGB") img_tensor = transform(img).unsqueeze(0) with torch.no_grad(): outputs = model(img_tensor) _, predicted = torch.max(outputs, 1) return {"prediction": predicted.item()}