food_api / app.py
dds3579's picture
Create app.py
6b6e702 verified
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import io
# FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model + processor
model_name = "dwililiya/food101-model-classification"
extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
# Device check (RTX 4050 will be used if running locally)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Nepali food calorie ranges (demo mapping)
calorie_map = {
"dal": "150-200 kcal per bowl",
"bhat": "300-400 kcal per plate",
"momo": "300-500 kcal (10 pcs)",
"sel roti": "150-250 kcal each",
"default": "N/A"
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
# Load image
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
inputs = extractor(images=image, return_tensors="pt").to(device)
# Predict
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
pred_id = probs.argmax(-1).item()
confidence = probs[0][pred_id].item()
label = model.config.id2label[pred_id].lower()
# Map to Nepali calorie range (fallback default)
calories = calorie_map.get(label, calorie_map["default"])
return {
"food": label,
"calories": calories,
"confidence": round(confidence, 3)
}
except Exception as e:
return {"error": str(e)}