dds3579 commited on
Commit
6b6e702
·
verified ·
1 Parent(s): afd54c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision import transforms
6
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
7
+ import io
8
+
9
+ # FastAPI app
10
+ app = FastAPI()
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_credentials=True,
15
+ allow_methods=["*"],
16
+ allow_headers=["*"],
17
+ )
18
+
19
+ # Load model + processor
20
+ model_name = "dwililiya/food101-model-classification"
21
+ extractor = AutoFeatureExtractor.from_pretrained(model_name)
22
+ model = AutoModelForImageClassification.from_pretrained(model_name)
23
+
24
+ # Device check (RTX 4050 will be used if running locally)
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ model.to(device)
27
+
28
+ # Nepali food calorie ranges (demo mapping)
29
+ calorie_map = {
30
+ "dal": "150-200 kcal per bowl",
31
+ "bhat": "300-400 kcal per plate",
32
+ "momo": "300-500 kcal (10 pcs)",
33
+ "sel roti": "150-250 kcal each",
34
+ "default": "N/A"
35
+ }
36
+
37
+ @app.post("/predict")
38
+ async def predict(file: UploadFile = File(...)):
39
+ try:
40
+ # Load image
41
+ image = Image.open(io.BytesIO(await file.read())).convert("RGB")
42
+ inputs = extractor(images=image, return_tensors="pt").to(device)
43
+
44
+ # Predict
45
+ with torch.no_grad():
46
+ outputs = model(**inputs)
47
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
48
+ pred_id = probs.argmax(-1).item()
49
+ confidence = probs[0][pred_id].item()
50
+ label = model.config.id2label[pred_id].lower()
51
+
52
+ # Map to Nepali calorie range (fallback default)
53
+ calories = calorie_map.get(label, calorie_map["default"])
54
+
55
+ return {
56
+ "food": label,
57
+ "calories": calories,
58
+ "confidence": round(confidence, 3)
59
+ }
60
+ except Exception as e:
61
+ return {"error": str(e)}