Chao-Ying commited on
Commit
3aa784b
·
verified ·
1 Parent(s): c6b5eb9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +363 -0
app.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import types
4
+ import importlib.machinery
5
+ from typing import List, Dict
6
+
7
+ import gradio as gr
8
+ import torch
9
+ from PIL import Image
10
+
11
+ # =========================
12
+ # 1) 偽裝 flash_attn(避免硬相依)
13
+ # =========================
14
+ def _make_pkg_stub(fullname: str):
15
+ m = types.ModuleType(fullname)
16
+ m.__file__ = f"<stub {fullname}>"
17
+ m.__package__ = fullname.rpartition('.')[0]
18
+ m.__path__ = []
19
+ m.__spec__ = importlib.machinery.ModuleSpec(fullname, loader=None, is_package=True)
20
+ sys.modules[fullname] = m
21
+ return m
22
+
23
+ for name in [
24
+ "flash_attn",
25
+ "flash_attn.ops",
26
+ "flash_attn.layers",
27
+ "flash_attn.functional",
28
+ "flash_attn.bert_padding",
29
+ "flash_attn.flash_attn_interface",
30
+ ]:
31
+ if name not in sys.modules:
32
+ _make_pkg_stub(name)
33
+
34
+ # =========================
35
+ # 2) Florence-2 載入(eager + dtype 對齊 + 關 cache)
36
+ # =========================
37
+ from transformers import AutoProcessor, AutoModelForCausalLM
38
+
39
+ MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base")
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+ TASK_TOKENS = {
43
+ "caption": "<CAPTION>",
44
+ "object_detection": "<OBJECT_DETECTION>",
45
+ }
46
+
47
+ _processor = None
48
+ _model = None
49
+
50
+ def get_florence2():
51
+ global _processor, _model
52
+ if _processor is None or _model is None:
53
+ _processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
54
+ _model = AutoModelForCausalLM.from_pretrained(
55
+ MODEL_ID,
56
+ trust_remote_code=True,
57
+ attn_implementation="eager",
58
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
59
+ ).to(device).eval()
60
+ # 關掉快取,避免某些環境下 beam/cache 造成錯誤
61
+ _model.config.use_cache = False
62
+ return _processor, _model
63
+
64
+ @torch.inference_mode()
65
+ def florence2_text(image: Image.Image, task: str = "caption"):
66
+ proc, mdl = get_florence2()
67
+ token = TASK_TOKENS.get(task, "<CAPTION>")
68
+ text = token # 這兩個任務都是「不帶額外輸入」
69
+
70
+ batch = proc(text=text, images=image, return_tensors="pt")
71
+ inputs = {}
72
+ for k, v in batch.items():
73
+ if isinstance(v, torch.Tensor):
74
+ if v.is_floating_point():
75
+ inputs[k] = v.to(device=device, dtype=mdl.dtype)
76
+ else:
77
+ inputs[k] = v.to(device=device)
78
+ else:
79
+ inputs[k] = v
80
+
81
+ ids = mdl.generate(
82
+ **inputs,
83
+ max_new_tokens=128,
84
+ do_sample=False,
85
+ num_beams=1, # 用貪婪生成以提升穩定性
86
+ use_cache=False, # 關 cache 避免空 past_key_values 問題
87
+ early_stopping=False,
88
+ eos_token_id=getattr(getattr(proc, "tokenizer", None), "eos_token_id", None),
89
+ )
90
+ out = proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
91
+ if ">" in out:
92
+ out = out.split(">", 1)[-1].strip()
93
+ return out
94
+
95
+ # =========================
96
+ # 3) 營養資料 / 同義詞 / 規則
97
+ # =========================
98
+ FOOD_DB = {
99
+ "rice": {"kcal":130, "carb_g":28, "protein_g":2.4, "fat_g":0.3, "sodium_mg":0, "cat":"全榖雜糧類", "base_g":150, "tip":"主食可改糙米/全穀增加膳食纖維"},
100
+ "noodles":{"kcal":138, "carb_g":25, "protein_g":4.5, "fat_g":1.9, "sodium_mg":170, "cat":"全榖雜糧類", "base_g":180, "tip":"清湯少油,避免重鹹湯底"},
101
+ "bread": {"kcal":265, "carb_g":49, "protein_g":9.0, "fat_g":3.2, "sodium_mg":490, "cat":"全榖雜糧類", "base_g":60, "tip":"可選全麥減少抹醬、甜餡"},
102
+ "broccoli":{"kcal":35, "carb_g":7, "protein_g":2.4, "fat_g":0.4, "sodium_mg":33, "cat":"蔬菜類", "base_g":80, "tip":"川燙/清炒保留口感與維生素"},
103
+ "spinach":{"kcal":23, "carb_g":3.6,"protein_g":2.9,"fat_g":0.4,"sodium_mg":70, "cat":"蔬菜類", "base_g":80, "tip":"川燙後快炒,少鹽少油"},
104
+ "chicken":{"kcal":215,"carb_g":0, "protein_g":27, "fat_g":12, "sodium_mg":90, "cat":"豆魚蛋肉類", "base_g":120, "tip":"去皮烹調、烤/氣炸取代油炸"},
105
+ "soy_braised_chicken_leg":{"kcal":220,"carb_g":0,"protein_g":24,"fat_g":12,"sodium_mg":550,"cat":"豆魚蛋肉類","base_g":130,"tip":"減醬油與滷汁、可先汆燙再滷"},
106
+ "salmon":{"kcal":208,"carb_g":0, "protein_g":20, "fat_g":13, "sodium_mg":60, "cat":"豆魚蛋肉類", "base_g":120, "tip":"烤/蒸保留 Omega-3,少鹽少醬"},
107
+ "pork_chop":{"kcal":242,"carb_g":0,"protein_g":27,"fat_g":14,"sodium_mg":75, "cat":"豆魚蛋肉類", "base_g":120, "tip":"少裹粉油炸,改煎烤並瀝油"},
108
+ "tofu": {"kcal":76, "carb_g":1.9,"protein_g":8.1,"fat_g":4.8,"sodium_mg":7, "cat":"豆魚蛋肉類", "base_g":120, "tip":"少勾芡、少滷汁,清蒸清爽"},
109
+ "egg": {"kcal":155,"carb_g":1.1,"protein_g":13, "fat_g":11, "sodium_mg":124, "cat":"豆魚蛋肉類", "base_g":60, "tip":"水煮/荷包少油,避免重鹹醬料"},
110
+ "banana":{"kcal":89, "carb_g":23, "protein_g":1.1,"fat_g":0.3,"sodium_mg":1, "cat":"水果類", "base_g":100, "tip":"控制份量,避免一次過量"},
111
+ "miso_soup":{"kcal":36,"carb_g":4.3,"protein_g":2.0,"fat_g":1.3,"sodium_mg":550, "cat":"湯品/飲品", "base_g":200, "tip":"味噌湯偏鹹,建議少量品嚐"},
112
+ # 泛化:若你想讓 salad/fish 直接有數值,可打開下列兩筆
113
+ # "salad": {"kcal":30,"carb_g":5,"protein_g":1.5,"fat_g":0.5,"sodium_mg":40,"cat":"蔬菜類","base_g":100,"tip":"少醬少油,優先清爽調味"},
114
+ # "fish": {"kcal":170,"carb_g":0,"protein_g":22,"fat_g":8,"sodium_mg":70,"cat":"豆魚蛋肉類","base_g":120,"tip":"蒸/烤/煎少油,避免重鹹醬汁"},
115
+ }
116
+
117
+ ALIASES = {
118
+ "white rice":"rice","steamed rice":"rice","飯":"rice","白飯":"rice",
119
+ "麵":"noodles","拉麵":"noodles","麵條":"noodles","義大利麵":"noodles",
120
+ "麵包":"bread","吐司":"bread",
121
+ "雞肉":"chicken","雞胸":"chicken","烤雞":"chicken",
122
+ "滷雞腿":"soy_braised_chicken_leg","醬油雞腿":"soy_braised_chicken_leg",
123
+ "鮭魚":"salmon","三文魚":"salmon",
124
+ "豬排":"pork_chop",
125
+ "豆腐":"tofu",
126
+ "蛋":"egg","水煮蛋":"egg","荷包蛋":"egg",
127
+ "花椰菜":"broccoli","青花菜":"broccoli","菠菜":"spinach",
128
+ "香蕉":"banana","味噌湯":"miso_soup",
129
+ }
130
+
131
+ RULES = {"T2DM": {"carb_g_per_meal_max": 60}, "HTN": {"sodium_mg_per_meal_max": 600}}
132
+ PORTION_MUL = {"小":0.8, "中":1.0, "大":1.2}
133
+ DEFAULT_BASE_G = 100
134
+
135
+ # 類別對應(泛稱 → 類別)
136
+ GENERIC_TO_CATEGORY = {
137
+ "vegetable":"蔬菜類","vegetables":"蔬菜類","greens":"蔬菜類","salad":"蔬菜類",
138
+ "meat":"豆魚蛋肉類","seafood":"豆魚蛋肉類","fish":"豆魚蛋肉類",
139
+ "noodles":"全榖雜糧類","bread":"全榖雜糧類","rice":"全榖雜糧類",
140
+ "soup":"湯品/飲品","drink":"湯品/飲品","beverage":"湯品/飲品"
141
+ }
142
+
143
+ # =========================
144
+ # 4) 文本抽詞(偵測文字 → 食物詞)
145
+ # =========================
146
+ import re
147
+
148
+ # 砍掉「環境尾巴」:on (top of) a table / tray / desk ...
149
+ ENV_TAIL = re.compile(
150
+ r"\b(on\s+(?:top\s+of\s+)?(?:a|the)?\s*(?:table|tray|desk|counter|tabletop))\b.*$",
151
+ flags=re.I
152
+ )
153
+
154
+ # 停用詞/顏色/器皿/常見形容詞
155
+ STOPWORDS = {
156
+ "a","an","the","with","and","of","on","in","to","served","over","side","sides",
157
+ "set","dish","meal","mixed","assorted","fresh","hot","cold","topped","style","seasoned",
158
+ # 中文
159
+ "便當","套餐","一盤","一碗","配菜","附餐","湯","沙拉","醬","佐","搭配","附","拌","炒","滷","炸","烤","蒸","煮"
160
+ }
161
+ COLOR_WORDS = {"white","black","red","green","yellow","orange","brown","purple","pink","golden"}
162
+ UTENSILS = {"plate","bowl","tray","box","cup","glass","container","table","desk","counter","tabletop"}
163
+ ADJ_MISC = {"filled","placed","served","topped","layered","mixed","assorted","piece","slice","fillet","serving"}
164
+
165
+ FOOD_LIKE = {
166
+ "salad","vegetable","vegetables","greens","meat","seafood","fish",
167
+ "chicken","beef","pork","shrimp","tofu","egg",
168
+ "rice","noodles","bread","soup","fruit","fruits"
169
+ }
170
+
171
+ def detect_foods_from_text(text: str) -> List[str]:
172
+ lower = text.lower()
173
+ labels = set()
174
+ for k in FOOD_DB.keys():
175
+ if k in lower:
176
+ labels.add(k)
177
+ for alias, key in ALIASES.items():
178
+ if alias in text or alias.lower() in lower:
179
+ labels.add(key)
180
+ return list(labels)
181
+
182
+ def extract_food_terms_free(text: str) -> List[str]:
183
+ """
184
+ 從偵測/描述文字中抽食物詞(允許未知):
185
+ - 砍掉環境尾巴(on top of a table ...)
186
+ - 切片(, ; . and with),過濾顏色/器皿/形容詞/停用詞
187
+ - 取片尾名詞;再補 FOOD_LIKE 名詞掃描
188
+ - alias 映射 → 主鍵;未知則保留原字
189
+ """
190
+ t = text.strip().lower()
191
+ t = ENV_TAIL.sub("", t)
192
+
193
+ hits = set()
194
+
195
+ # 解析「X of Y」→ 優先抓 Y(e.g., piece of fish → fish)
196
+ for pat in [r"(?:piece|slice|fillet|serving)\s+of\s+([a-z\u4e00-\u9fff]+)"]:
197
+ for m in re.findall(pat, t, flags=re.I):
198
+ y = m.strip()
199
+ if y in COLOR_WORDS or y in UTENSILS or y in ADJ_MISC or y in STOPWORDS:
200
+ continue
201
+ hits.add(ALIASES.get(y, y))
202
+
203
+ parts = re.split(r"(?:,|;|\.|\band\b|\bwith\b|\n)+", t, flags=re.I)
204
+ for p in parts:
205
+ if not p:
206
+ continue
207
+ toks = re.findall(r"[a-z\u4e00-\u9fff]+", p)
208
+ toks = [
209
+ w for w in toks
210
+ if w not in COLOR_WORDS
211
+ and w not in UTENSILS
212
+ and w not in ADJ_MISC
213
+ and w not in STOPWORDS
214
+ and len(w) >= 2
215
+ ]
216
+ if not toks:
217
+ continue
218
+ head = toks[-1]
219
+ hits.add(ALIASES.get(head, head))
220
+
221
+ for w in FOOD_LIKE:
222
+ if re.search(rf"\b{re.escape(w)}\b", t):
223
+ hits.add(ALIASES.get(w, w))
224
+
225
+ return list(hits)
226
+
227
+ # =========================
228
+ # 5) 估重 / 營養 / 規則
229
+ # =========================
230
+ def estimate_weight(name: str, plate_cm: int, portion: str) -> int:
231
+ base = FOOD_DB.get(name, {}).get("base_g", DEFAULT_BASE_G)
232
+ mul = PORTION_MUL.get(portion, 1.0)
233
+ grams = int(base * mul * (plate_cm / 24))
234
+ return max(10, grams)
235
+
236
+ def grams_to_nutrition(name: str, grams: int) -> Dict:
237
+ info = FOOD_DB[name]
238
+ ratio = grams / 100.0
239
+ out = {"name": name, "cat": info["cat"], "weight_g": grams, "tip": info.get("tip","")}
240
+ for k in ("kcal","carb_g","protein_g","fat_g","sodium_mg"):
241
+ out[k] = round(info[k] * ratio, 1)
242
+ return out
243
+
244
+ def make_placeholder_item(name: str, plate_cm: int, portion: str, cat: str = "未分類"):
245
+ grams = int(DEFAULT_BASE_G * (plate_cm / 24) * PORTION_MUL.get(portion, 1.0))
246
+ return {
247
+ "name": name, "cat": cat, "weight_g": grams,
248
+ "kcal": "待新增資訊", "carb_g": "待新增資訊", "protein_g": "待新增資訊",
249
+ "fat_g": "待新增資訊", "sodium_mg": "待新增資訊", "tip": "待新增資訊"
250
+ }
251
+
252
+ def eval_rules(items: List[Dict], conditions: List[str]):
253
+ totals = {}
254
+ for it in items:
255
+ if isinstance(it.get("kcal"), (int, float)):
256
+ for k in ("kcal","carb_g","protein_g","fat_g","sodium_mg"):
257
+ totals[k] = round(totals.get(k,0) + float(it[k]), 1)
258
+ advice = []
259
+ if "T2DM" in conditions and totals.get("carb_g",0) > RULES["T2DM"]["carb_g_per_meal_max"]:
260
+ advice.append("【糖尿病】碳水偏高,建議主食減量或改全穀。")
261
+ if "HTN" in conditions and totals.get("sodium_mg",0) > RULES["HTN"]["sodium_mg_per_meal_max"]:
262
+ advice.append("【高血壓】鈉含量偏高,少鹽、避免重口味與滷味/湯品。")
263
+ cats = {}
264
+ for it in items:
265
+ cats[it["cat"]] = cats.get(it["cat"], 0) + 1
266
+ return totals, advice, cats
267
+
268
+ # =========================
269
+ # 6) Pipeline(偵測為主,Caption 顯示)
270
+ # =========================
271
+ def run_pipeline(image, plate_cm, portion, conditions, task_mode, dev_mode):
272
+ if image is None:
273
+ return "請先上傳一張照片。", "", [], {}
274
+
275
+ # 先用偵測任務決定清單來源
276
+ if dev_mode:
277
+ det_txt = "rice, vegetables, grilled chicken"
278
+ else:
279
+ det_txt = florence2_text(image, task="object_detection")
280
+
281
+ # 再跑 caption 只用來顯示(不影響清單)
282
+ if dev_mode:
283
+ cap_txt = "A bento with white rice, broccoli and grilled chicken thigh."
284
+ else:
285
+ cap_txt = florence2_text(image, task="caption")
286
+
287
+ src_text = det_txt # 清單來源固定用偵測文字
288
+ labels_known = detect_foods_from_text(src_text)
289
+ labels_free = extract_food_terms_free(src_text)
290
+
291
+ labels_all, seen = [], set()
292
+ for term in labels_free + labels_known:
293
+ key = ALIASES.get(term, term)
294
+ if key not in seen:
295
+ labels_all.append(key); seen.add(key)
296
+
297
+ items = []
298
+ for name in labels_all[:6]:
299
+ if name in FOOD_DB:
300
+ g = estimate_weight(name, plate_cm, portion)
301
+ items.append(grams_to_nutrition(name, g))
302
+ else:
303
+ cat = GENERIC_TO_CATEGORY.get(name, "未分類")
304
+ items.append(make_placeholder_item(name, plate_cm, portion, cat=cat))
305
+
306
+ totals, advice, cats = eval_rules([it for it in items if isinstance(it.get("kcal"), (int,float))], conditions)
307
+
308
+ # 組輸出:顯示偵測 + 描述;清單以偵測為準
309
+ labels_display = [it["name"] for it in items]
310
+ lines = [
311
+ f"模型輸出(偵測):{det_txt}",
312
+ f"模型輸出(描述):{cap_txt}",
313
+ ""
314
+ ]
315
+ if labels_display:
316
+ lines.append("偵測到: " + ", ".join(labels_display))
317
+ else:
318
+ lines.append("偵測到: (無)")
319
+ lines.append("")
320
+ for it in items:
321
+ kcal = it['kcal'] if isinstance(it['kcal'], (int, float)) else it['kcal']
322
+ carb = it['carb_g'] if isinstance(it['carb_g'], (int, float)) else it['carb_g']
323
+ prot = it['protein_g'] if isinstance(it['protein_g'], (int, float)) else it['protein_g']
324
+ fat = it['fat_g'] if isinstance(it['fat_g'], (int, float)) else it['fat_g']
325
+ na = it['sodium_mg'] if isinstance(it['sodium_mg'], (int, float)) else it['sodium_mg']
326
+ lines.append(f"- {it['name']} ({it['cat']}) {it['weight_g']} g → "
327
+ f"{kcal} kcal, C{carb} g, P{prot} g, F{fat} g, Na{na} mg")
328
+
329
+ if totals:
330
+ lines.append("")
331
+ lines.append(f"總計:{totals.get('kcal',0)} kcal,碳水 {totals.get('carb_g',0)} g,蛋白 {totals.get('protein_g',0)} g,脂肪 {totals.get('fat_g',0)} g,鈉 {totals.get('sodium_mg',0)} mg")
332
+ if advice:
333
+ lines.append("建議:" + " ".join(advice))
334
+
335
+ # 「模型原始輸出」欄位顯示 caption(較好讀)
336
+ return "\n".join(lines), cap_txt, items, totals
337
+
338
+ # =========================
339
+ # 7) Gradio 介面
340
+ # =========================
341
+ with gr.Blocks(title="FoodAI · Florence-2 Demo") as demo:
342
+ gr.Markdown("# 🍱 FoodAI · Florence-2 Demo\n上傳餐點 → 偵測(主)/描述(輔) → 估營養/建議\n\n> 開發模式:不跑模型,固定假字串���便測試 UI/流程。")
343
+ with gr.Row():
344
+ with gr.Column(scale=1):
345
+ img = gr.Image(type="pil", label="上傳圖片")
346
+ plate = gr.Slider(18, 28, value=24, step=1, label="盤子直徑 (cm)")
347
+ portion = gr.Radio(["小", "中", "大"], value="中", label="份量")
348
+ cond = gr.CheckboxGroup(["T2DM", "HTN"], label="狀況")
349
+ # 預設改為「偵測」
350
+ task_mode = gr.Radio(["描述 (Caption)", "偵測 (Object Detection)"], value="偵測 (Object Detection)", label="任務")
351
+ dev_mode = gr.Checkbox(label="開發模式(不跑模型)", value=False)
352
+ btn = gr.Button("開始分析", variant="primary")
353
+ with gr.Column(scale=1):
354
+ out_md = gr.Markdown(label="結果")
355
+ raw = gr.Textbox(label="模型原始輸出(Caption)", lines=4)
356
+ js = gr.JSON(label="逐項結果")
357
+ total = gr.JSON(label="總計")
358
+
359
+ btn.click(run_pipeline, inputs=[img, plate, portion, cond, task_mode, dev_mode], outputs=[out_md, raw, js, total])
360
+
361
+ if __name__ == "__main__":
362
+ PORT = int(os.getenv("PORT", "7860"))
363
+ demo.launch(server_name="0.0.0.0", server_port=PORT)