File size: 841 Bytes
7203481 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import torch
from fastapi import FastAPI
import torchvision.transforms as transforms
from pydantic import BaseModel
from PIL import Image
import io
import base64
# Load model
model = torch.load("rice-recognizer-vgg16-v1.pkl")
model.eval()
app = FastAPI()
class ImageData(BaseModel):
image: str
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
@app.post("/predict/")
async def predict(data: ImageData):
img_bytes = io.BytesIO(base64.b64decode(data.image))
img = Image.open(img_bytes).convert("RGB")
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
return {"prediction": predicted.item()}
|