| | import torch |
| | from fastapi import FastAPI |
| | import torchvision.transforms as transforms |
| | from pydantic import BaseModel |
| | from PIL import Image |
| | import io |
| | import base64 |
| |
|
| | |
| | model = torch.load("rice-recognizer-vgg16-v1.pkl") |
| | model.eval() |
| |
|
| | app = FastAPI() |
| |
|
| | class ImageData(BaseModel): |
| | image: str |
| |
|
| | transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| | ]) |
| |
|
| | @app.post("/predict/") |
| | async def predict(data: ImageData): |
| | img_bytes = io.BytesIO(base64.b64decode(data.image)) |
| | img = Image.open(img_bytes).convert("RGB") |
| | img_tensor = transform(img).unsqueeze(0) |
| | with torch.no_grad(): |
| | outputs = model(img_tensor) |
| | _, predicted = torch.max(outputs, 1) |
| | return {"prediction": predicted.item()} |
| |
|