Unnatrathi commited on
Commit
c43c869
Β·
verified Β·
1 Parent(s): bbf568b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -65
app.py CHANGED
@@ -178,7 +178,15 @@ def _run_inference(image: Image.Image, max_new_tokens: int) -> str:
178
  "role": "user",
179
  "content": [
180
  {"type": "image"},
181
- {"type": "text", "text": "Analyze this food image for ingredients and calories."},
 
 
 
 
 
 
 
 
182
  ],
183
  }
184
  ]
@@ -193,9 +201,7 @@ def _run_inference(image: Image.Image, max_new_tokens: int) -> str:
193
  )
194
  device = next(_model.parameters()).device
195
  if "pixel_values" in inputs and inputs["pixel_values"] is not None:
196
- inputs["pixel_values"] = inputs["pixel_values"].to(
197
- torch.bfloat16 if torch.cuda.is_available() else torch.float32
198
- )
199
  inputs = {k: v.to(device) for k, v in inputs.items()}
200
 
201
  with torch.inference_mode():
@@ -210,39 +216,65 @@ def _run_inference(image: Image.Image, max_new_tokens: int) -> str:
210
  return _processor.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
211
 
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
 
214
 
215
- def _get_nutrition_from_api(ingredients_text: str) -> dict:
216
  try:
217
- # Extract the most important food word from the description
218
- import re
219
- # Common food words to search for
220
- food_keywords = [
221
- "banana", "apple", "orange", "mango", "rice", "chicken", "egg",
222
- "bread", "milk", "yogurt", "cheese", "pizza", "burger", "salad",
223
- "pasta", "fish", "beef", "pork", "potato", "tomato", "onion",
224
- "carrot", "broccoli", "spinach", "lemon", "grape", "strawberry",
225
- "chocolate", "cake", "cookie", "coffee", "tea", "juice", "soup",
226
- "sandwich", "taco", "noodle", "tofu", "paneer", "dal", "roti",
227
- "biryani", "curry", "samosa", "idli", "dosa"
228
- ]
229
-
230
- text_lower = ingredients_text.lower()
231
  food_query = None
232
-
233
- # Find first matching food keyword
234
- for keyword in food_keywords:
235
  if keyword in text_lower:
236
  food_query = keyword
237
  break
238
-
239
- # If no keyword matched, use first 2 words of description
240
  if not food_query:
241
- words = re.findall(r'\b[a-zA-Z]{4,}\b', ingredients_text)
242
  food_query = words[0] if words else ingredients_text[:20]
243
 
244
- logger.info(f"Searching nutrition for: {food_query}")
245
-
246
  response = req_lib.get(
247
  "https://world.openfoodfacts.org/cgi/search.pl",
248
  params={
@@ -250,62 +282,66 @@ def _get_nutrition_from_api(ingredients_text: str) -> dict:
250
  "search_simple": 1,
251
  "action": "process",
252
  "json": 1,
253
- "page_size": 3,
254
  "fields": "product_name,nutriments",
255
  },
256
- timeout=10,
257
  )
258
  response.raise_for_status()
259
- data = response.json()
260
- products = data.get("products", [])
261
-
262
- if not products:
263
- logger.warning(f"No products found for: {food_query}")
264
- return {}
265
 
266
- # Find first product with complete nutrition data
267
  for product in products:
268
- nutriments = product.get("nutriments", {})
269
- calories = nutriments.get("energy-kcal_100g", 0)
270
- if calories and calories > 0:
 
 
 
 
 
271
  return {
272
- "calories": round(float(calories), 1),
273
- "protein_g": round(float(nutriments.get("proteins_100g", 0) or 0), 1),
274
- "carbs_g": round(float(nutriments.get("carbohydrates_100g", 0) or 0), 1),
275
- "fat_g": round(float(nutriments.get("fat_100g", 0) or 0), 1),
276
- "fibre_g": round(float(nutriments.get("fiber_100g", 0) or 0), 1),
277
  }
278
-
279
- return {}
280
-
281
  except Exception as e:
282
- logger.warning(f"Nutrition API failed: {e}")
283
- return {}
284
 
 
 
 
 
 
 
 
 
285
 
286
 
287
  def _parse_response(raw: str) -> dict:
288
- # FIX: pre-initialize all nutrition keys to None so KeyError never happens
289
  result = {
290
- "ingredients": "",
291
  "portion_notes": "",
292
- "raw_text": raw,
293
- "calories": None, # ← these were missing
294
- "protein_g": None,
295
- "carbs_g": None,
296
- "fat_g": None,
297
- "fibre_g": None,
298
  }
299
 
300
- # Try structured CaLoRAify format first
301
  if "Ingredients detected:" in raw:
302
  ing_start = raw.index("Ingredients detected:") + len("Ingredients detected:")
303
- ing_end = raw.index(".", ing_start) if "." in raw[ing_start:] else len(raw)
 
304
  result["ingredients"] = raw[ing_start:ing_end].strip()
305
 
306
  if "Portion Analysis:" in raw:
307
  pa_start = raw.index("Portion Analysis:") + len("Portion Analysis:")
308
- pa_end = raw.index(".", pa_start) if "." in raw[pa_start:] else len(raw)
 
309
  result["portion_notes"] = raw[pa_start:pa_end].strip()
310
 
311
  if "JSON Summary:" in raw:
@@ -326,21 +362,20 @@ def _parse_response(raw: str) -> dict:
326
  except json.JSONDecodeError:
327
  pass
328
 
329
- # Fallback: model gave plain description
330
  if not result["ingredients"]:
331
- result["ingredients"] = raw.strip()[:200]
332
  result["portion_notes"] = "Portion estimated from image."
333
 
334
- # If no calories yet, call Nutritionix/OpenFoodFacts
335
  if result["calories"] is None and result["ingredients"]:
336
- logger.info(f"Calling nutrition API for: {result['ingredients'][:80]}")
337
  nutrition = _get_nutrition_from_api(result["ingredients"])
338
  if nutrition:
339
  result.update(nutrition)
340
 
341
  return result
342
 
343
-
344
  # ── Endpoints ─────────────────────────────────────────────────────────────────
345
  @app.get("/health")
346
  def health():
 
178
  "role": "user",
179
  "content": [
180
  {"type": "image"},
181
+ {
182
+ "type": "text",
183
+ "text": (
184
+ "Look at this food photo and tell me: "
185
+ "1) What food or dish do you see? "
186
+ "2) List all visible ingredients. "
187
+ "Start your answer with: 'Ingredients detected:'"
188
+ ),
189
+ },
190
  ],
191
  }
192
  ]
 
201
  )
202
  device = next(_model.parameters()).device
203
  if "pixel_values" in inputs and inputs["pixel_values"] is not None:
204
+ inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
 
 
205
  inputs = {k: v.to(device) for k, v in inputs.items()}
206
 
207
  with torch.inference_mode():
 
216
  return _processor.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
217
 
218
 
219
+ def _get_nutrition_from_api(ingredients_text: str) -> dict:
220
+ """Nutrition lookup β€” Open Food Facts + hardcoded fallback table."""
221
+
222
+ # ── Hardcoded fallback for 30 common foods ─────────────────────────────
223
+ NUTRITION_TABLE = {
224
+ "banana": {"calories": 89, "protein_g": 1.1, "carbs_g": 23.0, "fat_g": 0.3, "fibre_g": 2.6},
225
+ "apple": {"calories": 72, "protein_g": 0.4, "carbs_g": 19.0, "fat_g": 0.2, "fibre_g": 2.4},
226
+ "orange": {"calories": 62, "protein_g": 1.2, "carbs_g": 15.4, "fat_g": 0.2, "fibre_g": 3.1},
227
+ "mango": {"calories": 99, "protein_g": 1.4, "carbs_g": 25.0, "fat_g": 0.6, "fibre_g": 2.6},
228
+ "grape": {"calories": 69, "protein_g": 0.7, "carbs_g": 18.1, "fat_g": 0.2, "fibre_g": 0.9},
229
+ "strawberry": {"calories": 32, "protein_g": 0.7, "carbs_g": 7.7, "fat_g": 0.3, "fibre_g": 2.0},
230
+ "watermelon": {"calories": 30, "protein_g": 0.6, "carbs_g": 7.6, "fat_g": 0.2, "fibre_g": 0.4},
231
+ "rice": {"calories": 206, "protein_g": 4.3, "carbs_g": 45.0, "fat_g": 0.4, "fibre_g": 0.6},
232
+ "chicken": {"calories": 239, "protein_g": 27.0,"carbs_g": 0.0, "fat_g": 14.0,"fibre_g": 0.0},
233
+ "egg": {"calories": 155, "protein_g": 13.0,"carbs_g": 1.1, "fat_g": 11.0,"fibre_g": 0.0},
234
+ "bread": {"calories": 265, "protein_g": 9.0, "carbs_g": 49.0, "fat_g": 3.2, "fibre_g": 2.7},
235
+ "milk": {"calories": 61, "protein_g": 3.2, "carbs_g": 4.8, "fat_g": 3.3, "fibre_g": 0.0},
236
+ "cheese": {"calories": 402, "protein_g": 25.0,"carbs_g": 1.3, "fat_g": 33.0,"fibre_g": 0.0},
237
+ "pizza": {"calories": 266, "protein_g": 11.0,"carbs_g": 33.0, "fat_g": 10.0,"fibre_g": 2.3},
238
+ "burger": {"calories": 295, "protein_g": 17.0,"carbs_g": 24.0, "fat_g": 14.0,"fibre_g": 1.3},
239
+ "pasta": {"calories": 220, "protein_g": 8.1, "carbs_g": 43.0, "fat_g": 1.3, "fibre_g": 2.5},
240
+ "fish": {"calories": 136, "protein_g": 20.0,"carbs_g": 0.0, "fat_g": 6.0, "fibre_g": 0.0},
241
+ "potato": {"calories": 77, "protein_g": 2.0, "carbs_g": 17.0, "fat_g": 0.1, "fibre_g": 2.2},
242
+ "broccoli": {"calories": 34, "protein_g": 2.8, "carbs_g": 6.6, "fat_g": 0.4, "fibre_g": 2.6},
243
+ "carrot": {"calories": 41, "protein_g": 0.9, "carbs_g": 10.0, "fat_g": 0.2, "fibre_g": 2.8},
244
+ "tomato": {"calories": 18, "protein_g": 0.9, "carbs_g": 3.9, "fat_g": 0.2, "fibre_g": 1.2},
245
+ "salad": {"calories": 20, "protein_g": 1.8, "carbs_g": 3.6, "fat_g": 0.3, "fibre_g": 2.0},
246
+ "sandwich": {"calories": 250, "protein_g": 12.0,"carbs_g": 33.0, "fat_g": 7.0, "fibre_g": 2.5},
247
+ "soup": {"calories": 71, "protein_g": 3.8, "carbs_g": 8.0, "fat_g": 2.0, "fibre_g": 1.5},
248
+ "chocolate": {"calories": 546, "protein_g": 5.0, "carbs_g": 60.0, "fat_g": 31.0,"fibre_g": 7.0},
249
+ "cake": {"calories": 347, "protein_g": 5.0, "carbs_g": 55.0, "fat_g": 12.0,"fibre_g": 1.0},
250
+ "dal": {"calories": 116, "protein_g": 9.0, "carbs_g": 20.0, "fat_g": 0.4, "fibre_g": 8.0},
251
+ "roti": {"calories": 297, "protein_g": 9.9, "carbs_g": 61.0, "fat_g": 1.7, "fibre_g": 1.9},
252
+ "biryani": {"calories": 200, "protein_g": 8.0, "carbs_g": 30.0, "fat_g": 6.0, "fibre_g": 1.5},
253
+ "paneer": {"calories": 265, "protein_g": 18.0,"carbs_g": 3.4, "fat_g": 20.0,"fibre_g": 0.0},
254
+ "idli": {"calories": 58, "protein_g": 2.0, "carbs_g": 12.0, "fat_g": 0.4, "fibre_g": 0.5},
255
+ "dosa": {"calories": 168, "protein_g": 3.7, "carbs_g": 30.0, "fat_g": 3.7, "fibre_g": 1.5},
256
+ "samosa": {"calories": 262, "protein_g": 3.5, "carbs_g": 28.0, "fat_g": 15.0,"fibre_g": 2.0},
257
+ "noodle": {"calories": 138, "protein_g": 4.5, "carbs_g": 25.0, "fat_g": 2.0, "fibre_g": 1.8},
258
+ "coffee": {"calories": 2, "protein_g": 0.3, "carbs_g": 0.0, "fat_g": 0.0, "fibre_g": 0.0},
259
+ "omelette": {"calories": 154, "protein_g": 11.0,"carbs_g": 0.4, "fat_g": 12.0,"fibre_g": 0.0},
260
+ "yogurt": {"calories": 59, "protein_g": 10.0,"carbs_g": 3.6, "fat_g": 0.4, "fibre_g": 0.0},
261
+ }
262
 
263
+ text_lower = ingredients_text.lower()
264
 
265
+ # ── Step 1: Try Open Food Facts API ───────────────────────────────────
266
  try:
267
+ import re as _re
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  food_query = None
269
+ for keyword in NUTRITION_TABLE.keys():
 
 
270
  if keyword in text_lower:
271
  food_query = keyword
272
  break
 
 
273
  if not food_query:
274
+ words = _re.findall(r'\b[a-zA-Z]{4,}\b', ingredients_text)
275
  food_query = words[0] if words else ingredients_text[:20]
276
 
277
+ logger.info(f"OpenFoodFacts query: {food_query}")
 
278
  response = req_lib.get(
279
  "https://world.openfoodfacts.org/cgi/search.pl",
280
  params={
 
282
  "search_simple": 1,
283
  "action": "process",
284
  "json": 1,
285
+ "page_size": 5,
286
  "fields": "product_name,nutriments",
287
  },
288
+ timeout=8,
289
  )
290
  response.raise_for_status()
291
+ products = response.json().get("products", [])
 
 
 
 
 
292
 
 
293
  for product in products:
294
+ n = product.get("nutriments", {})
295
+ cal = n.get("energy-kcal_100g") or n.get("energy-kcal") or 0
296
+ try:
297
+ cal = float(cal)
298
+ except (TypeError, ValueError):
299
+ cal = 0
300
+ if cal > 0:
301
+ logger.info(f"OpenFoodFacts found: {product.get('product_name')} = {cal} kcal")
302
  return {
303
+ "calories": round(cal, 1),
304
+ "protein_g": round(float(n.get("proteins_100g", 0) or 0), 1),
305
+ "carbs_g": round(float(n.get("carbohydrates_100g", 0) or 0), 1),
306
+ "fat_g": round(float(n.get("fat_100g", 0) or 0), 1),
307
+ "fibre_g": round(float(n.get("fiber_100g", 0) or 0), 1),
308
  }
 
 
 
309
  except Exception as e:
310
+ logger.warning(f"OpenFoodFacts failed: {e}")
 
311
 
312
+ # ── Step 2: Hardcoded table fallback ──────────────────────────────────
313
+ for food, values in NUTRITION_TABLE.items():
314
+ if food in text_lower:
315
+ logger.info(f"Using hardcoded nutrition for: {food}")
316
+ return values
317
+
318
+ logger.warning("No nutrition data found from any source")
319
+ return {}
320
 
321
 
322
  def _parse_response(raw: str) -> dict:
 
323
  result = {
324
+ "ingredients": "",
325
  "portion_notes": "",
326
+ "raw_text": raw,
327
+ "calories": None,
328
+ "protein_g": None,
329
+ "carbs_g": None,
330
+ "fat_g": None,
331
+ "fibre_g": None,
332
  }
333
 
334
+ # Try structured format
335
  if "Ingredients detected:" in raw:
336
  ing_start = raw.index("Ingredients detected:") + len("Ingredients detected:")
337
+ ing_end = raw.find(".", ing_start)
338
+ ing_end = ing_end if ing_end != -1 else len(raw)
339
  result["ingredients"] = raw[ing_start:ing_end].strip()
340
 
341
  if "Portion Analysis:" in raw:
342
  pa_start = raw.index("Portion Analysis:") + len("Portion Analysis:")
343
+ pa_end = raw.find(".", pa_start)
344
+ pa_end = pa_end if pa_end != -1 else len(raw)
345
  result["portion_notes"] = raw[pa_start:pa_end].strip()
346
 
347
  if "JSON Summary:" in raw:
 
362
  except json.JSONDecodeError:
363
  pass
364
 
365
+ # Fallback: use entire raw text as ingredient description
366
  if not result["ingredients"]:
367
+ result["ingredients"] = raw.strip()[:300]
368
  result["portion_notes"] = "Portion estimated from image."
369
 
370
+ # Call nutrition API if no calories yet
371
  if result["calories"] is None and result["ingredients"]:
372
+ logger.info(f"Looking up nutrition for: {result['ingredients'][:80]}")
373
  nutrition = _get_nutrition_from_api(result["ingredients"])
374
  if nutrition:
375
  result.update(nutrition)
376
 
377
  return result
378
 
 
379
  # ── Endpoints ─────────────────────────────────────────────────────────────────
380
  @app.get("/health")
381
  def health():