Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import types | |
| import importlib.machinery | |
| from typing import List, Dict | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| # =============== 避免 flash_attn 強相依(不安裝它) =============== | |
| def _make_pkg_stub(fullname: str): | |
| m = types.ModuleType(fullname) | |
| m.__file__ = f"<stub {fullname}>" | |
| m.__package__ = fullname.rpartition('.')[0] | |
| m.__path__ = [] | |
| m.__spec__ = importlib.machinery.ModuleSpec(fullname, loader=None, is_package=True) | |
| sys.modules[fullname] = m | |
| return m | |
| for name in [ | |
| "flash_attn","flash_attn.ops","flash_attn.layers", | |
| "flash_attn.functional","flash_attn.bert_padding","flash_attn.flash_attn_interface", | |
| ]: | |
| if name not in sys.modules: | |
| _make_pkg_stub(name) | |
| # =============== Florence-2 載入(eager + 關 cache) =============== | |
| MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| _processor = None | |
| _model = None | |
| def get_florence2(): | |
| global _processor, _model | |
| if _processor is None or _model is None: | |
| _processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| attn_implementation="eager", | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ).to(device).eval() | |
| _model.config.use_cache = False | |
| return _processor, _model | |
| def _generate_text_and_parse(image: Image.Image, task_token: str, text_input: str | None = None): | |
| """ | |
| 回傳: | |
| - 對 caption 任務:{'<MORE_DETAILED_CAPTION>': '...'} 這類 | |
| - 對 grounding 任務:{'<CAPTION_TO_PHRASE_GROUNDING>': {'bboxes': [...], 'labels': [...]}} | |
| """ | |
| proc, mdl = get_florence2() | |
| text = task_token if text_input is None else (task_token + text_input) | |
| batch = proc(text=text, images=image, return_tensors="pt") | |
| # 對齊 dtype/device | |
| inputs = {} | |
| for k, v in batch.items(): | |
| if isinstance(v, torch.Tensor): | |
| inputs[k] = v.to(device=device, dtype=mdl.dtype if v.is_floating_point() else None) | |
| else: | |
| inputs[k] = v | |
| ids = mdl.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| num_beams=1, # 貪婪,跨環境最穩 | |
| use_cache=False, | |
| early_stopping=False, | |
| eos_token_id=getattr(getattr(proc, "tokenizer", None), "eos_token_id", None), | |
| ) | |
| gen = proc.batch_decode(ids, skip_special_tokens=False)[0] | |
| parsed = proc.post_process_generation( | |
| gen, task=task_token, image_size=(image.width, image.height) | |
| ) | |
| return parsed | |
| def florence2_cascade_food_labels(image: Image.Image): | |
| """ | |
| 級聯路徑: | |
| 1) <MORE_DETAILED_CAPTION> → caption_text | |
| 2) <CAPTION_TO_PHRASE_GROUNDING>(caption_text) → labels(我們拿來當食物候選) | |
| """ | |
| # step1: 更詳細 caption | |
| cap_res = _generate_text_and_parse(image, "<MORE_DETAILED_CAPTION>") | |
| caption_text = cap_res.get("<MORE_DETAILED_CAPTION>", "") | |
| # step2: grounding(把 caption 各片語對齊到框與標籤) | |
| grd_res = _generate_text_and_parse(image, "<CAPTION_TO_PHRASE_GROUNDING>", text_input=caption_text) | |
| grounding = grd_res.get("<CAPTION_TO_PHRASE_GROUNDING>", {}) | |
| labels = grounding.get("labels", []) or [] | |
| # labels 是一串字串(可能含形容詞),我們後續再做 alias 過濾 | |
| return caption_text, labels | |
| # =============== 營養資料 / 同義詞 / 規則 =============== | |
| FOOD_DB = { | |
| "rice": {"kcal":130, "carb_g":28, "protein_g":2.4, "fat_g":0.3, "sodium_mg":0, "cat":"全榖雜糧類", "base_g":150, "tip":"主食可改糙米/全穀增加膳食纖維"}, | |
| "noodles":{"kcal":138, "carb_g":25, "protein_g":4.5, "fat_g":1.9, "sodium_mg":170, "cat":"全榖雜糧類", "base_g":180, "tip":"清湯少油,避免重鹹湯底"}, | |
| "bread": {"kcal":265, "carb_g":49, "protein_g":9.0, "fat_g":3.2, "sodium_mg":490, "cat":"全榖雜糧類", "base_g":60, "tip":"可選全麥減少抹醬、甜餡"}, | |
| "broccoli":{"kcal":35, "carb_g":7, "protein_g":2.4, "fat_g":0.4, "sodium_mg":33, "cat":"蔬菜類", "base_g":80, "tip":"川燙/清炒保留口感與維生素"}, | |
| "spinach":{"kcal":23, "carb_g":3.6,"protein_g":2.9,"fat_g":0.4,"sodium_mg":70, "cat":"蔬菜類", "base_g":80, "tip":"川燙後快炒,少鹽少油"}, | |
| "chicken":{"kcal":215,"carb_g":0, "protein_g":27, "fat_g":12, "sodium_mg":90, "cat":"豆魚蛋肉類", "base_g":120, "tip":"去皮烹調、烤/氣炸取代油炸"}, | |
| "soy_braised_chicken_leg":{"kcal":220,"carb_g":0,"protein_g":24,"fat_g":12,"sodium_mg":550,"cat":"豆魚蛋肉類","base_g":130,"tip":"減醬油與滷汁、可先汆燙再滷"}, | |
| "salmon":{"kcal":208,"carb_g":0, "protein_g":20, "fat_g":13, "sodium_mg":60, "cat":"豆魚蛋肉類", "base_g":120, "tip":"烤/蒸保留 Omega-3,少鹽少醬"}, | |
| "pork_chop":{"kcal":242,"carb_g":0,"protein_g":27,"fat_g":14,"sodium_mg":75, "cat":"豆魚蛋肉類", "base_g":120, "tip":"少裹粉油炸,改煎烤並瀝油"}, | |
| "tofu": {"kcal":76, "carb_g":1.9,"protein_g":8.1,"fat_g":4.8,"sodium_mg":7, "cat":"豆魚蛋肉類", "base_g":120, "tip":"少勾芡、少滷汁,清蒸清爽"}, | |
| "egg": {"kcal":155,"carb_g":1.1,"protein_g":13, "fat_g":11, "sodium_mg":124, "cat":"豆魚蛋肉類", "base_g":60, "tip":"水煮/荷包少油,避免重鹹醬料"}, | |
| "banana":{"kcal":89, "carb_g":23, "protein_g":1.1,"fat_g":0.3,"sodium_mg":1, "cat":"水果類", "base_g":100, "tip":"控制份量,避免一次過量"}, | |
| "miso_soup":{"kcal":36,"carb_g":4.3,"protein_g":2.0,"fat_g":1.3,"sodium_mg":550, "cat":"湯品/飲品", "base_g":200, "tip":"味噌湯偏鹹,建議少量品嚐"}, | |
| # 想開放泛化兩筆可解除註解: | |
| # "salad": {"kcal":30,"carb_g":5,"protein_g":1.5,"fat_g":0.5,"sodium_mg":40,"cat":"蔬菜類","base_g":100,"tip":"少醬少油,優先清爽調味"}, | |
| # "fish": {"kcal":170,"carb_g":0,"protein_g":22,"fat_g":8,"sodium_mg":70,"cat":"豆魚蛋肉類","base_g":120,"tip":"蒸/烤/煎少油,避免重鹹醬汁"}, | |
| } | |
| ALIASES = { | |
| "white rice":"rice","steamed rice":"rice","飯":"rice","白飯":"rice", | |
| "麵":"noodles","拉麵":"noodles","麵條":"noodles","義大利麵":"noodles", | |
| "麵包":"bread","吐司":"bread", | |
| "雞肉":"chicken","雞胸":"chicken","烤雞":"chicken", | |
| "滷雞腿":"soy_braised_chicken_leg","醬油雞腿":"soy_braised_chicken_leg", | |
| "鮭魚":"salmon","三文魚":"salmon", | |
| "豬排":"pork_chop", | |
| "豆腐":"tofu", | |
| "蛋":"egg","水煮蛋":"egg","荷包蛋":"egg", | |
| "花椰菜":"broccoli","青花菜":"broccoli","菠菜":"spinach", | |
| "香蕉":"banana","味噌湯":"miso_soup", | |
| } | |
| RULES = {"T2DM": {"carb_g_per_meal_max": 60}, "HTN": {"sodium_mg_per_meal_max": 600}} | |
| PORTION_MUL = {"小":0.8, "中":1.0, "大":1.2} | |
| DEFAULT_BASE_G = 100 | |
| GENERIC_TO_CATEGORY = { | |
| "vegetable":"蔬菜類","vegetables":"蔬菜類","greens":"蔬菜類","salad":"蔬菜類", | |
| "meat":"豆魚蛋肉類","seafood":"豆魚蛋肉類","fish":"豆魚蛋肉類", | |
| "noodles":"全榖雜糧類","bread":"全榖雜糧類","rice":"全榖雜糧類", | |
| "soup":"湯品/飲品","drink":"湯品/飲品","beverage":"湯品/飲品" | |
| } | |
| # =============== 基本估算/規則 =============== | |
| def estimate_weight(name: str, plate_cm: int, portion: str) -> int: | |
| base = FOOD_DB.get(name, {}).get("base_g", DEFAULT_BASE_G) | |
| mul = PORTION_MUL.get(portion, 1.0) | |
| grams = int(base * mul * (plate_cm / 24)) | |
| return max(10, grams) | |
| def grams_to_nutrition(name: str, grams: int) -> Dict: | |
| info = FOOD_DB[name] | |
| ratio = grams / 100.0 | |
| out = {"name": name, "cat": info["cat"], "weight_g": grams, "tip": info.get("tip","")} | |
| for k in ("kcal","carb_g","protein_g","fat_g","sodium_mg"): | |
| out[k] = round(info[k] * ratio, 1) | |
| return out | |
| def make_placeholder_item(name: str, plate_cm: int, portion: str, cat: str = "未分類"): | |
| grams = int(DEFAULT_BASE_G * (plate_cm / 24) * PORTION_MUL.get(portion, 1.0)) | |
| return { | |
| "name": name, "cat": cat, "weight_g": grams, | |
| "kcal": "待新增資訊", "carb_g": "待新增資訊", "protein_g": "待新增資訊", | |
| "fat_g": "待新增資訊", "sodium_mg": "待新增資訊", "tip": "待新增資訊" | |
| } | |
| def eval_rules(items: List[Dict], conditions: List[str]): | |
| totals = {} | |
| for it in items: | |
| if isinstance(it.get("kcal"), (int, float)): | |
| for k in ("kcal","carb_g","protein_g","fat_g","sodium_mg"): | |
| totals[k] = round(totals.get(k,0) + float(it[k]), 1) | |
| advice = [] | |
| if "T2DM" in conditions and totals.get("carb_g",0) > RULES["T2DM"]["carb_g_per_meal_max"]: | |
| advice.append("【糖尿病】碳水偏高,建議主食減量或改全穀。") | |
| if "HTN" in conditions and totals.get("sodium_mg",0) > RULES["HTN"]["sodium_mg_per_meal_max"]: | |
| advice.append("【高血壓】鈉含量偏高,少鹽、避免重口味與滷味/湯品。") | |
| return totals, advice | |
| # =============== 主流程(級聯任務為主) =============== | |
| def run_pipeline(image, plate_cm, portion, conditions, dev_mode): | |
| if image is None: | |
| return "請先上傳一張照片。", "", [], {} | |
| if dev_mode: | |
| caption_text = "A more detailed description of a bento with white rice, broccoli, and grilled chicken thigh." | |
| grounded_labels = ["rice","broccoli","grilled chicken thigh"] | |
| else: | |
| caption_text, grounded_labels = florence2_cascade_food_labels(image) | |
| # 清洗 grounded labels → 只留下食物詞 | |
| labels = [] | |
| for lab in grounded_labels: | |
| name = lab.strip().lower() | |
| name = ALIASES.get(name, name) | |
| # 過濾一些明顯不是食物的字(可再擴充) | |
| if name in {"plate","table","box","tray","container","bento","white","filled","topped"}: | |
| continue | |
| labels.append(name) | |
| # 去重並保留順序 | |
| seen = set() | |
| labels_all = [] | |
| for n in labels: | |
| if n not in seen: | |
| labels_all.append(n); seen.add(n) | |
| # 生成逐項 | |
| items = [] | |
| for name in labels_all[:6]: | |
| if name in FOOD_DB: | |
| g = estimate_weight(name, plate_cm, portion) | |
| items.append(grams_to_nutrition(name, g)) | |
| else: | |
| cat = GENERIC_TO_CATEGORY.get(name, "未分類") | |
| items.append(make_placeholder_item(name, plate_cm, portion, cat=cat)) | |
| totals, advice = eval_rules([it for it in items if isinstance(it.get("kcal"), (int,float))], conditions) | |
| # 組輸出 | |
| lines = [f"模型輸出(More Detailed Caption):{caption_text}"] | |
| lines.append("偵測到(Grounding labels): " + (", ".join(labels_all) if labels_all else "(無)")) | |
| lines.append("") | |
| for it in items: | |
| kcal = it['kcal'] if isinstance(it['kcal'], (int, float)) else it['kcal'] | |
| carb = it['carb_g'] if isinstance(it['carb_g'], (int, float)) else it['carb_g'] | |
| prot = it['protein_g'] if isinstance(it['protein_g'], (int, float)) else it['protein_g'] | |
| fat = it['fat_g'] if isinstance(it['fat_g'], (int, float)) else it['fat_g'] | |
| na = it['sodium_mg'] if isinstance(it['sodium_mg'], (int, float)) else it['sodium_mg'] | |
| lines.append(f"- {it['name']} ({it['cat']}) {it['weight_g']} g → " | |
| f"{kcal} kcal, C{carb} g, P{prot} g, F{fat} g, Na{na} mg") | |
| if totals: | |
| lines.append("") | |
| lines.append(f"總計:{totals.get('kcal',0)} kcal,碳水 {totals.get('carb_g',0)} g,蛋白 {totals.get('protein_g',0)} g,脂肪 {totals.get('fat_g',0)} g,鈉 {totals.get('sodium_mg',0)} mg") | |
| if advice: | |
| lines.append("建議:" + " ".join(advice)) | |
| return "\n".join(lines), caption_text, items, totals | |
| # =============== Gradio 介面 =============== | |
| with gr.Blocks(title="FoodAI · Florence-2 (Cascade Grounding)") as demo: | |
| gr.Markdown("# 🍱 FoodAI · Florence-2 (More Detailed Caption + Grounding)\n以級聯任務抽食物詞 → 估營養與建議\n\n> 開發模式:不跑模型,固定假字串方便測流程。") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img = gr.Image(type="pil", label="上傳圖片") | |
| plate = gr.Slider(18, 28, value=24, step=1, label="盤子直徑 (cm)") | |
| portion = gr.Radio(["小", "中", "大"], value="中", label="份量") | |
| cond = gr.CheckboxGroup(["T2DM", "HTN"], label="狀況") | |
| dev_mode = gr.Checkbox(label="開發模式(不跑模型)", value=False) | |
| btn = gr.Button("開始分析", variant="primary") | |
| with gr.Column(scale=1): | |
| out_md = gr.Markdown(label="結果") | |
| raw = gr.Textbox(label="模型原始輸出(More Detailed Caption)", lines=4) | |
| js = gr.JSON(label="逐項結果") | |
| total = gr.JSON(label="總計") | |
| btn.click(run_pipeline, inputs=[img, plate, portion, cond, dev_mode], outputs=[out_md, raw, js, total]) | |
| if __name__ == "__main__": | |
| PORT = int(os.getenv("PORT", "7860")) | |
| demo.launch(server_name="0.0.0.0", server_port=PORT) | |