caloraify / app.py
Unnatrathi's picture
Update app.py
aac8c26 verified
# =============================================================================
# CaLoRAify β€” HuggingFace Space: FastAPI Inference Server
# Deploy this file to a HuggingFace Space (SDK: Docker or Gradio with custom app)
#
# Space hardware: CPU Basic (free) works for demo; T4 Small for faster response
# =============================================================================
#
# HOW TO DEPLOY:
# 1. Create a new Space at huggingface.co/new-space
# 2. Choose SDK: "Docker"
# 3. Upload this file as app.py + requirements.txt + Dockerfile
# 4. Set Space secrets:
# HF_MODEL_ID = "HuggingFaceTB/SmolVLM2-500M-Instruct"
# LORA_REPO_ID = "your_username/caloraify-lora-adapter" (your uploaded adapter)
# API_KEY = "any_secret_string_you_choose"
# 5. Your endpoint will be: https://your_username-caloraify.hf.space/analyze
#
# HOW TO UPLOAD YOUR LORA ADAPTER:
# Run this in Colab after training:
# from huggingface_hub import HfApi
# api = HfApi()
# api.create_repo("caloraify-lora-adapter", repo_type="model", private=False)
# api.upload_folder(
# folder_path="./caloraify-smolvlm2-lora",
# repo_id="your_username/caloraify-lora-adapter",
# repo_type="model"
# )
import os
import io
import json
import base64
import logging
import time
from contextlib import asynccontextmanager
from typing import Optional
import torch
from fastapi import FastAPI, HTTPException, Header, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from PIL import Image
import re
import requests as req_lib
# ── Lazy imports for heavy ML deps ──────────────────────────────────────────
# Imported inside lifespan so the Space starts quickly and fails clearly
_processor = None
_model = None
logger = logging.getLogger("caloraify")
logging.basicConfig(level=logging.INFO)
# ── Startup / shutdown ───────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global _processor, _model
from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
from transformers.models.smolvlm.configuration_smolvlm import SmolVLMConfig
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
from bitsandbytes.nn import Linear4bit
from peft import PeftModel
MODEL_ID = os.environ.get("HF_MODEL_ID", "HuggingFaceTB/SmolVLM2-500M-Instruct")
LORA_REPO = os.environ.get("LORA_REPO_ID", "")
# Registry patch (same fix as training notebook)
AutoModelForVision2Seq.register(SmolVLMConfig, SmolVLMForConditionalGeneration, exist_ok=True)
logger.info("Loading processor …")
_processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
# Use 4-bit if CUDA available, else fp32 on CPU
if torch.cuda.is_available():
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
dtype = torch.bfloat16
logger.info("Loading model in 4-bit NF4 on GPU …")
else:
bnb = None
dtype = torch.float32
logger.info("No GPU β€” loading model in fp32 on CPU (slower) …")
base = AutoModelForVision2Seq.from_pretrained(
MODEL_ID,
quantization_config=bnb,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else "cpu",
trust_remote_code=True,
_fast_init=False,
)
# Cast vision encoder to bfloat16 (same fix as training notebook)
if torch.cuda.is_available():
for module in base.modules():
if isinstance(module, Linear4bit):
continue
for p in module.parameters(recurse=False):
if p.dtype == torch.float32:
p.data = p.data.to(torch.bfloat16)
# Load LoRA adapter if provided
if LORA_REPO:
logger.info(f"Loading LoRA adapter from {LORA_REPO} …")
_model = PeftModel.from_pretrained(base, LORA_REPO)
else:
logger.warning("No LORA_REPO_ID set β€” running base model without fine-tuning")
_model = base
_model.eval()
_model.config.use_cache = True
logger.info("Model ready βœ…")
yield # server runs here
logger.info("Shutting down …")
# ── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title="CaLoRAify Inference API",
description="Food photo β†’ ingredient list + calorie JSON",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST", "GET"],
allow_headers=["*"],
)
# ── Auth dependency ───────────────────────────────────────────────────────────
def verify_api_key(x_api_key: Optional[str] = Header(default=None)):
expected = os.environ.get("API_KEY", "")
if expected and x_api_key != expected:
raise HTTPException(status_code=401, detail="Invalid API key")
# ── Request / response models ─────────────────────────────────────────────────
class AnalyzeRequest(BaseModel):
image_b64: str # base64-encoded JPEG/PNG
max_new_tokens: int = 300
class NutritionResponse(BaseModel):
ingredients: str
portion_notes: str
calories: Optional[float] = None
protein_g: Optional[float] = None
carbs_g: Optional[float] = None
fat_g: Optional[float] = None
fibre_g: Optional[float] = None
raw_text: str # full model output for debugging
latency_ms: int
# ── Inference helper ──────────────────────────────────────────────────────────
def _decode_image(b64: str) -> Image.Image:
try:
data = base64.b64decode(b64)
img = Image.open(io.BytesIO(data)).convert("RGB")
img = img.resize((384, 384))
return img
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image: {e}")
def _run_inference(image: Image.Image, max_new_tokens: int) -> str:
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": (
"What food are in this image? "
"Reply in exactly this format:\n"
"Ingredients detected: [food name and ingredients separated by commas]\n"
"Example: Ingredients detected: banana, rice, chicken, broccoli"
),
},
],
}
]
prompt = _processor.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=True
)
inputs = _processor(
images=[[image]],
text=[prompt],
return_tensors="pt",
truncation=False,
)
device = next(_model.parameters()).device
if "pixel_values" in inputs and inputs["pixel_values"] is not None:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.inference_mode():
out_ids = _model.generate(
**inputs,
max_new_tokens=150,
do_sample=False,
temperature=1.0,
repetition_penalty=1.3,
no_repeat_ngram_size=3,
)
new_tokens = out_ids[:, inputs["input_ids"].shape[-1]:]
return _processor.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
def _get_nutrition_from_api(ingredients_text: str) -> dict:
NUTRITION_TABLE = {
"banana": {"calories": 89, "protein_g": 1.1, "carbs_g": 23.0, "fat_g": 0.3, "fibre_g": 2.6},
"apple": {"calories": 72, "protein_g": 0.4, "carbs_g": 19.0, "fat_g": 0.2, "fibre_g": 2.4},
"orange": {"calories": 62, "protein_g": 1.2, "carbs_g": 15.4, "fat_g": 0.2, "fibre_g": 3.1},
"mango": {"calories": 99, "protein_g": 1.4, "carbs_g": 25.0, "fat_g": 0.6, "fibre_g": 2.6},
"grape": {"calories": 69, "protein_g": 0.7, "carbs_g": 18.1, "fat_g": 0.2, "fibre_g": 0.9},
"strawberry": {"calories": 32, "protein_g": 0.7, "carbs_g": 7.7, "fat_g": 0.3, "fibre_g": 2.0},
"watermelon": {"calories": 30, "protein_g": 0.6, "carbs_g": 7.6, "fat_g": 0.2, "fibre_g": 0.4},
"pineapple": {"calories": 50, "protein_g": 0.5, "carbs_g": 13.0, "fat_g": 0.1, "fibre_g": 1.4},
"coconut": {"calories": 354, "protein_g": 3.3, "carbs_g": 15.0, "fat_g": 33.0,"fibre_g": 9.0},
"lettuce": {"calories": 15, "protein_g": 1.4, "carbs_g": 2.9, "fat_g": 0.2, "fibre_g": 1.3},
"salad": {"calories": 20, "protein_g": 1.8, "carbs_g": 3.6, "fat_g": 0.3, "fibre_g": 2.0},
"carrot": {"calories": 41, "protein_g": 0.9, "carbs_g": 10.0, "fat_g": 0.2, "fibre_g": 2.8},
"onion": {"calories": 40, "protein_g": 1.1, "carbs_g": 9.3, "fat_g": 0.1, "fibre_g": 1.7},
"rice": {"calories": 206, "protein_g": 4.3, "carbs_g": 45.0, "fat_g": 0.4, "fibre_g": 0.6},
"chicken": {"calories": 239, "protein_g": 27.0,"carbs_g": 0.0, "fat_g": 14.0,"fibre_g": 0.0},
"egg": {"calories": 155, "protein_g": 13.0,"carbs_g": 1.1, "fat_g": 11.0,"fibre_g": 0.0},
"bread": {"calories": 265, "protein_g": 9.0, "carbs_g": 49.0, "fat_g": 3.2, "fibre_g": 2.7},
"milk": {"calories": 61, "protein_g": 3.2, "carbs_g": 4.8, "fat_g": 3.3, "fibre_g": 0.0},
"cheese": {"calories": 402, "protein_g": 25.0,"carbs_g": 1.3, "fat_g": 33.0,"fibre_g": 0.0},
"pizza": {"calories": 266, "protein_g": 11.0,"carbs_g": 33.0, "fat_g": 10.0,"fibre_g": 2.3},
"burger": {"calories": 295, "protein_g": 17.0,"carbs_g": 24.0, "fat_g": 14.0,"fibre_g": 1.3},
"pasta": {"calories": 220, "protein_g": 8.1, "carbs_g": 43.0, "fat_g": 1.3, "fibre_g": 2.5},
"fish": {"calories": 136, "protein_g": 20.0,"carbs_g": 0.0, "fat_g": 6.0, "fibre_g": 0.0},
"potato": {"calories": 77, "protein_g": 2.0, "carbs_g": 17.0, "fat_g": 0.1, "fibre_g": 2.2},
"broccoli": {"calories": 34, "protein_g": 2.8, "carbs_g": 6.6, "fat_g": 0.4, "fibre_g": 2.6},
"tomato": {"calories": 18, "protein_g": 0.9, "carbs_g": 3.9, "fat_g": 0.2, "fibre_g": 1.2},
"sandwich": {"calories": 250, "protein_g": 12.0,"carbs_g": 33.0, "fat_g": 7.0, "fibre_g": 2.5},
"soup": {"calories": 71, "protein_g": 3.8, "carbs_g": 8.0, "fat_g": 2.0, "fibre_g": 1.5},
"chocolate": {"calories": 546, "protein_g": 5.0, "carbs_g": 60.0, "fat_g": 31.0,"fibre_g": 7.0},
"cake": {"calories": 347, "protein_g": 5.0, "carbs_g": 55.0, "fat_g": 12.0,"fibre_g": 1.0},
"dal": {"calories": 116, "protein_g": 9.0, "carbs_g": 20.0, "fat_g": 0.4, "fibre_g": 8.0},
"roti": {"calories": 297, "protein_g": 9.9, "carbs_g": 61.0, "fat_g": 1.7, "fibre_g": 1.9},
"biryani": {"calories": 200, "protein_g": 8.0, "carbs_g": 30.0, "fat_g": 6.0, "fibre_g": 1.5},
"paneer": {"calories": 265, "protein_g": 18.0,"carbs_g": 3.4, "fat_g": 20.0,"fibre_g": 0.0},
"idli": {"calories": 58, "protein_g": 2.0, "carbs_g": 12.0, "fat_g": 0.4, "fibre_g": 0.5},
"dosa": {"calories": 168, "protein_g": 3.7, "carbs_g": 30.0, "fat_g": 3.7, "fibre_g": 1.5},
"samosa": {"calories": 262, "protein_g": 3.5, "carbs_g": 28.0, "fat_g": 15.0,"fibre_g": 2.0},
"noodle": {"calories": 138, "protein_g": 4.5, "carbs_g": 25.0, "fat_g": 2.0, "fibre_g": 1.8},
"omelette": {"calories": 154, "protein_g": 11.0,"carbs_g": 0.4, "fat_g": 12.0,"fibre_g": 0.0},
"yogurt": {"calories": 59, "protein_g": 10.0,"carbs_g": 3.6, "fat_g": 0.4, "fibre_g": 0.0},
"coffee": {"calories": 2, "protein_g": 0.3, "carbs_g": 0.0, "fat_g": 0.0, "fibre_g": 0.0},
}
text_lower = ingredients_text.lower()
# Count how many times each food appears β€” pick the one mentioned FIRST
# (first mention = most prominent ingredient in the model's output)
matched = {}
for food in NUTRITION_TABLE:
idx = text_lower.find(food)
if idx != -1:
matched[food] = idx # store position of first mention
if matched:
# Use the food mentioned earliest in the text (most prominent)
primary_food = min(matched, key=matched.get)
# If multiple foods found, sum up unique ones for a combined estimate
unique_foods = list(set(matched.keys()))[:5] # max 5 ingredients
if len(unique_foods) > 1:
total = {"calories": 0, "protein_g": 0, "carbs_g": 0, "fat_g": 0, "fibre_g": 0}
for food in unique_foods:
for key in total:
total[key] += NUTRITION_TABLE[food][key]
# Average across ingredients (rough estimate per serving)
count = len(unique_foods)
logger.info(f"Combined nutrition for: {unique_foods}")
return {k: round(v / count, 1) for k, v in total.items()}
else:
logger.info(f"Single food nutrition for: {primary_food}")
return NUTRITION_TABLE[primary_food]
# Try Open Food Facts as last resort
try:
import re as _re
words = _re.findall(r'\b[a-zA-Z]{4,}\b', ingredients_text)
query = words[0] if words else "food"
r = req_lib.get(
"https://world.openfoodfacts.org/cgi/search.pl",
params={"search_terms": query, "search_simple": 1,
"action": "process", "json": 1, "page_size": 3,
"fields": "nutriments"},
timeout=8,
)
for product in r.json().get("products", []):
n = product.get("nutriments", {})
cal = float(n.get("energy-kcal_100g") or 0)
if cal > 0:
return {
"calories": round(cal, 1),
"protein_g": round(float(n.get("proteins_100g", 0) or 0), 1),
"carbs_g": round(float(n.get("carbohydrates_100g", 0) or 0), 1),
"fat_g": round(float(n.get("fat_100g", 0) or 0), 1),
"fibre_g": round(float(n.get("fiber_100g", 0) or 0), 1),
}
except Exception as e:
logger.warning(f"OpenFoodFacts failed: {e}")
return {}
def _parse_response(raw: str) -> dict:
import re
result = {
"ingredients": "",
"portion_notes": "Portion estimated from image.",
"raw_text": raw,
"calories": None,
"protein_g": None,
"carbs_g": None,
"fat_g": None,
"fibre_g": None,
}
# ── Strategy 1: structured CaLoRAify format ───────────────────────────
if "Ingredients detected:" in raw:
ing_start = raw.index("Ingredients detected:") + len("Ingredients detected:")
# End at next section or newline
for end_marker in ["Portion Analysis:", "JSON Summary:", "\n\n"]:
if end_marker in raw[ing_start:]:
ing_end = raw.index(end_marker, ing_start)
break
else:
ing_end = min(ing_start + 200, len(raw))
result["ingredients"] = raw[ing_start:ing_end].strip().rstrip(".")
if "Portion Analysis:" in raw:
pa_start = raw.index("Portion Analysis:") + len("Portion Analysis:")
pa_end = raw.find("JSON Summary:", pa_start)
pa_end = pa_end if pa_end != -1 else pa_start + 150
result["portion_notes"] = raw[pa_start:pa_end].strip()
if "JSON Summary:" in raw:
json_start = raw.index("JSON Summary:") + len("JSON Summary:")
brace_start = raw.find("{", json_start)
brace_end = raw.rfind("}") + 1
if brace_start != -1 and brace_end > brace_start:
try:
nutrition = json.loads(raw[brace_start:brace_end])
result.update({
"calories": nutrition.get("calories_kcal") or nutrition.get("calories"),
"protein_g": nutrition.get("protein_g"),
"carbs_g": nutrition.get("carbs_g") or nutrition.get("carbohydrate_g"),
"fat_g": nutrition.get("fat_g"),
"fibre_g": nutrition.get("fibre_g"),
})
except json.JSONDecodeError:
pass
# ── Strategy 2: scan raw text for known food words directly ───────────
# This works regardless of what format the model uses
if not result["ingredients"] or result["calories"] is None:
FOOD_WORDS = [
"banana", "apple", "orange", "mango", "grape", "strawberry",
"watermelon", "pineapple", "coconut", "lettuce", "salad",
"carrot", "onion", "rice", "chicken", "egg", "bread", "milk",
"cheese", "pizza", "burger", "pasta", "fish", "potato",
"broccoli", "tomato", "sandwich", "soup", "chocolate", "cake",
"dal", "roti", "biryani", "paneer", "idli", "dosa", "samosa",
"noodle", "omelette", "yogurt", "coffee", "curry", "taco",
"sushi", "steak", "bacon", "butter", "cream", "icecream",
]
text_lower = raw.lower()
found_foods = []
for food in FOOD_WORDS:
if food in text_lower:
found_foods.append(food)
if found_foods:
# Use found foods as ingredients if not already set
if not result["ingredients"]:
result["ingredients"] = ", ".join(found_foods)
# Get nutrition from found foods
if result["calories"] is None:
nutrition = _get_nutrition_from_api(", ".join(found_foods))
if nutrition:
result.update(nutrition)
# ── Strategy 3: fallback β€” use entire raw text for API lookup ─────────
if not result["ingredients"]:
result["ingredients"] = raw.strip()[:150]
if result["calories"] is None and result["ingredients"]:
nutrition = _get_nutrition_from_api(result["ingredients"])
if nutrition:
result.update(nutrition)
return result
# ── Endpoints ─────────────────────────────────────────────────────────────────
@app.get("/health")
def health():
return {
"status": "ok",
"model_loaded": _model is not None,
"cuda": torch.cuda.is_available(),
}
@app.post("/analyze", response_model=NutritionResponse)
def analyze(req: AnalyzeRequest, _=Depends(verify_api_key)):
if _model is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
image = _decode_image(req.image_b64)
t0 = time.monotonic()
raw = _run_inference(image, req.max_new_tokens)
ms = int((time.monotonic() - t0) * 1000)
parsed = _parse_response(raw)
return NutritionResponse(**parsed, latency_ms=ms)
@app.post("/debug")
def debug(req: AnalyzeRequest, _=Depends(verify_api_key)):
image = _decode_image(req.image_b64)
raw = _run_inference(image, req.max_new_tokens)
return {"raw_text": raw}