rice_model2 / model.py
rawhaturrafin's picture
Upload 2 files
7203481 verified
raw
history blame contribute delete
841 Bytes
import torch
from fastapi import FastAPI
import torchvision.transforms as transforms
from pydantic import BaseModel
from PIL import Image
import io
import base64
# Load model
model = torch.load("rice-recognizer-vgg16-v1.pkl")
model.eval()
app = FastAPI()
class ImageData(BaseModel):
image: str
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
@app.post("/predict/")
async def predict(data: ImageData):
img_bytes = io.BytesIO(base64.b64decode(data.image))
img = Image.open(img_bytes).convert("RGB")
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
return {"prediction": predicted.item()}