Spaces:
Sleeping
Sleeping
| import base64 | |
| import io | |
| import os | |
| import re | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| import torch | |
| from fastapi import FastAPI, Header, HTTPException | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| from pydantic import BaseModel, Field | |
| from transformers import pipeline | |
| from calority_nutrition_model import load_nutrition_checkpoint, predict_nutrients | |
| from calority_scratch_model import image_to_tensor, load_checkpoint | |
| MODEL_ID = os.getenv("MODEL_ID", "nateraw/food") | |
| MODEL_DIR = os.getenv("MODEL_DIR", "") | |
| HF_MODEL_REPO_ID = os.getenv("HF_MODEL_REPO_ID", "") | |
| MODEL_TASK = os.getenv("MODEL_TASK", "classification") | |
| MODEL_API_KEY = os.getenv("MODEL_API_KEY", "") | |
| app = FastAPI(title="Calority Meal Model", version="0.1.0") | |
| class AnalyzeMealRequest(BaseModel): | |
| imageBase64: str = Field(min_length=1) | |
| mimeType: str = "image/jpeg" | |
| portionContext: str = "" | |
| class NutritionProfile: | |
| serving_g: int | |
| calories_100g: int | |
| protein_100g: float | |
| carbs_100g: float | |
| fat_100g: float | |
| NUTRITION = { | |
| "apple pie": NutritionProfile(140, 237, 1.9, 34.0, 11.0), | |
| "baby back ribs": NutritionProfile(220, 290, 20.0, 6.0, 21.0), | |
| "baklava": NutritionProfile(80, 428, 6.0, 54.0, 21.0), | |
| "beef carpaccio": NutritionProfile(120, 160, 22.0, 1.0, 7.0), | |
| "beef tartare": NutritionProfile(150, 190, 20.0, 2.0, 12.0), | |
| "beet salad": NutritionProfile(180, 90, 3.0, 12.0, 4.0), | |
| "bibimbap": NutritionProfile(450, 145, 6.0, 20.0, 4.0), | |
| "bread pudding": NutritionProfile(160, 220, 5.0, 32.0, 8.0), | |
| "breakfast burrito": NutritionProfile(280, 210, 10.0, 23.0, 9.0), | |
| "bruschetta": NutritionProfile(120, 190, 6.0, 25.0, 7.0), | |
| "caesar salad": NutritionProfile(220, 170, 8.0, 8.0, 12.0), | |
| "cannoli": NutritionProfile(90, 310, 7.0, 33.0, 16.0), | |
| "caprese salad": NutritionProfile(180, 170, 9.0, 5.0, 13.0), | |
| "carrot cake": NutritionProfile(120, 415, 4.0, 50.0, 22.0), | |
| "cheesecake": NutritionProfile(125, 321, 6.0, 26.0, 22.0), | |
| "chicken curry": NutritionProfile(300, 165, 13.0, 7.0, 9.0), | |
| "chicken quesadilla": NutritionProfile(250, 260, 14.0, 22.0, 13.0), | |
| "chicken wings": NutritionProfile(180, 290, 24.0, 1.0, 20.0), | |
| "chocolate cake": NutritionProfile(120, 371, 5.0, 53.0, 16.0), | |
| "club sandwich": NutritionProfile(260, 240, 13.0, 22.0, 12.0), | |
| "cup cakes": NutritionProfile(80, 305, 4.0, 47.0, 12.0), | |
| "donuts": NutritionProfile(80, 452, 5.0, 51.0, 25.0), | |
| "dumplings": NutritionProfile(220, 190, 9.0, 26.0, 6.0), | |
| "edamame": NutritionProfile(160, 121, 11.0, 9.0, 5.0), | |
| "falafel": NutritionProfile(180, 333, 13.0, 32.0, 18.0), | |
| "filet mignon": NutritionProfile(180, 250, 26.0, 0.0, 16.0), | |
| "fish and chips": NutritionProfile(350, 230, 11.0, 24.0, 10.0), | |
| "french fries": NutritionProfile(150, 312, 3.4, 41.0, 15.0), | |
| "fried rice": NutritionProfile(300, 165, 5.0, 25.0, 5.0), | |
| "greek salad": NutritionProfile(220, 110, 4.0, 7.0, 8.0), | |
| "grilled cheese sandwich": NutritionProfile(180, 350, 12.0, 28.0, 21.0), | |
| "hamburger": NutritionProfile(250, 295, 17.0, 24.0, 14.0), | |
| "hot dog": NutritionProfile(150, 290, 11.0, 24.0, 17.0), | |
| "hummus": NutritionProfile(120, 166, 8.0, 14.0, 10.0), | |
| "lasagna": NutritionProfile(320, 170, 10.0, 16.0, 8.0), | |
| "macaroni and cheese": NutritionProfile(250, 164, 7.0, 20.0, 6.0), | |
| "omelette": NutritionProfile(180, 154, 11.0, 1.0, 12.0), | |
| "pancakes": NutritionProfile(220, 227, 6.0, 28.0, 10.0), | |
| "pizza": NutritionProfile(250, 266, 11.0, 33.0, 10.0), | |
| "ramen": NutritionProfile(500, 90, 4.0, 12.0, 3.0), | |
| "samosa": NutritionProfile(150, 260, 6.0, 30.0, 13.0), | |
| "sashimi": NutritionProfile(160, 130, 22.0, 0.0, 4.0), | |
| "spaghetti bolognese": NutritionProfile(350, 150, 8.0, 20.0, 5.0), | |
| "steak": NutritionProfile(220, 250, 26.0, 0.0, 15.0), | |
| "sushi": NutritionProfile(220, 145, 7.0, 24.0, 2.0), | |
| "tacos": NutritionProfile(220, 210, 10.0, 21.0, 10.0), | |
| "waffles": NutritionProfile(180, 291, 8.0, 33.0, 14.0), | |
| } | |
| DEFAULT_PROFILE = NutritionProfile(250, 180, 8.0, 20.0, 6.0) | |
| def classifier(): | |
| return pipeline("image-classification", model=MODEL_ID) | |
| def resolved_model_dir() -> str: | |
| if MODEL_DIR: | |
| return MODEL_DIR | |
| if HF_MODEL_REPO_ID: | |
| return snapshot_download(repo_id=HF_MODEL_REPO_ID) | |
| return "" | |
| def scratch_classifier(): | |
| model_dir = resolved_model_dir() | |
| if not model_dir or MODEL_TASK != "classification": | |
| return None | |
| model_path = Path(model_dir) | |
| if not (model_path / "model.pt").exists(): | |
| return None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model, labels = load_checkpoint(model_path, device=device) | |
| return model, labels, device | |
| def nutrition_regressor(): | |
| model_dir = resolved_model_dir() | |
| if not model_dir or MODEL_TASK != "nutrition-regression": | |
| return None | |
| model_path = Path(model_dir) | |
| if not (model_path / "model.pt").exists() or not (model_path / "target_stats.json").exists(): | |
| return None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model, target_mean, target_std = load_nutrition_checkpoint(model_path, device=device) | |
| return model, target_mean, target_std, device | |
| def classify_image(image: Image.Image) -> list[dict]: | |
| scratch = scratch_classifier() | |
| if scratch is None: | |
| return classifier()(image, top_k=3) | |
| model, labels, device = scratch | |
| tensor = image_to_tensor(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| probabilities = torch.softmax(model(tensor), dim=1)[0] | |
| top_scores, top_indices = torch.topk(probabilities, k=min(3, len(labels))) | |
| return [ | |
| {"label": labels[index.item()], "score": score.item()} | |
| for score, index in zip(top_scores, top_indices) | |
| ] | |
| def analyze_nutrients(image: Image.Image, portion_context: str) -> dict | None: | |
| regressor = nutrition_regressor() | |
| if regressor is None: | |
| return None | |
| model, target_mean, target_std, device = regressor | |
| nutrients = predict_nutrients(model, image, target_mean, target_std, device) | |
| calories = round(nutrients["total_calories"]) | |
| mass = round(nutrients["total_mass"]) | |
| fat = round(nutrients["total_fat"]) | |
| carbs = round(nutrients["total_carb"]) | |
| protein = round(nutrients["total_protein"]) | |
| macro_calories = (protein * 4) + (carbs * 4) + (fat * 9) | |
| macro_gap = abs(macro_calories - calories) | |
| confidence = "medium" if calories > 0 else "low" | |
| confidence_note = ( | |
| f"Estimated from image using Calority nutrition regression. Macro calories differ by {macro_gap} kcal." | |
| ) | |
| if portion_context: | |
| confidence_note = f"{confidence_note} User context: {portion_context}." | |
| return { | |
| "name": "Food Plate", | |
| "calories": calories, | |
| "protein": protein, | |
| "carbs": carbs, | |
| "fat": fat, | |
| "ingredients": [ | |
| f"Estimated total mass {mass}g", | |
| f"Protein {protein}g - {protein * 4} kcal", | |
| f"Carbs {carbs}g - {carbs * 4} kcal", | |
| f"Fat {fat}g - {fat * 9} kcal", | |
| ], | |
| "confidence": confidence, | |
| "confidenceNote": confidence_note, | |
| "nutritionDetails": { | |
| "totalMass": mass, | |
| "calories": calories, | |
| "protein": protein, | |
| "carbs": carbs, | |
| "fat": fat, | |
| "macroCalories": macro_calories, | |
| }, | |
| } | |
| def require_auth(authorization: str | None) -> None: | |
| if not MODEL_API_KEY: | |
| return | |
| expected = f"Bearer {MODEL_API_KEY}" | |
| if authorization != expected: | |
| raise HTTPException(status_code=401, detail="Invalid model service token") | |
| def decode_image(image_base64: str) -> Image.Image: | |
| try: | |
| raw = base64.b64decode(image_base64) | |
| return Image.open(io.BytesIO(raw)).convert("RGB") | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail="Invalid imageBase64") from exc | |
| def normalize_label(label: str) -> str: | |
| return label.lower().replace("_", " ").replace("-", " ").strip() | |
| def grams_from_context(portion_context: str, fallback: int) -> int: | |
| match = re.search(r"(\d{2,4})\s*(g|gram|grams)\b", portion_context.lower()) | |
| if match: | |
| return max(30, min(1200, int(match.group(1)))) | |
| return fallback | |
| def nutrition_for(label: str, grams: int) -> dict: | |
| profile = NUTRITION.get(label, DEFAULT_PROFILE) | |
| factor = grams / 100 | |
| calories = round(profile.calories_100g * factor) | |
| protein = round(profile.protein_100g * factor) | |
| carbs = round(profile.carbs_100g * factor) | |
| fat = round(profile.fat_100g * factor) | |
| return { | |
| "calories": calories, | |
| "protein": protein, | |
| "carbs": carbs, | |
| "fat": fat, | |
| "ingredient": f"{label.title()} estimated {grams}g - {calories} kcal", | |
| } | |
| def confidence_from(score: float) -> tuple[str, str]: | |
| if score >= 0.75: | |
| return "high", "" | |
| if score >= 0.45: | |
| return "medium", "The food is visible, but the model is not fully certain." | |
| return "low", "The model could not confidently identify the meal." | |
| def health() -> dict: | |
| if nutrition_regressor(): | |
| model_source = f"nutrition-regression:{HF_MODEL_REPO_ID or MODEL_DIR}" | |
| elif scratch_classifier(): | |
| model_source = f"classification:{HF_MODEL_REPO_ID or MODEL_DIR}" | |
| else: | |
| model_source = f"pipeline:{MODEL_ID}" | |
| return {"status": "ok", "model": model_source} | |
| def analyze_meal(payload: AnalyzeMealRequest, authorization: str | None = Header(default=None)) -> dict: | |
| require_auth(authorization) | |
| image = decode_image(payload.imageBase64) | |
| nutrient_result = analyze_nutrients(image, payload.portionContext) | |
| if nutrient_result: | |
| return nutrient_result | |
| predictions = classify_image(image) | |
| best = predictions[0] | |
| label = normalize_label(best["label"]) | |
| score = float(best["score"]) | |
| profile = NUTRITION.get(label, DEFAULT_PROFILE) | |
| grams = grams_from_context(payload.portionContext, profile.serving_g) | |
| macros = nutrition_for(label, grams) | |
| confidence, confidence_note = confidence_from(score) | |
| alternatives = [ | |
| f"{normalize_label(item['label']).title()} ({round(float(item['score']) * 100)}%)" | |
| for item in predictions[1:] | |
| ] | |
| if alternatives and confidence != "high": | |
| confidence_note = f"{confidence_note} Alternatives: {', '.join(alternatives)}".strip() | |
| return { | |
| "name": label.title(), | |
| "calories": macros["calories"], | |
| "protein": macros["protein"], | |
| "carbs": macros["carbs"], | |
| "fat": macros["fat"], | |
| "ingredients": [macros["ingredient"]], | |
| "confidence": confidence, | |
| "confidenceNote": confidence_note, | |
| } | |