Monserza's picture
Upload 46 files
e406c94 verified
# demo.py — Depth Pro + YOLO segmentation + Portion & Nutrition post-processing (tables version)
import sys
import json
import numpy as np
import cv2
import torch
from PIL import Image
import gradio as gr
from ultralytics import YOLO
# -----------------------------------------------------------
# 1. Import depth_pro (adjust path if needed)
# -----------------------------------------------------------
# If depth_pro is in a local folder "ml-depth-pro/src" next to this file:
sys.path.append("ml-depth-pro/src")
import depth_pro # noqa: E402
# -----------------------------------------------------------
# 2. Device selection
# -----------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device}")
# -----------------------------------------------------------
# 3. Load Depth Pro model
# -----------------------------------------------------------
print("[INFO] Loading Depth Pro model...")
dp_model, dp_transform = depth_pro.create_model_and_transforms()
dp_model = dp_model.to(device)
dp_model.eval()
print("[INFO] Depth Pro ready.")
# -----------------------------------------------------------
# 4. Load YOLO segmentation model
# -----------------------------------------------------------
# TODO: change this to your actual best.pt path
YOLO_MODEL_PATH = r"C:\Users\monol\Desktop\Senior_demo\ml-depth-pro\model\yolo-seg.pt"
print(f"[INFO] Loading YOLO model from: {YOLO_MODEL_PATH}")
yolo_model = YOLO(YOLO_MODEL_PATH)
print("[INFO] YOLO ready.")
# -----------------------------------------------------------
# 5. Load preset + nutrition metadata
# -----------------------------------------------------------
try:
with open("presetdata.json", "r", encoding="utf-8") as f:
PRESET_LIST = json.load(f)
PRESET_BY_CLASS = {item["class"]: item for item in PRESET_LIST}
print(f"[INFO] Loaded {len(PRESET_LIST)} preset entries.")
except Exception as e:
print("[WARN] Could not load presetdata.json:", e)
PRESET_LIST = []
PRESET_BY_CLASS = {}
try:
with open("nutrition_data.json", "r", encoding="utf-8") as f:
NUTRITION_LIST = json.load(f)
NUTR_BY_CLASS = {item["class"]: item for item in NUTRITION_LIST}
print(f"[INFO] Loaded {len(NUTRITION_LIST)} nutrition entries.")
except Exception as e:
print("[WARN] Could not load nutrition_data.json:", e)
NUTRITION_LIST = []
NUTR_BY_CLASS = {}
# -----------------------------------------------------------
# 6. Helper: make depth visualization (RGB uint8)
# -----------------------------------------------------------
def make_depth_vis(depth: np.ndarray) -> np.ndarray:
"""
depth: HxW float (meters), may contain NaNs
returns: HxWx3 uint8 RGB image
"""
d = depth.copy()
d[~np.isfinite(d)] = np.nan
if not np.isfinite(d).any():
return np.zeros((*depth.shape, 3), dtype=np.uint8)
d_min = np.nanpercentile(d, 1)
d_max = np.nanpercentile(d, 99)
if d_max <= d_min:
d_max = d_min + 1e-6
d_norm = (d - d_min) / (d_max - d_min)
d_norm = np.clip(d_norm, 0.0, 1.0)
d_uint8 = (d_norm * 255).astype(np.uint8)
depth_color_bgr = cv2.applyColorMap(d_uint8, cv2.COLORMAP_INFERNO)
depth_color_rgb = cv2.cvtColor(depth_color_bgr, cv2.COLOR_BGR2RGB)
return depth_color_rgb
# -----------------------------------------------------------
# 7. Portion + nutrition helper functions
# Using your equation:
# Mass_in = Mass_ref * (%area_in / %area_ref) * (Z_in / Z_ref)^2
# -----------------------------------------------------------
def estimate_portion_for_class(cls_name, area_in_pct, z_in_m, default_z_in=None):
"""
Estimate portion (grams) for one class using preset reference + depth.
area_in_pct: percentage area of image (0-100)
z_in_m: median depth for that class (meters)
"""
preset = PRESET_BY_CLASS.get(cls_name)
if not preset:
return None
try:
mass_ref = float(preset["portion"]) # grams
area_ref = float(preset["mask_region"]) # % area in reference
z_ref = float(preset["center_depth"]) # meters
except (KeyError, ValueError, TypeError):
return None
if area_ref <= 0 or z_ref <= 0:
return None
if z_in_m is None:
z_in_m = default_z_in
if z_in_m is None or not np.isfinite(z_in_m) or z_in_m <= 0:
return None
# Apply your scaling equation
mass_in = mass_ref * (area_in_pct / area_ref) * (z_in_m / z_ref) ** 2
return {
"class": cls_name,
"estimated_portion_g": float(mass_in),
"area_in_pct": float(area_in_pct),
"area_ref_pct": float(area_ref),
"z_in_m": float(z_in_m),
"z_ref_m": float(z_ref),
"mass_ref_g": float(mass_ref),
}
def estimate_nutrition_for_mass(class_name, mass_g):
"""
Use nutrition_data.json to scale nutrition by mass.
Typically data is per 100 g.
"""
nutr = NUTR_BY_CLASS.get(class_name)
if not nutr:
return None
try:
ref_mass = float(nutr["amount"])
calories = float(nutr["calories"])
protein = float(nutr["protein"])
fat = float(nutr["fat"])
carbs = float(nutr["carbohydrates"])
sodium = float(nutr["sodium"])
except (KeyError, ValueError, TypeError):
return None
if ref_mass <= 0:
return None
factor = mass_g / ref_mass
return {
"class": class_name,
"mass_g": float(mass_g),
"calories": calories * factor,
"protein": protein * factor,
"fat": fat * factor,
"carbohydrates": carbs * factor,
"sodium": sodium * factor,
}
def breakdown_ingredients(dish_class_name, dish_mass_g):
"""
Split a dish (e.g., pad kaprao) into ingredients using presetdata.json,
then compute ingredient-level nutrition if available in nutrition_data.json.
"""
preset = PRESET_BY_CLASS.get(dish_class_name)
if not preset or "ingredients" not in preset:
return [], []
try:
portion_ref = float(preset["portion"])
except (KeyError, ValueError, TypeError):
return [], []
if portion_ref <= 0:
return [], []
ingredient_masses = []
ingredient_nutrition = []
for ing in preset["ingredients"]:
ing_name = ing.get("name")
try:
ing_ref_mass = float(ing["amount"])
except (KeyError, ValueError, TypeError):
continue
ratio = ing_ref_mass / portion_ref
ing_mass_in = dish_mass_g * ratio
ingredient_masses.append({
"dish_class": dish_class_name,
"ingredient": ing_name,
"mass_g": float(ing_mass_in),
})
nutr = estimate_nutrition_for_mass(ing_name, ing_mass_in)
if nutr:
nutr["dish_class"] = dish_class_name
ingredient_nutrition.append(nutr)
return ingredient_masses, ingredient_nutrition
def postprocess_ai_results(rows, center_depth_m):
"""
rows: list of [class_name, area_pct, median_depth_m]
center_depth_m: depth at center of image (meters)
Returns:
- portions_json: list of dicts like
{
"class": "pad kaprao",
"portion": 100,
"portion_label": "gram",
"center_depth": "0.47",
"mask_region": "5.07"
}
- dish_nutr_json: list of dish-level nutrition dicts
- ingredient_nutr_json: list of ingredient-level nutrition dicts
"""
portions_json = []
dish_nutr_json = []
ingredient_nutr_json = []
for cls_name, area_pct, md in rows:
if area_pct is None:
continue
# Use median depth if available; otherwise use global center depth
if md is not None and np.isfinite(md):
z_in = md
else:
z_in = center_depth_m
portion_info = estimate_portion_for_class(
cls_name=cls_name,
area_in_pct=area_pct,
z_in_m=z_in,
default_z_in=center_depth_m,
)
if portion_info is None:
continue
# Portion JSON in your requested-ish format
portions_json.append({
"class": portion_info["class"],
"portion": round(portion_info["estimated_portion_g"], 2),
"portion_label": "gram",
"center_depth": f"{portion_info['z_in_m']:.2f}",
"mask_region": f"{portion_info['area_in_pct']:.2f}",
})
# Dish-level nutrition
dish_n = estimate_nutrition_for_mass(
cls_name,
portion_info["estimated_portion_g"]
)
if dish_n:
dish_nutr_json.append({
"class": dish_n["class"],
"mass_g": round(dish_n["mass_g"], 2),
"calories": round(dish_n["calories"], 1),
"protein": round(dish_n["protein"], 1),
"fat": round(dish_n["fat"], 1),
"carbohydrates": round(dish_n["carbohydrates"], 1),
"sodium": round(dish_n["sodium"], 1),
})
# Ingredient-level nutrition (show ALL ingredients, even if we don’t know nutrition)
ing_masses, ing_nutrition = breakdown_ingredients(
dish_class_name=cls_name,
dish_mass_g=portion_info["estimated_portion_g"],
)
# Build a quick lookup: (dish_class, ingredient_name) -> nutrition dict
nutr_lookup = {}
for n in ing_nutrition:
key = (n.get("dish_class", cls_name), n["class"])
nutr_lookup[key] = n
for mass_rec in ing_masses:
dish_cls = mass_rec["dish_class"]
ing_name = mass_rec["ingredient"]
mass_g = mass_rec["mass_g"]
key = (dish_cls, ing_name)
n = nutr_lookup.get(key)
if n is not None:
# We have nutrition data for this ingredient
ingredient_nutr_json.append({
"dish_class": dish_cls,
"ingredient": ing_name,
"mass_g": round(mass_g, 2),
"calories": round(n["calories"], 1),
"protein": round(n["protein"], 1),
"fat": round(n["fat"], 1),
"carbohydrates": round(n["carbohydrates"], 1),
"sodium": round(n["sodium"], 1),
})
else:
# No nutrition data -> still show ingredient with mass, leave nutrients blank
ingredient_nutr_json.append({
"dish_class": dish_cls,
"ingredient": ing_name,
"mass_g": round(mass_g, 2),
"calories": None,
"protein": None,
"fat": None,
"carbohydrates": None,
"sodium": None,
})
return portions_json, dish_nutr_json, ingredient_nutr_json
# -----------------------------------------------------------
# 8. Main pipeline: Depth Pro + YOLO segmentation + post-processing
# -----------------------------------------------------------
def analyze_image(pil_img: Image.Image):
# ---------- safety ----------
if pil_img is None:
blank = np.zeros((10, 10, 3), dtype=np.uint8)
return blank, blank, "Please upload an image first.", [], [], [], []
# Ensure RGB
pil_img = pil_img.convert("RGB")
rgb_np = np.array(pil_img)
H_s, W_s, _ = rgb_np.shape
# =======================================================
# A) YOLO segmentation (for mask & class percentages)
# =======================================================
seg_vis = rgb_np.copy()
class_to_mask = {} # class_name -> combined bool mask H_s x W_s
# YOLO expects BGR typically; convert
bgr_np = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2BGR)
try:
results = yolo_model.predict(
source=bgr_np,
save=False, # we don't save images to disk
conf=0.25,
iou=0.7,
verbose=False
)
r = results[0]
# visualization (BGR -> RGB)
seg_plot_bgr = r.plot()
seg_vis = cv2.cvtColor(seg_plot_bgr, cv2.COLOR_BGR2RGB)
if r.masks is not None and len(r.masks.data) > 0:
masks = r.masks.data.cpu().numpy() # [N, H, W] in YOLO image space
boxes = r.boxes
for i in range(len(masks)):
cls_id = int(boxes.cls[i])
cls_name = yolo_model.names[cls_id]
mask_i = masks[i] > 0.5 # bool H_s x W_s
if cls_name not in class_to_mask:
class_to_mask[cls_name] = mask_i
else:
class_to_mask[cls_name] |= mask_i
else:
print("[YOLO] No masks found.")
except Exception as e:
print("[YOLO ERROR]", e)
seg_vis = seg_vis.astype(np.uint8)
# =======================================================
# B) Depth Pro (distance from camera)
# =======================================================
try:
dp_in = dp_transform(pil_img).to(device)
with torch.no_grad():
pred = dp_model.infer(dp_in, f_px=None)
depth = pred["depth"]
if isinstance(depth, torch.Tensor):
depth = depth.squeeze().cpu().numpy()
except Exception as e:
blank = np.zeros((10, 10, 3), dtype=np.uint8)
return blank, seg_vis, f"Depth estimation error: {e}", [], [], [], []
if depth is None or not np.isfinite(depth).any():
blank = np.zeros((10, 10, 3), dtype=np.uint8)
return blank, seg_vis, "Depth map invalid (NaN/empty).", [], [], [], []
H_d, W_d = depth.shape
# depth visualization (resized to original image size)
depth_vis = make_depth_vis(depth)
depth_vis_big = cv2.resize(depth_vis, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
depth_vis_big = depth_vis_big.astype(np.uint8)
# -------------------------------------------------------
# Global depth summary (center + ROI)
# -------------------------------------------------------
cx_d, cy_d = W_d // 2, H_d // 2
center_depth = float(depth[cy_d, cx_d])
roi = depth[int(H_d * 0.4):int(H_d * 0.6), int(W_d * 0.4):int(W_d * 0.6)]
roi = roi[np.isfinite(roi)]
roi_depth = float(np.median(roi)) if roi.size > 0 else float("nan")
depth_lines = [
"### Depth Estimate",
f"- Center depth: **{center_depth:.2f} m**",
]
if np.isfinite(roi_depth):
depth_lines.append(f"- Center ROI median depth: **{roi_depth:.2f} m**")
# =======================================================
# C) Compute % area + median depth per class
# =======================================================
total_pixels = H_s * W_s
rows = [] # for segmentation stats table: [class, area%, median_depth]
for cls_name, mask in class_to_mask.items():
# percentage of image area
area_px = int(mask.sum())
area_pct = 100.0 * area_px / total_pixels if total_pixels > 0 else 0.0
# resize mask to depth resolution to sample depth correctly
mask_u8 = (mask.astype(np.uint8) * 255)
mask_depth = cv2.resize(
mask_u8, (W_d, H_d), interpolation=cv2.INTER_NEAREST
) > 0
obj_depths = depth[mask_depth & np.isfinite(depth)]
if obj_depths.size > 0:
median_depth = float(np.median(obj_depths))
else:
median_depth = float("nan")
rows.append([
cls_name,
round(area_pct, 2),
None if not np.isfinite(median_depth) else round(median_depth, 2)
])
# Post-processing: portions + nutrition based on rows + center_depth
portions_json, dish_nutr_json, ingredient_nutr_json = postprocess_ai_results(
rows, center_depth
)
if rows:
depth_lines.append("\n### Object distances (per class)")
for cls_name, area_pct, md in rows:
if md is None:
depth_lines.append(
f"- {cls_name}: {area_pct:.2f}% of image, depth: N/A"
)
else:
depth_lines.append(
f"- {cls_name}: {area_pct:.2f}% of image, median depth **{md:.2f} m**"
)
else:
depth_lines.append("\n_No segmentation masks detected._")
depth_text = "\n".join(depth_lines)
# -------------------------------------------------------
# Convert JSON-like results to table rows for Dataframe
# -------------------------------------------------------
# Portions table: class, portion(g), center_depth(m), mask_region(%)
portions_table_rows = [
[
p["class"],
p["portion"],
p["portion_label"],
p["center_depth"],
p["mask_region"],
]
for p in portions_json
]
# Dish nutrition table: class, mass_g, kcal, protein, fat, carbs, sodium
dish_table_rows = [
[
d["class"],
d["mass_g"],
d["calories"],
d["protein"],
d["fat"],
d["carbohydrates"],
d["sodium"],
]
for d in dish_nutr_json
]
# Ingredient nutrition table:
# dish_class, ingredient, mass_g, kcal, protein, fat, carbs, sodium
ingredient_table_rows = [
[
ing["dish_class"],
ing["ingredient"],
ing["mass_g"],
ing["calories"],
ing["protein"],
ing["fat"],
ing["carbohydrates"],
ing["sodium"],
]
for ing in ingredient_nutr_json
]
return (
depth_vis_big,
seg_vis,
depth_text,
rows,
portions_table_rows,
dish_table_rows,
ingredient_table_rows,
)
# -----------------------------------------------------------
# 9. Gradio UI (using tables/Dataframe instead of JSON)
# -----------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown(
"<h2 style='text-align:center;'>Depth Pro + YOLO Segmentation + Nutrition Demo</h2>"
"<p style='text-align:center;'>"
"Upload a food image → get depth map, object distance, estimated portion, and nutrition per dish & ingredient."
"</p>"
)
with gr.Row():
input_img = gr.Image(label="Upload food image", type="pil")
with gr.Row():
depth_out = gr.Image(label="Depth overlay", type="numpy")
seg_out = gr.Image(label="Segmentation result", type="numpy")
with gr.Row():
depth_info = gr.Markdown(label="Depth estimate")
seg_table = gr.Dataframe(
headers=["Class", "Area % of image", "Median depth (m)"],
datatype=["str", "number", "number"],
label="Segmentation stats"
)
portions_table = gr.Dataframe(
headers=["Class", "Portion (g)", "Unit", "Center depth (m)", "Mask region (%)"],
datatype=["str", "number", "str", "str", "str"],
label="Estimated Portions (per class)",
)
dish_nutrition_table = gr.Dataframe(
headers=["Class", "Mass (g)", "Calories", "Protein (g)", "Fat (g)", "Carbs (g)", "Sodium (mg)"],
datatype=["str", "number", "number", "number", "number", "number", "number"],
label="Dish Nutrition (per class)",
)
ingredient_nutrition_table = gr.Dataframe(
headers=["Dish", "Ingredient", "Mass (g)", "Calories", "Protein (g)", "Fat (g)", "Carbs (g)", "Sodium (mg)"],
datatype=["str", "str", "number", "number", "number", "number", "number", "number"],
label="Ingredient Nutrition (per ingredient)",
)
run_btn = gr.Button("Run analysis")
run_btn.click(
fn=analyze_image,
inputs=input_img,
outputs=[
depth_out,
seg_out,
depth_info,
seg_table,
portions_table,
dish_nutrition_table,
ingredient_nutrition_table,
],
)
demo.launch(server_name="0.0.0.0", server_port=7860)