okd06's picture
Deploy Calority model API
cecd1f0 verified
import base64
import io
import os
import re
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
import torch
from fastapi import FastAPI, Header, HTTPException
from huggingface_hub import snapshot_download
from PIL import Image
from pydantic import BaseModel, Field
from transformers import pipeline
from calority_nutrition_model import load_nutrition_checkpoint, predict_nutrients
from calority_scratch_model import image_to_tensor, load_checkpoint
MODEL_ID = os.getenv("MODEL_ID", "nateraw/food")
MODEL_DIR = os.getenv("MODEL_DIR", "")
HF_MODEL_REPO_ID = os.getenv("HF_MODEL_REPO_ID", "")
MODEL_TASK = os.getenv("MODEL_TASK", "classification")
MODEL_API_KEY = os.getenv("MODEL_API_KEY", "")
app = FastAPI(title="Calority Meal Model", version="0.1.0")
class AnalyzeMealRequest(BaseModel):
imageBase64: str = Field(min_length=1)
mimeType: str = "image/jpeg"
portionContext: str = ""
@dataclass(frozen=True)
class NutritionProfile:
serving_g: int
calories_100g: int
protein_100g: float
carbs_100g: float
fat_100g: float
NUTRITION = {
"apple pie": NutritionProfile(140, 237, 1.9, 34.0, 11.0),
"baby back ribs": NutritionProfile(220, 290, 20.0, 6.0, 21.0),
"baklava": NutritionProfile(80, 428, 6.0, 54.0, 21.0),
"beef carpaccio": NutritionProfile(120, 160, 22.0, 1.0, 7.0),
"beef tartare": NutritionProfile(150, 190, 20.0, 2.0, 12.0),
"beet salad": NutritionProfile(180, 90, 3.0, 12.0, 4.0),
"bibimbap": NutritionProfile(450, 145, 6.0, 20.0, 4.0),
"bread pudding": NutritionProfile(160, 220, 5.0, 32.0, 8.0),
"breakfast burrito": NutritionProfile(280, 210, 10.0, 23.0, 9.0),
"bruschetta": NutritionProfile(120, 190, 6.0, 25.0, 7.0),
"caesar salad": NutritionProfile(220, 170, 8.0, 8.0, 12.0),
"cannoli": NutritionProfile(90, 310, 7.0, 33.0, 16.0),
"caprese salad": NutritionProfile(180, 170, 9.0, 5.0, 13.0),
"carrot cake": NutritionProfile(120, 415, 4.0, 50.0, 22.0),
"cheesecake": NutritionProfile(125, 321, 6.0, 26.0, 22.0),
"chicken curry": NutritionProfile(300, 165, 13.0, 7.0, 9.0),
"chicken quesadilla": NutritionProfile(250, 260, 14.0, 22.0, 13.0),
"chicken wings": NutritionProfile(180, 290, 24.0, 1.0, 20.0),
"chocolate cake": NutritionProfile(120, 371, 5.0, 53.0, 16.0),
"club sandwich": NutritionProfile(260, 240, 13.0, 22.0, 12.0),
"cup cakes": NutritionProfile(80, 305, 4.0, 47.0, 12.0),
"donuts": NutritionProfile(80, 452, 5.0, 51.0, 25.0),
"dumplings": NutritionProfile(220, 190, 9.0, 26.0, 6.0),
"edamame": NutritionProfile(160, 121, 11.0, 9.0, 5.0),
"falafel": NutritionProfile(180, 333, 13.0, 32.0, 18.0),
"filet mignon": NutritionProfile(180, 250, 26.0, 0.0, 16.0),
"fish and chips": NutritionProfile(350, 230, 11.0, 24.0, 10.0),
"french fries": NutritionProfile(150, 312, 3.4, 41.0, 15.0),
"fried rice": NutritionProfile(300, 165, 5.0, 25.0, 5.0),
"greek salad": NutritionProfile(220, 110, 4.0, 7.0, 8.0),
"grilled cheese sandwich": NutritionProfile(180, 350, 12.0, 28.0, 21.0),
"hamburger": NutritionProfile(250, 295, 17.0, 24.0, 14.0),
"hot dog": NutritionProfile(150, 290, 11.0, 24.0, 17.0),
"hummus": NutritionProfile(120, 166, 8.0, 14.0, 10.0),
"lasagna": NutritionProfile(320, 170, 10.0, 16.0, 8.0),
"macaroni and cheese": NutritionProfile(250, 164, 7.0, 20.0, 6.0),
"omelette": NutritionProfile(180, 154, 11.0, 1.0, 12.0),
"pancakes": NutritionProfile(220, 227, 6.0, 28.0, 10.0),
"pizza": NutritionProfile(250, 266, 11.0, 33.0, 10.0),
"ramen": NutritionProfile(500, 90, 4.0, 12.0, 3.0),
"samosa": NutritionProfile(150, 260, 6.0, 30.0, 13.0),
"sashimi": NutritionProfile(160, 130, 22.0, 0.0, 4.0),
"spaghetti bolognese": NutritionProfile(350, 150, 8.0, 20.0, 5.0),
"steak": NutritionProfile(220, 250, 26.0, 0.0, 15.0),
"sushi": NutritionProfile(220, 145, 7.0, 24.0, 2.0),
"tacos": NutritionProfile(220, 210, 10.0, 21.0, 10.0),
"waffles": NutritionProfile(180, 291, 8.0, 33.0, 14.0),
}
DEFAULT_PROFILE = NutritionProfile(250, 180, 8.0, 20.0, 6.0)
@lru_cache(maxsize=1)
def classifier():
return pipeline("image-classification", model=MODEL_ID)
@lru_cache(maxsize=1)
def resolved_model_dir() -> str:
if MODEL_DIR:
return MODEL_DIR
if HF_MODEL_REPO_ID:
return snapshot_download(repo_id=HF_MODEL_REPO_ID)
return ""
@lru_cache(maxsize=1)
def scratch_classifier():
model_dir = resolved_model_dir()
if not model_dir or MODEL_TASK != "classification":
return None
model_path = Path(model_dir)
if not (model_path / "model.pt").exists():
return None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, labels = load_checkpoint(model_path, device=device)
return model, labels, device
@lru_cache(maxsize=1)
def nutrition_regressor():
model_dir = resolved_model_dir()
if not model_dir or MODEL_TASK != "nutrition-regression":
return None
model_path = Path(model_dir)
if not (model_path / "model.pt").exists() or not (model_path / "target_stats.json").exists():
return None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, target_mean, target_std = load_nutrition_checkpoint(model_path, device=device)
return model, target_mean, target_std, device
def classify_image(image: Image.Image) -> list[dict]:
scratch = scratch_classifier()
if scratch is None:
return classifier()(image, top_k=3)
model, labels, device = scratch
tensor = image_to_tensor(image).unsqueeze(0).to(device)
with torch.no_grad():
probabilities = torch.softmax(model(tensor), dim=1)[0]
top_scores, top_indices = torch.topk(probabilities, k=min(3, len(labels)))
return [
{"label": labels[index.item()], "score": score.item()}
for score, index in zip(top_scores, top_indices)
]
def analyze_nutrients(image: Image.Image, portion_context: str) -> dict | None:
regressor = nutrition_regressor()
if regressor is None:
return None
model, target_mean, target_std, device = regressor
nutrients = predict_nutrients(model, image, target_mean, target_std, device)
calories = round(nutrients["total_calories"])
mass = round(nutrients["total_mass"])
fat = round(nutrients["total_fat"])
carbs = round(nutrients["total_carb"])
protein = round(nutrients["total_protein"])
macro_calories = (protein * 4) + (carbs * 4) + (fat * 9)
macro_gap = abs(macro_calories - calories)
confidence = "medium" if calories > 0 else "low"
confidence_note = (
f"Estimated from image using Calority nutrition regression. Macro calories differ by {macro_gap} kcal."
)
if portion_context:
confidence_note = f"{confidence_note} User context: {portion_context}."
return {
"name": "Food Plate",
"calories": calories,
"protein": protein,
"carbs": carbs,
"fat": fat,
"ingredients": [
f"Estimated total mass {mass}g",
f"Protein {protein}g - {protein * 4} kcal",
f"Carbs {carbs}g - {carbs * 4} kcal",
f"Fat {fat}g - {fat * 9} kcal",
],
"confidence": confidence,
"confidenceNote": confidence_note,
"nutritionDetails": {
"totalMass": mass,
"calories": calories,
"protein": protein,
"carbs": carbs,
"fat": fat,
"macroCalories": macro_calories,
},
}
def require_auth(authorization: str | None) -> None:
if not MODEL_API_KEY:
return
expected = f"Bearer {MODEL_API_KEY}"
if authorization != expected:
raise HTTPException(status_code=401, detail="Invalid model service token")
def decode_image(image_base64: str) -> Image.Image:
try:
raw = base64.b64decode(image_base64)
return Image.open(io.BytesIO(raw)).convert("RGB")
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid imageBase64") from exc
def normalize_label(label: str) -> str:
return label.lower().replace("_", " ").replace("-", " ").strip()
def grams_from_context(portion_context: str, fallback: int) -> int:
match = re.search(r"(\d{2,4})\s*(g|gram|grams)\b", portion_context.lower())
if match:
return max(30, min(1200, int(match.group(1))))
return fallback
def nutrition_for(label: str, grams: int) -> dict:
profile = NUTRITION.get(label, DEFAULT_PROFILE)
factor = grams / 100
calories = round(profile.calories_100g * factor)
protein = round(profile.protein_100g * factor)
carbs = round(profile.carbs_100g * factor)
fat = round(profile.fat_100g * factor)
return {
"calories": calories,
"protein": protein,
"carbs": carbs,
"fat": fat,
"ingredient": f"{label.title()} estimated {grams}g - {calories} kcal",
}
def confidence_from(score: float) -> tuple[str, str]:
if score >= 0.75:
return "high", ""
if score >= 0.45:
return "medium", "The food is visible, but the model is not fully certain."
return "low", "The model could not confidently identify the meal."
@app.get("/health")
def health() -> dict:
if nutrition_regressor():
model_source = f"nutrition-regression:{HF_MODEL_REPO_ID or MODEL_DIR}"
elif scratch_classifier():
model_source = f"classification:{HF_MODEL_REPO_ID or MODEL_DIR}"
else:
model_source = f"pipeline:{MODEL_ID}"
return {"status": "ok", "model": model_source}
@app.post("/analyze-meal")
def analyze_meal(payload: AnalyzeMealRequest, authorization: str | None = Header(default=None)) -> dict:
require_auth(authorization)
image = decode_image(payload.imageBase64)
nutrient_result = analyze_nutrients(image, payload.portionContext)
if nutrient_result:
return nutrient_result
predictions = classify_image(image)
best = predictions[0]
label = normalize_label(best["label"])
score = float(best["score"])
profile = NUTRITION.get(label, DEFAULT_PROFILE)
grams = grams_from_context(payload.portionContext, profile.serving_g)
macros = nutrition_for(label, grams)
confidence, confidence_note = confidence_from(score)
alternatives = [
f"{normalize_label(item['label']).title()} ({round(float(item['score']) * 100)}%)"
for item in predictions[1:]
]
if alternatives and confidence != "high":
confidence_note = f"{confidence_note} Alternatives: {', '.join(alternatives)}".strip()
return {
"name": label.title(),
"calories": macros["calories"],
"protein": macros["protein"],
"carbs": macros["carbs"],
"fat": macros["fat"],
"ingredients": [macros["ingredient"]],
"confidence": confidence,
"confidenceNote": confidence_note,
}