simkyuri commited on
Commit
0c636ec
ยท
verified ยท
1 Parent(s): bc807a0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -0
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import timm
7
+ import numpy as np
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ import gradio as gr
11
+
12
+ from transformers import (
13
+ CLIPModel,
14
+ CLIPProcessor,
15
+ BlipProcessor,
16
+ BlipForConditionalGeneration,
17
+ )
18
+
19
+ # =========================================
20
+ # 0. ๊ฒฝ๋กœ / ๋””๋ฐ”์ด์Šค ์„ค์ •
21
+ # =========================================
22
+ CLIP_EMBED_PATH = "multimodal_assets/clip_text_embeds.pt"
23
+ MODEL_WEIGHTS_PATH = "models/convnext_base_merged_ema.pth"
24
+
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ print("Device:", device)
27
+
28
+ # =========================================
29
+ # 1. CLIP ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ๋กœ๋“œ
30
+ # =========================================
31
+ print("CLIP ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ๋กœ๋“œ ์ค‘...")
32
+ clip_data = torch.load(CLIP_EMBED_PATH, map_location="cpu")
33
+
34
+ merged_class_names = clip_data["class_names"]
35
+ clip_prompts = clip_data["prompts"]
36
+ text_embeds = clip_data["text_embeds"]
37
+ clip_model_name = clip_data["clip_model_name"]
38
+
39
+ text_embeds = text_embeds.to(device)
40
+
41
+ print("๋ณ‘ํ•ฉ ํด๋ž˜์Šค ์ˆ˜:", len(merged_class_names))
42
+ print("๋ณ‘ํ•ฉ ํด๋ž˜์Šค ๋ชฉ๋ก:", merged_class_names)
43
+
44
+ # =========================================
45
+ # 2. ConvNeXt-Base ๋ถ„๋ฅ˜ ๋ชจ๋ธ ๋กœ๋“œ
46
+ # =========================================
47
+ print("ConvNeXt-Base ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ (timm)...")
48
+ num_classes = len(merged_class_names)
49
+
50
+ convnext_model = timm.create_model(
51
+ "convnext_base",
52
+ pretrained=False,
53
+ num_classes=num_classes,
54
+ )
55
+
56
+ state_dict = torch.load(MODEL_WEIGHTS_PATH, map_location="cpu")
57
+ convnext_model.load_state_dict(state_dict)
58
+ convnext_model.to(device)
59
+ convnext_model.eval()
60
+
61
+ print("ConvNeXt-Base ํ•™์Šต ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ")
62
+
63
+ mean = (0.485, 0.456, 0.406)
64
+ std = (0.229, 0.224, 0.225)
65
+
66
+ val_transform = transforms.Compose([
67
+ transforms.Resize(256),
68
+ transforms.CenterCrop(224),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize(mean, std),
71
+ ])
72
+
73
+ # =========================================
74
+ # 3. CLIP ๋ชจ๋ธ ๋กœ๋“œ
75
+ # =========================================
76
+ print(f"CLIP ๋ชจ๋ธ ๋กœ๋“œ ์ค‘... ({clip_model_name})")
77
+ clip_model = CLIPModel.from_pretrained(clip_model_name)
78
+ clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
79
+ clip_model.to(device)
80
+ clip_model.eval()
81
+
82
+ # =========================================
83
+ # 4. BLIP ์บก์…˜ ๋ชจ๋ธ ๋กœ๋“œ
84
+ # =========================================
85
+ print("BLIP ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
86
+ blip_model_name = "Salesforce/blip-image-captioning-base"
87
+ blip_processor = BlipProcessor.from_pretrained(blip_model_name)
88
+ blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name).to(device)
89
+ blip_model.eval()
90
+
91
+ # =========================================
92
+ # 5. ์„ธ๋ถ€ ๋ฉ”๋‰ด / ์นผ๋กœ๋ฆฌ ํ…Œ์ด๋ธ”
93
+ # =========================================
94
+ fine_grained_menus = [
95
+ "๊ฐ„์žฅ๋ผ๋ถˆ๋ฎ๋ฐฅ", "๊ณ ์ถ”์น˜ํ‚จ์นด๋ ˆ๋™", "๊ณต๊ธฐ๋ฐฅ", "๊น€์น˜์–ด๋ฌต์šฐ๋™", "๋‹ญ๊ฐ•์ •",
96
+ "๋ˆ๊นŒ์Šค์˜ค๋ฏ€๋ผ์ด์Šค", "๋ˆ๊นŒ์Šค์šฐ๋™์„ธํŠธ", "๋ˆ๊นŒ์Šค์นด๋ ˆ๋™", "๋“ฑ์‹ฌ๋ˆ๊นŒ์Šค",
97
+ "๋งˆ๊ทธ๋งˆ์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ", "๋งˆ๊ทธ๋งˆ์น˜ํ‚จ๋งˆ์š”", "๋ฒ ์ด์ปจ ์•Œ๋ฆฌ์˜ค์˜ฌ๋ฆฌ์˜ค", "์‚ผ๊ฒน๋œ์žฅ์งœ๊ธ€์ด",
98
+ "์‚ผ๊ฒน์‚ด๊ฐ•๋œ์žฅ๋น„๋น”๋ฐฅ", "์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ", "์ƒˆ์šฐํŠ€๊น€์šฐ๋™", "์†Œ๋–ก์†Œ๋–ก",
99
+ "์‹ ๋ผ๋ฉด(๊ณ„๋ž€)", "์‹ ๋ผ๋ฉด(๊ณ„๋ž€+์น˜์ฆˆ)", "์–‘๋…์น˜ํ‚จ์˜ค๋ฏ€๋ผ์ด์Šค", "์–ด๋ฌต์šฐ๋™",
100
+ "์—๋น„์นด๋ ˆ๋™", "์˜ค๋ฏ€๋ผ์ด์Šค", "์ซ‘์ซ‘์ด๋ฎ๋ฐฅ", "์น˜ํ‚จ๋งˆ์š”", "์ผ€๋„ค๋””์†Œ์‹œ์ง€",
101
+ "์ผ€๋„ค๋””์†Œ์‹œ์ง€์˜ค๋ฏ€๋ผ์ด์Šค",
102
+ ]
103
+
104
+ merged_to_fine = {
105
+ "์˜ค๋ฏ€๋ผ์ด์Šค๋ฅ˜": ["์˜ค๋ฏ€๋ผ์ด์Šค", "๋ˆ๊นŒ์Šค์˜ค๋ฏ€๋ผ์ด์Šค", "์ผ€๋„ค๋””์†Œ์‹œ์ง€์˜ค๋ฏ€๋ผ์ด์Šค"],
106
+ "์น˜ํ‚จ๋งˆ์š”๋ฅ˜": ["์น˜ํ‚จ๋งˆ์š”", "๋งˆ๊ทธ๋งˆ์น˜ํ‚จ๋งˆ์š”"],
107
+ "์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ๋ฅ˜": ["์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ", "๋งˆ๊ทธ๋งˆ์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ"],
108
+ "๋ผ๋ฉด๋ฅ˜": ["์‹ ๋ผ๋ฉด(๊ณ„๋ž€)", "์‹ ๋ผ๋ฉด(๊ณ„๋ž€+์น˜์ฆˆ)"],
109
+ }
110
+
111
+ default_detail = {
112
+ "์˜ค๋ฏ€๋ผ์ด์Šค๋ฅ˜": "์˜ค๋ฏ€๋ผ์ด์Šค",
113
+ "์น˜ํ‚จ๋งˆ์š”๋ฅ˜": "์น˜ํ‚จ๋งˆ์š”",
114
+ "์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ๋ฅ˜": "์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ",
115
+ "๋ผ๋ฉด๋ฅ˜": "์‹ ๋ผ๋ฉด(๊ณ„๋ž€)",
116
+ }
117
+
118
+ calorie_table = {
119
+ "๊ฐ„์žฅ๋ผ๋ถˆ๋ฎ๋ฐฅ": 800, "๊ณ ์ถ”์น˜ํ‚จ์นด๋ ˆ๋™": 900, "๊ณต๊ธฐ๋ฐฅ": 300,
120
+ "๊น€์น˜์–ด๋ฌต์šฐ๋™": 500, "๋‹ญ๊ฐ•์ •": 450, "๋ˆ๊นŒ์Šค์˜ค๋ฏ€๋ผ์ด์Šค": 950,
121
+ "๋ˆ๊นŒ์Šค์šฐ๋™์„ธํŠธ": 900, "๋ˆ๊นŒ์Šค์นด๋ ˆ๋™": 900, "๋“ฑ์‹ฌ๋ˆ๊นŒ์Šค": 700,
122
+ "๋งˆ๊ทธ๋งˆ์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ": 800, "๋งˆ๊ทธ๋งˆ์น˜ํ‚จ๋งˆ์š”": 850,
123
+ "๋ฒ ์ด์ปจ ์•Œ๋ฆฌ์˜ค์˜ฌ๋ฆฌ์˜ค": 800, "์‚ผ๊ฒน๋œ์žฅ์งœ๊ธ€์ด": 750,
124
+ "์‚ผ๊ฒน์‚ด๊ฐ•๋œ์žฅ๋น„๋น”๋ฐฅ": 800, "์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ": 750, "์ƒˆ์šฐํŠ€๊น€์šฐ๋™": 550,
125
+ "์†Œ๋–ก์†Œ๋–ก": 450, "์‹ ๋ผ๋ฉด(๊ณ„๋ž€)": 570, "์‹ ๋ผ๋ฉด(๊ณ„๋ž€+์น˜์ฆˆ)": 630,
126
+ "์–‘๋…์น˜ํ‚จ์˜ค๋ฏ€๋ผ์ด์Šค": 950, "์–ด๋ฌต์šฐ๋™": 450, "์—๋น„์นด๋ ˆ๋™": 800,
127
+ "์˜ค๋ฏ€๋ผ์ด์Šค": 730, "์ซ‘์ซ‘์ด๋ฎ๋ฐฅ": 700, "์น˜ํ‚จ๋งˆ์š”": 800,
128
+ "์ผ€๋„ค๋””์†Œ์‹œ์ง€": 280, "์ผ€๋„ค๋””์†Œ์‹œ์ง€์˜ค๋ฏ€๋ผ์ด์Šค": 1000,
129
+ }
130
+
131
+ # =========================================
132
+ # 6. ๋ชจ๋ธ ๊ธฐ๋Šฅ ํ•จ์ˆ˜
133
+ # =========================================
134
+
135
+ def predict_convnext(image: Image.Image):
136
+ img_t = val_transform(image).unsqueeze(0).to(device)
137
+ with torch.no_grad():
138
+ logits = convnext_model(img_t)
139
+ probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
140
+ top1 = int(np.argmax(probs))
141
+ top1_prob = float(probs[top1])
142
+ return merged_class_names[top1], top1_prob
143
+
144
+ def recommend_with_clip(image: Image.Image):
145
+ inputs = clip_processor(images=image, return_tensors="pt").to(device)
146
+ with torch.no_grad():
147
+ img_feat = clip_model.get_image_features(**inputs)
148
+ img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
149
+ sims = (img_feat @ text_embeds.T).squeeze(0)
150
+ topk = sims.topk(3)
151
+ result = [(merged_class_names[i], float(s)) for i, s in zip(topk.indices.tolist(), topk.values.tolist())]
152
+ return result
153
+
154
+ def generate_caption(image: Image.Image):
155
+ inputs = blip_processor(images=image, return_tensors="pt").to(device)
156
+ with torch.no_grad():
157
+ out = blip_model.generate(**inputs, max_new_tokens=20)
158
+ return blip_processor.decode(out[0], skip_special_tokens=True)
159
+
160
+ def calorie_comment(menu_name: str, activity: str):
161
+ kcal = calorie_table.get(menu_name, None)
162
+ if kcal is None:
163
+ return "์นผ๋กœ๋ฆฌ ์ •๋ณด ์—†์Œ"
164
+ return f"{menu_name}: ์•ฝ {kcal} kcal"
165
+
166
+ # =========================================
167
+ # 7. ์›น์•ฑ ๋ฉ”์ธ
168
+ # =========================================
169
+
170
+ def analyze_menu(image, activity_level, detail_menu_choice):
171
+ if image is None:
172
+ return "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.", "", "", ""
173
+
174
+ # 1) ConvNeXt
175
+ big_cls, prob = predict_convnext(image)
176
+
177
+ # 2) ์„ธ๋ถ€ ๋ฉ”๋‰ด ๊ฒฐ์ •
178
+ fine_candidates = merged_to_fine.get(big_cls, [])
179
+ if detail_menu_choice != "์„ ํƒ ์•ˆ ํ•จ (๋ชจ๋ธ์— ๋งก๊ธฐ๊ธฐ)":
180
+ final_menu = detail_menu_choice
181
+ else:
182
+ final_menu = default_detail.get(big_cls, big_cls)
183
+
184
+ # 3) CLIP Top-3
185
+ clip_top3 = recommend_with_clip(image)
186
+ clip_text = "\n".join([f"- {n} ({s:.4f})" for n, s in clip_top3])
187
+
188
+ # 4) BLIP
189
+ caption = generate_caption(image)
190
+
191
+ # 5) ์นผ๋กœ๋ฆฌ
192
+ kcal = calorie_comment(final_menu, activity_level)
193
+
194
+ # 6) ์ถœ๋ ฅ
195
+ summary = (
196
+ f"### ์ตœ์ข… ๋ฉ”๋‰ด ๋ถ„์„\n"
197
+ f"- ์˜ˆ์ธก ๋Œ€๋ถ„๋ฅ˜: **{big_cls}** ({prob*100:.2f}%)\n"
198
+ f"- ์ตœ์ข… ์„ธ๋ถ€ ๋ฉ”๋‰ด: **{final_menu}**\n\n"
199
+ f"### CLIP Top-3\n{clip_text}\n\n"
200
+ f"### BLIP ์บก์…˜\n> {caption}\n\n"
201
+ f"### ์นผ๋กœ๋ฆฌ ์ •๋ณด\n{kcal}"
202
+ )
203
+ return summary, caption, clip_text, kcal
204
+
205
+ # =========================================
206
+ # 8. Gradio ์ธํ„ฐํŽ˜์ด์Šค
207
+ # =========================================
208
+
209
+ with gr.Blocks() as demo:
210
+ gr.Markdown("## ํ•™์‹ ์Šค์บ๋„ˆ")
211
+
212
+ with gr.Row():
213
+ with gr.Column():
214
+ img_input = gr.Image(type="pil", label="๋ฉ”๋‰ด ์‚ฌ์ง„ ์—…๋กœ๋“œ")
215
+
216
+ activity_input = gr.Radio(
217
+ choices=["๊ฑฐ์˜ ์•ˆ ์›€์ง์ž„", "๋ณดํ†ต ํ™œ๋™", "๋งŽ์ด ์›€์ง์ž„"],
218
+ value="๋ณดํ†ต ํ™œ๋™",
219
+ label="์˜ค๋Š˜ ํ™œ๋™๋Ÿ‰",
220
+ )
221
+
222
+ detail_menu_input = gr.Dropdown(
223
+ choices=["์„ ํƒ ์•ˆ ํ•จ (๋ชจ๋ธ์— ๋งก๊ธฐ๊ธฐ)"] + fine_grained_menus,
224
+ value="์„ ํƒ ์•ˆ ํ•จ (๋ชจ๋ธ์— ๋งก๊ธฐ๊ธฐ)",
225
+ label="์„ธ๋ถ€ ๋ฉ”๋‰ด ์„ ํƒ",
226
+ )
227
+
228
+ btn = gr.Button("๋ถ„์„ํ•˜๊ธฐ")
229
+
230
+ with gr.Column():
231
+ summary_output = gr.Markdown()
232
+ caption_output = gr.Textbox(label="BLIP ์บก์…˜")
233
+ clip_output = gr.Textbox(label="CLIP Top-3")
234
+ kcal_output = gr.Textbox(label="์นผ๋กœ๋ฆฌ")
235
+
236
+ btn.click(
237
+ analyze_menu,
238
+ inputs=[img_input, activity_input, detail_menu_input],
239
+ outputs=[summary_output, caption_output, clip_output, kcal_output],
240
+ )
241
+
242
+ demo.launch()