import base64 import io import os import re from dataclasses import dataclass from functools import lru_cache from pathlib import Path import torch from fastapi import FastAPI, Header, HTTPException from huggingface_hub import snapshot_download from PIL import Image from pydantic import BaseModel, Field from transformers import pipeline from calority_nutrition_model import load_nutrition_checkpoint, predict_nutrients from calority_scratch_model import image_to_tensor, load_checkpoint MODEL_ID = os.getenv("MODEL_ID", "nateraw/food") MODEL_DIR = os.getenv("MODEL_DIR", "") HF_MODEL_REPO_ID = os.getenv("HF_MODEL_REPO_ID", "") MODEL_TASK = os.getenv("MODEL_TASK", "classification") MODEL_API_KEY = os.getenv("MODEL_API_KEY", "") app = FastAPI(title="Calority Meal Model", version="0.1.0") class AnalyzeMealRequest(BaseModel): imageBase64: str = Field(min_length=1) mimeType: str = "image/jpeg" portionContext: str = "" @dataclass(frozen=True) class NutritionProfile: serving_g: int calories_100g: int protein_100g: float carbs_100g: float fat_100g: float NUTRITION = { "apple pie": NutritionProfile(140, 237, 1.9, 34.0, 11.0), "baby back ribs": NutritionProfile(220, 290, 20.0, 6.0, 21.0), "baklava": NutritionProfile(80, 428, 6.0, 54.0, 21.0), "beef carpaccio": NutritionProfile(120, 160, 22.0, 1.0, 7.0), "beef tartare": NutritionProfile(150, 190, 20.0, 2.0, 12.0), "beet salad": NutritionProfile(180, 90, 3.0, 12.0, 4.0), "bibimbap": NutritionProfile(450, 145, 6.0, 20.0, 4.0), "bread pudding": NutritionProfile(160, 220, 5.0, 32.0, 8.0), "breakfast burrito": NutritionProfile(280, 210, 10.0, 23.0, 9.0), "bruschetta": NutritionProfile(120, 190, 6.0, 25.0, 7.0), "caesar salad": NutritionProfile(220, 170, 8.0, 8.0, 12.0), "cannoli": NutritionProfile(90, 310, 7.0, 33.0, 16.0), "caprese salad": NutritionProfile(180, 170, 9.0, 5.0, 13.0), "carrot cake": NutritionProfile(120, 415, 4.0, 50.0, 22.0), "cheesecake": NutritionProfile(125, 321, 6.0, 26.0, 22.0), "chicken curry": NutritionProfile(300, 165, 13.0, 7.0, 9.0), "chicken quesadilla": NutritionProfile(250, 260, 14.0, 22.0, 13.0), "chicken wings": NutritionProfile(180, 290, 24.0, 1.0, 20.0), "chocolate cake": NutritionProfile(120, 371, 5.0, 53.0, 16.0), "club sandwich": NutritionProfile(260, 240, 13.0, 22.0, 12.0), "cup cakes": NutritionProfile(80, 305, 4.0, 47.0, 12.0), "donuts": NutritionProfile(80, 452, 5.0, 51.0, 25.0), "dumplings": NutritionProfile(220, 190, 9.0, 26.0, 6.0), "edamame": NutritionProfile(160, 121, 11.0, 9.0, 5.0), "falafel": NutritionProfile(180, 333, 13.0, 32.0, 18.0), "filet mignon": NutritionProfile(180, 250, 26.0, 0.0, 16.0), "fish and chips": NutritionProfile(350, 230, 11.0, 24.0, 10.0), "french fries": NutritionProfile(150, 312, 3.4, 41.0, 15.0), "fried rice": NutritionProfile(300, 165, 5.0, 25.0, 5.0), "greek salad": NutritionProfile(220, 110, 4.0, 7.0, 8.0), "grilled cheese sandwich": NutritionProfile(180, 350, 12.0, 28.0, 21.0), "hamburger": NutritionProfile(250, 295, 17.0, 24.0, 14.0), "hot dog": NutritionProfile(150, 290, 11.0, 24.0, 17.0), "hummus": NutritionProfile(120, 166, 8.0, 14.0, 10.0), "lasagna": NutritionProfile(320, 170, 10.0, 16.0, 8.0), "macaroni and cheese": NutritionProfile(250, 164, 7.0, 20.0, 6.0), "omelette": NutritionProfile(180, 154, 11.0, 1.0, 12.0), "pancakes": NutritionProfile(220, 227, 6.0, 28.0, 10.0), "pizza": NutritionProfile(250, 266, 11.0, 33.0, 10.0), "ramen": NutritionProfile(500, 90, 4.0, 12.0, 3.0), "samosa": NutritionProfile(150, 260, 6.0, 30.0, 13.0), "sashimi": NutritionProfile(160, 130, 22.0, 0.0, 4.0), "spaghetti bolognese": NutritionProfile(350, 150, 8.0, 20.0, 5.0), "steak": NutritionProfile(220, 250, 26.0, 0.0, 15.0), "sushi": NutritionProfile(220, 145, 7.0, 24.0, 2.0), "tacos": NutritionProfile(220, 210, 10.0, 21.0, 10.0), "waffles": NutritionProfile(180, 291, 8.0, 33.0, 14.0), } DEFAULT_PROFILE = NutritionProfile(250, 180, 8.0, 20.0, 6.0) @lru_cache(maxsize=1) def classifier(): return pipeline("image-classification", model=MODEL_ID) @lru_cache(maxsize=1) def resolved_model_dir() -> str: if MODEL_DIR: return MODEL_DIR if HF_MODEL_REPO_ID: return snapshot_download(repo_id=HF_MODEL_REPO_ID) return "" @lru_cache(maxsize=1) def scratch_classifier(): model_dir = resolved_model_dir() if not model_dir or MODEL_TASK != "classification": return None model_path = Path(model_dir) if not (model_path / "model.pt").exists(): return None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model, labels = load_checkpoint(model_path, device=device) return model, labels, device @lru_cache(maxsize=1) def nutrition_regressor(): model_dir = resolved_model_dir() if not model_dir or MODEL_TASK != "nutrition-regression": return None model_path = Path(model_dir) if not (model_path / "model.pt").exists() or not (model_path / "target_stats.json").exists(): return None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model, target_mean, target_std = load_nutrition_checkpoint(model_path, device=device) return model, target_mean, target_std, device def classify_image(image: Image.Image) -> list[dict]: scratch = scratch_classifier() if scratch is None: return classifier()(image, top_k=3) model, labels, device = scratch tensor = image_to_tensor(image).unsqueeze(0).to(device) with torch.no_grad(): probabilities = torch.softmax(model(tensor), dim=1)[0] top_scores, top_indices = torch.topk(probabilities, k=min(3, len(labels))) return [ {"label": labels[index.item()], "score": score.item()} for score, index in zip(top_scores, top_indices) ] def analyze_nutrients(image: Image.Image, portion_context: str) -> dict | None: regressor = nutrition_regressor() if regressor is None: return None model, target_mean, target_std, device = regressor nutrients = predict_nutrients(model, image, target_mean, target_std, device) calories = round(nutrients["total_calories"]) mass = round(nutrients["total_mass"]) fat = round(nutrients["total_fat"]) carbs = round(nutrients["total_carb"]) protein = round(nutrients["total_protein"]) macro_calories = (protein * 4) + (carbs * 4) + (fat * 9) macro_gap = abs(macro_calories - calories) confidence = "medium" if calories > 0 else "low" confidence_note = ( f"Estimated from image using Calority nutrition regression. Macro calories differ by {macro_gap} kcal." ) if portion_context: confidence_note = f"{confidence_note} User context: {portion_context}." return { "name": "Food Plate", "calories": calories, "protein": protein, "carbs": carbs, "fat": fat, "ingredients": [ f"Estimated total mass {mass}g", f"Protein {protein}g - {protein * 4} kcal", f"Carbs {carbs}g - {carbs * 4} kcal", f"Fat {fat}g - {fat * 9} kcal", ], "confidence": confidence, "confidenceNote": confidence_note, "nutritionDetails": { "totalMass": mass, "calories": calories, "protein": protein, "carbs": carbs, "fat": fat, "macroCalories": macro_calories, }, } def require_auth(authorization: str | None) -> None: if not MODEL_API_KEY: return expected = f"Bearer {MODEL_API_KEY}" if authorization != expected: raise HTTPException(status_code=401, detail="Invalid model service token") def decode_image(image_base64: str) -> Image.Image: try: raw = base64.b64decode(image_base64) return Image.open(io.BytesIO(raw)).convert("RGB") except Exception as exc: raise HTTPException(status_code=400, detail="Invalid imageBase64") from exc def normalize_label(label: str) -> str: return label.lower().replace("_", " ").replace("-", " ").strip() def grams_from_context(portion_context: str, fallback: int) -> int: match = re.search(r"(\d{2,4})\s*(g|gram|grams)\b", portion_context.lower()) if match: return max(30, min(1200, int(match.group(1)))) return fallback def nutrition_for(label: str, grams: int) -> dict: profile = NUTRITION.get(label, DEFAULT_PROFILE) factor = grams / 100 calories = round(profile.calories_100g * factor) protein = round(profile.protein_100g * factor) carbs = round(profile.carbs_100g * factor) fat = round(profile.fat_100g * factor) return { "calories": calories, "protein": protein, "carbs": carbs, "fat": fat, "ingredient": f"{label.title()} estimated {grams}g - {calories} kcal", } def confidence_from(score: float) -> tuple[str, str]: if score >= 0.75: return "high", "" if score >= 0.45: return "medium", "The food is visible, but the model is not fully certain." return "low", "The model could not confidently identify the meal." @app.get("/health") def health() -> dict: if nutrition_regressor(): model_source = f"nutrition-regression:{HF_MODEL_REPO_ID or MODEL_DIR}" elif scratch_classifier(): model_source = f"classification:{HF_MODEL_REPO_ID or MODEL_DIR}" else: model_source = f"pipeline:{MODEL_ID}" return {"status": "ok", "model": model_source} @app.post("/analyze-meal") def analyze_meal(payload: AnalyzeMealRequest, authorization: str | None = Header(default=None)) -> dict: require_auth(authorization) image = decode_image(payload.imageBase64) nutrient_result = analyze_nutrients(image, payload.portionContext) if nutrient_result: return nutrient_result predictions = classify_image(image) best = predictions[0] label = normalize_label(best["label"]) score = float(best["score"]) profile = NUTRITION.get(label, DEFAULT_PROFILE) grams = grams_from_context(payload.portionContext, profile.serving_g) macros = nutrition_for(label, grams) confidence, confidence_note = confidence_from(score) alternatives = [ f"{normalize_label(item['label']).title()} ({round(float(item['score']) * 100)}%)" for item in predictions[1:] ] if alternatives and confidence != "high": confidence_note = f"{confidence_note} Alternatives: {', '.join(alternatives)}".strip() return { "name": label.title(), "calories": macros["calories"], "protein": macros["protein"], "carbs": macros["carbs"], "fat": macros["fat"], "ingredients": [macros["ingredient"]], "confidence": confidence, "confidenceNote": confidence_note, }