File size: 1,900 Bytes
6b6e702 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import io
# FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model + processor
model_name = "dwililiya/food101-model-classification"
extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
# Device check (RTX 4050 will be used if running locally)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Nepali food calorie ranges (demo mapping)
calorie_map = {
"dal": "150-200 kcal per bowl",
"bhat": "300-400 kcal per plate",
"momo": "300-500 kcal (10 pcs)",
"sel roti": "150-250 kcal each",
"default": "N/A"
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
# Load image
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
inputs = extractor(images=image, return_tensors="pt").to(device)
# Predict
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
pred_id = probs.argmax(-1).item()
confidence = probs[0][pred_id].item()
label = model.config.id2label[pred_id].lower()
# Map to Nepali calorie range (fallback default)
calories = calorie_map.get(label, calorie_map["default"])
return {
"food": label,
"calories": calories,
"confidence": round(confidence, 3)
}
except Exception as e:
return {"error": str(e)}
|