from fastapi import FastAPI, File, UploadFile from fastapi.middleware.cors import CORSMiddleware from PIL import Image import torch from torchvision import transforms from transformers import AutoFeatureExtractor, AutoModelForImageClassification import io # FastAPI app app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model + processor model_name = "dwililiya/food101-model-classification" extractor = AutoFeatureExtractor.from_pretrained(model_name) model = AutoModelForImageClassification.from_pretrained(model_name) # Device check (RTX 4050 will be used if running locally) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Nepali food calorie ranges (demo mapping) calorie_map = { "dal": "150-200 kcal per bowl", "bhat": "300-400 kcal per plate", "momo": "300-500 kcal (10 pcs)", "sel roti": "150-250 kcal each", "default": "N/A" } @app.post("/predict") async def predict(file: UploadFile = File(...)): try: # Load image image = Image.open(io.BytesIO(await file.read())).convert("RGB") inputs = extractor(images=image, return_tensors="pt").to(device) # Predict with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=1) pred_id = probs.argmax(-1).item() confidence = probs[0][pred_id].item() label = model.config.id2label[pred_id].lower() # Map to Nepali calorie range (fallback default) calories = calorie_map.get(label, calorie_map["default"]) return { "food": label, "calories": calories, "confidence": round(confidence, 3) } except Exception as e: return {"error": str(e)}