|
|
from fastapi import FastAPI, File, UploadFile |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from PIL import Image |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
|
import io |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
model_name = "dwililiya/food101-model-classification" |
|
|
extractor = AutoFeatureExtractor.from_pretrained(model_name) |
|
|
model = AutoModelForImageClassification.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
calorie_map = { |
|
|
"dal": "150-200 kcal per bowl", |
|
|
"bhat": "300-400 kcal per plate", |
|
|
"momo": "300-500 kcal (10 pcs)", |
|
|
"sel roti": "150-250 kcal each", |
|
|
"default": "N/A" |
|
|
} |
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(file: UploadFile = File(...)): |
|
|
try: |
|
|
|
|
|
image = Image.open(io.BytesIO(await file.read())).convert("RGB") |
|
|
inputs = extractor(images=image, return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=1) |
|
|
pred_id = probs.argmax(-1).item() |
|
|
confidence = probs[0][pred_id].item() |
|
|
label = model.config.id2label[pred_id].lower() |
|
|
|
|
|
|
|
|
calories = calorie_map.get(label, calorie_map["default"]) |
|
|
|
|
|
return { |
|
|
"food": label, |
|
|
"calories": calories, |
|
|
"confidence": round(confidence, 3) |
|
|
} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|