import os import sys import types import importlib.machinery from typing import List, Dict import gradio as gr import torch from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM # =============== 避免 flash_attn 強相依(不安裝它) =============== def _make_pkg_stub(fullname: str): m = types.ModuleType(fullname) m.__file__ = f"" m.__package__ = fullname.rpartition('.')[0] m.__path__ = [] m.__spec__ = importlib.machinery.ModuleSpec(fullname, loader=None, is_package=True) sys.modules[fullname] = m return m for name in [ "flash_attn","flash_attn.ops","flash_attn.layers", "flash_attn.functional","flash_attn.bert_padding","flash_attn.flash_attn_interface", ]: if name not in sys.modules: _make_pkg_stub(name) # =============== Florence-2 載入(eager + 關 cache) =============== MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base") device = "cuda" if torch.cuda.is_available() else "cpu" _processor = None _model = None def get_florence2(): global _processor, _model if _processor is None or _model is None: _processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, attn_implementation="eager", torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device).eval() _model.config.use_cache = False return _processor, _model @torch.inference_mode() def _generate_text_and_parse(image: Image.Image, task_token: str, text_input: str | None = None): """ 回傳: - 對 caption 任務:{'': '...'} 這類 - 對 grounding 任務:{'': {'bboxes': [...], 'labels': [...]}} """ proc, mdl = get_florence2() text = task_token if text_input is None else (task_token + text_input) batch = proc(text=text, images=image, return_tensors="pt") # 對齊 dtype/device inputs = {} for k, v in batch.items(): if isinstance(v, torch.Tensor): inputs[k] = v.to(device=device, dtype=mdl.dtype if v.is_floating_point() else None) else: inputs[k] = v ids = mdl.generate( **inputs, max_new_tokens=1024, do_sample=False, num_beams=1, # 貪婪,跨環境最穩 use_cache=False, early_stopping=False, eos_token_id=getattr(getattr(proc, "tokenizer", None), "eos_token_id", None), ) gen = proc.batch_decode(ids, skip_special_tokens=False)[0] parsed = proc.post_process_generation( gen, task=task_token, image_size=(image.width, image.height) ) return parsed def florence2_cascade_food_labels(image: Image.Image): """ 級聯路徑: 1) → caption_text 2) (caption_text) → labels(我們拿來當食物候選) """ # step1: 更詳細 caption cap_res = _generate_text_and_parse(image, "") caption_text = cap_res.get("", "") # step2: grounding(把 caption 各片語對齊到框與標籤) grd_res = _generate_text_and_parse(image, "", text_input=caption_text) grounding = grd_res.get("", {}) labels = grounding.get("labels", []) or [] # labels 是一串字串(可能含形容詞),我們後續再做 alias 過濾 return caption_text, labels # =============== 營養資料 / 同義詞 / 規則 =============== FOOD_DB = { "rice": {"kcal":130, "carb_g":28, "protein_g":2.4, "fat_g":0.3, "sodium_mg":0, "cat":"全榖雜糧類", "base_g":150, "tip":"主食可改糙米/全穀增加膳食纖維"}, "noodles":{"kcal":138, "carb_g":25, "protein_g":4.5, "fat_g":1.9, "sodium_mg":170, "cat":"全榖雜糧類", "base_g":180, "tip":"清湯少油,避免重鹹湯底"}, "bread": {"kcal":265, "carb_g":49, "protein_g":9.0, "fat_g":3.2, "sodium_mg":490, "cat":"全榖雜糧類", "base_g":60, "tip":"可選全麥減少抹醬、甜餡"}, "broccoli":{"kcal":35, "carb_g":7, "protein_g":2.4, "fat_g":0.4, "sodium_mg":33, "cat":"蔬菜類", "base_g":80, "tip":"川燙/清炒保留口感與維生素"}, "spinach":{"kcal":23, "carb_g":3.6,"protein_g":2.9,"fat_g":0.4,"sodium_mg":70, "cat":"蔬菜類", "base_g":80, "tip":"川燙後快炒,少鹽少油"}, "chicken":{"kcal":215,"carb_g":0, "protein_g":27, "fat_g":12, "sodium_mg":90, "cat":"豆魚蛋肉類", "base_g":120, "tip":"去皮烹調、烤/氣炸取代油炸"}, "soy_braised_chicken_leg":{"kcal":220,"carb_g":0,"protein_g":24,"fat_g":12,"sodium_mg":550,"cat":"豆魚蛋肉類","base_g":130,"tip":"減醬油與滷汁、可先汆燙再滷"}, "salmon":{"kcal":208,"carb_g":0, "protein_g":20, "fat_g":13, "sodium_mg":60, "cat":"豆魚蛋肉類", "base_g":120, "tip":"烤/蒸保留 Omega-3,少鹽少醬"}, "pork_chop":{"kcal":242,"carb_g":0,"protein_g":27,"fat_g":14,"sodium_mg":75, "cat":"豆魚蛋肉類", "base_g":120, "tip":"少裹粉油炸,改煎烤並瀝油"}, "tofu": {"kcal":76, "carb_g":1.9,"protein_g":8.1,"fat_g":4.8,"sodium_mg":7, "cat":"豆魚蛋肉類", "base_g":120, "tip":"少勾芡、少滷汁,清蒸清爽"}, "egg": {"kcal":155,"carb_g":1.1,"protein_g":13, "fat_g":11, "sodium_mg":124, "cat":"豆魚蛋肉類", "base_g":60, "tip":"水煮/荷包少油,避免重鹹醬料"}, "banana":{"kcal":89, "carb_g":23, "protein_g":1.1,"fat_g":0.3,"sodium_mg":1, "cat":"水果類", "base_g":100, "tip":"控制份量,避免一次過量"}, "miso_soup":{"kcal":36,"carb_g":4.3,"protein_g":2.0,"fat_g":1.3,"sodium_mg":550, "cat":"湯品/飲品", "base_g":200, "tip":"味噌湯偏鹹,建議少量品嚐"}, # 想開放泛化兩筆可解除註解: # "salad": {"kcal":30,"carb_g":5,"protein_g":1.5,"fat_g":0.5,"sodium_mg":40,"cat":"蔬菜類","base_g":100,"tip":"少醬少油,優先清爽調味"}, # "fish": {"kcal":170,"carb_g":0,"protein_g":22,"fat_g":8,"sodium_mg":70,"cat":"豆魚蛋肉類","base_g":120,"tip":"蒸/烤/煎少油,避免重鹹醬汁"}, } ALIASES = { "white rice":"rice","steamed rice":"rice","飯":"rice","白飯":"rice", "麵":"noodles","拉麵":"noodles","麵條":"noodles","義大利麵":"noodles", "麵包":"bread","吐司":"bread", "雞肉":"chicken","雞胸":"chicken","烤雞":"chicken", "滷雞腿":"soy_braised_chicken_leg","醬油雞腿":"soy_braised_chicken_leg", "鮭魚":"salmon","三文魚":"salmon", "豬排":"pork_chop", "豆腐":"tofu", "蛋":"egg","水煮蛋":"egg","荷包蛋":"egg", "花椰菜":"broccoli","青花菜":"broccoli","菠菜":"spinach", "香蕉":"banana","味噌湯":"miso_soup", } RULES = {"T2DM": {"carb_g_per_meal_max": 60}, "HTN": {"sodium_mg_per_meal_max": 600}} PORTION_MUL = {"小":0.8, "中":1.0, "大":1.2} DEFAULT_BASE_G = 100 GENERIC_TO_CATEGORY = { "vegetable":"蔬菜類","vegetables":"蔬菜類","greens":"蔬菜類","salad":"蔬菜類", "meat":"豆魚蛋肉類","seafood":"豆魚蛋肉類","fish":"豆魚蛋肉類", "noodles":"全榖雜糧類","bread":"全榖雜糧類","rice":"全榖雜糧類", "soup":"湯品/飲品","drink":"湯品/飲品","beverage":"湯品/飲品" } # =============== 基本估算/規則 =============== def estimate_weight(name: str, plate_cm: int, portion: str) -> int: base = FOOD_DB.get(name, {}).get("base_g", DEFAULT_BASE_G) mul = PORTION_MUL.get(portion, 1.0) grams = int(base * mul * (plate_cm / 24)) return max(10, grams) def grams_to_nutrition(name: str, grams: int) -> Dict: info = FOOD_DB[name] ratio = grams / 100.0 out = {"name": name, "cat": info["cat"], "weight_g": grams, "tip": info.get("tip","")} for k in ("kcal","carb_g","protein_g","fat_g","sodium_mg"): out[k] = round(info[k] * ratio, 1) return out def make_placeholder_item(name: str, plate_cm: int, portion: str, cat: str = "未分類"): grams = int(DEFAULT_BASE_G * (plate_cm / 24) * PORTION_MUL.get(portion, 1.0)) return { "name": name, "cat": cat, "weight_g": grams, "kcal": "待新增資訊", "carb_g": "待新增資訊", "protein_g": "待新增資訊", "fat_g": "待新增資訊", "sodium_mg": "待新增資訊", "tip": "待新增資訊" } def eval_rules(items: List[Dict], conditions: List[str]): totals = {} for it in items: if isinstance(it.get("kcal"), (int, float)): for k in ("kcal","carb_g","protein_g","fat_g","sodium_mg"): totals[k] = round(totals.get(k,0) + float(it[k]), 1) advice = [] if "T2DM" in conditions and totals.get("carb_g",0) > RULES["T2DM"]["carb_g_per_meal_max"]: advice.append("【糖尿病】碳水偏高,建議主食減量或改全穀。") if "HTN" in conditions and totals.get("sodium_mg",0) > RULES["HTN"]["sodium_mg_per_meal_max"]: advice.append("【高血壓】鈉含量偏高,少鹽、避免重口味與滷味/湯品。") return totals, advice # =============== 主流程(級聯任務為主) =============== def run_pipeline(image, plate_cm, portion, conditions, dev_mode): if image is None: return "請先上傳一張照片。", "", [], {} if dev_mode: caption_text = "A more detailed description of a bento with white rice, broccoli, and grilled chicken thigh." grounded_labels = ["rice","broccoli","grilled chicken thigh"] else: caption_text, grounded_labels = florence2_cascade_food_labels(image) # 清洗 grounded labels → 只留下食物詞 labels = [] for lab in grounded_labels: name = lab.strip().lower() name = ALIASES.get(name, name) # 過濾一些明顯不是食物的字(可再擴充) if name in {"plate","table","box","tray","container","bento","white","filled","topped"}: continue labels.append(name) # 去重並保留順序 seen = set() labels_all = [] for n in labels: if n not in seen: labels_all.append(n); seen.add(n) # 生成逐項 items = [] for name in labels_all[:6]: if name in FOOD_DB: g = estimate_weight(name, plate_cm, portion) items.append(grams_to_nutrition(name, g)) else: cat = GENERIC_TO_CATEGORY.get(name, "未分類") items.append(make_placeholder_item(name, plate_cm, portion, cat=cat)) totals, advice = eval_rules([it for it in items if isinstance(it.get("kcal"), (int,float))], conditions) # 組輸出 lines = [f"模型輸出(More Detailed Caption):{caption_text}"] lines.append("偵測到(Grounding labels): " + (", ".join(labels_all) if labels_all else "(無)")) lines.append("") for it in items: kcal = it['kcal'] if isinstance(it['kcal'], (int, float)) else it['kcal'] carb = it['carb_g'] if isinstance(it['carb_g'], (int, float)) else it['carb_g'] prot = it['protein_g'] if isinstance(it['protein_g'], (int, float)) else it['protein_g'] fat = it['fat_g'] if isinstance(it['fat_g'], (int, float)) else it['fat_g'] na = it['sodium_mg'] if isinstance(it['sodium_mg'], (int, float)) else it['sodium_mg'] lines.append(f"- {it['name']} ({it['cat']}) {it['weight_g']} g → " f"{kcal} kcal, C{carb} g, P{prot} g, F{fat} g, Na{na} mg") if totals: lines.append("") lines.append(f"總計:{totals.get('kcal',0)} kcal,碳水 {totals.get('carb_g',0)} g,蛋白 {totals.get('protein_g',0)} g,脂肪 {totals.get('fat_g',0)} g,鈉 {totals.get('sodium_mg',0)} mg") if advice: lines.append("建議:" + " ".join(advice)) return "\n".join(lines), caption_text, items, totals # =============== Gradio 介面 =============== with gr.Blocks(title="FoodAI · Florence-2 (Cascade Grounding)") as demo: gr.Markdown("# 🍱 FoodAI · Florence-2 (More Detailed Caption + Grounding)\n以級聯任務抽食物詞 → 估營養與建議\n\n> 開發模式:不跑模型,固定假字串方便測流程。") with gr.Row(): with gr.Column(scale=1): img = gr.Image(type="pil", label="上傳圖片") plate = gr.Slider(18, 28, value=24, step=1, label="盤子直徑 (cm)") portion = gr.Radio(["小", "中", "大"], value="中", label="份量") cond = gr.CheckboxGroup(["T2DM", "HTN"], label="狀況") dev_mode = gr.Checkbox(label="開發模式(不跑模型)", value=False) btn = gr.Button("開始分析", variant="primary") with gr.Column(scale=1): out_md = gr.Markdown(label="結果") raw = gr.Textbox(label="模型原始輸出(More Detailed Caption)", lines=4) js = gr.JSON(label="逐項結果") total = gr.JSON(label="總計") btn.click(run_pipeline, inputs=[img, plate, portion, cond, dev_mode], outputs=[out_md, raw, js, total]) if __name__ == "__main__": PORT = int(os.getenv("PORT", "7860")) demo.launch(server_name="0.0.0.0", server_port=PORT)