File size: 6,399 Bytes
26ab04f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# app.py - Gradio demo for testing your image-classification model (Food-101)
# - Replace FOOD101_MODEL_ID with your model repo id (e.g., "yourname/your-food-model")
# - If your model is private, add HF_TOKEN as a secret in your Space (or set env var)

import os
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
import gradio as gr

# ---- CONFIG: 改成你的 model id ----
Tell_Me_Recipe = os.environ.get("Tell_Me_Recipe", "YOUR_USERNAME/YOUR_FOOD101_MODEL")
# optional gate model (food / not-food). Change or set to None to skip.
GATE_MODEL_ID = os.environ.get("GATE_MODEL_ID", "prithivMLmods/Food-or-Not-SigLIP2")
HF_TOKEN = os.environ.get("HF_TOKEN")  # 用於 private 模型

# ---- device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- load models (一次載入) ----
def load_model(model_id):
    if model_id is None or model_id.strip() == "":
        return None, None
    try:
        processor = AutoImageProcessor.from_pretrained(model_id, use_auth_token=HF_TOKEN)
        model = AutoModelForImageClassification.from_pretrained(model_id, use_auth_token=HF_TOKEN)
        model.to(device)
        model.eval()
        return processor, model
    except Exception as e:
        print(f"Failed to load {model_id}: {e}")
        raise

print("Loading main classification model:", Tell_Me_Recipe)
clf_processor, clf_model = load_model(Tell_Me_Recipe)

if GATE_MODEL_ID:
    try:
        print("Loading gate model (food vs not-food):", GATE_MODEL_ID)
        gate_processor, gate_model = load_model(GATE_MODEL_ID)
    except Exception as e:
        print("Gate model load failed — continuing without gate:", e)
        gate_processor, gate_model = None, None
else:
    gate_processor, gate_model = None, None

# ---- small recipe template mapping (示範用) ----
RECIPES = {
    "pizza": {
        "ingredients": ["高筋麵粉 300g", "水 180ml", "酵母 3g", "番茄醬", "Mozzarella 起司"],
        "steps": ["揉麵發酵", "桿皮抹醬加起司配料", "220°C 烤 10-15 分鐘"]
    },
    "ramen": {
        "ingredients": ["中華麵", "高湯 400ml", "醬油", "叉燒", "溏心蛋"],
        "steps": ["煮麵 / 熬高湯 / 擺盤"]
    },
    "cheesecake": {
        "ingredients": ["奶油乳酪 200g", "蛋 2 顆", "砂糖 60g", "消化餅底"],
        "steps": ["餅乾壓底 / 乳酪餡打勻倒入 / 160°C 烤 40-50 分鐘"]
    }
}

def simple_recipe_for(label: str):
    key = label.replace("_", " ").lower()
    for k in RECIPES:
        if k in key:
            r = RECIPES[k]
            ing = "\n".join(f"- {i}" for i in r["ingredients"])
            steps = "\n".join(f"{idx+1}. {s}" for idx, s in enumerate(r["steps"]))
            return f"【材料】\n{ing}\n\n【步驟】\n{steps}"
    return f"找不到精確食譜。模型預測:{label.replace('_',' ')}。\n建議搜尋該菜名的食譜或連到 RecipeNLG 做檢索。"

# ---- inference helpers ----
@torch.inference_mode()
def is_food_image(image: Image.Image, threshold: float = 0.5):
    """如果有 gate_model,回傳 (is_food:bool, score:float, label:str)"""
    if gate_model is None or gate_processor is None:
        return True, 1.0, "no-gate"
    inputs = gate_processor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    out = gate_model(**inputs)
    probs = out.logits.softmax(-1).squeeze(0)
    topv, topi = torch.max(probs, dim=-1)
    label = gate_model.config.id2label[int(topi)]
    # 假設 gate 的 label 包含 'not' 或 'not_food' 來表現非食物
    not_food_names = ["not-food", "not_food", "not food", "notfood", "no-food"]
    is_food = True
    if any(n in label.lower() for n in not_food_names):
        is_food = False
    return is_food, float(topv), label

@torch.inference_mode()
def predict_label(image: Image.Image, topk: int = 3):
    inputs = clf_processor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    out = clf_model(**inputs)
    probs = out.logits.softmax(-1).squeeze(0)
    topv, topi = torch.topk(probs, k=min(topk, probs.shape[0]))
    labels = [clf_model.config.id2label[int(i)] for i in topi]
    return [(l.replace("_"," "), float(v)) for l, v in zip(labels, topv)]

# ---- gradio UI function ----
def analyze_image(image: Image.Image, topk: int=3, gate_threshold: float=0.5, use_gate: bool=True):
    if image is None:
        return "請上傳圖片", "", "", ""
    try:
        # 1) gate (optional)
        if use_gate and gate_model is not None:
            is_food, score, gate_label = is_food_image(image, gate_threshold)
            if not is_food and score >= gate_threshold:
                return f"判斷:非食物({gate_label},score={score:.2f})", "", "", "這張圖被判定為「非食物」,不做菜名預測。"
        # 2) predict top-k
        preds = predict_label(image, topk=topk)
        topk_text = "\n".join([f"{lbl}{p:.3f}" for lbl,p in preds])
        best_label = preds[0][0]
        recipe_txt = simple_recipe_for(best_label)
        return f"判斷:是食物(gate ok)", best_label, topk_text, recipe_txt
    except Exception as e:
        return f"發生錯誤:{e}", "", "", ""

# ---- build UI ----
with gr.Blocks() as demo:
    gr.Markdown("# Food → 菜名 + 簡易食譜 Demo")
    with gr.Row():
        img = gr.Image(type="pil", label="上傳一張照片(任意圖片)")
        with gr.Column():
            topk = gr.Slider(1, 10, value=3, step=1, label="Top-K")
            gate_check = gr.Checkbox(value=True, label="使用 food vs not-food gate(建議開)")
            gate_th = gr.Slider(0.0, 1.0, value=0.5, label="Gate threshold")
            run_btn = gr.Button("分析照片")
    with gr.Row():
        out0 = gr.Textbox(label="是否為食物 / Gate 訊息", lines=1)
    with gr.Row():
        out1 = gr.Textbox(label="預測菜名(Top 1)", lines=1)
    with gr.Row():
        out2 = gr.Textbox(label="Top-K 預測 & 機率", lines=6)
    with gr.Row():
        out3 = gr.Textbox(label="自動回傳的簡易食譜(示範)", lines=12)
    run_btn.click(fn=analyze_image, inputs=[img, topk, gate_th, gate_check], outputs=[out0, out1, out2, out3])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))