simkyuri commited on
Commit
d40ff78
ยท
verified ยท
1 Parent(s): 087ce96

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +375 -0
app.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import timm
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ import gradio as gr
10
+ from torchvision import transforms
11
+
12
+ from transformers import (
13
+ CLIPModel,
14
+ CLIPProcessor,
15
+ BlipProcessor,
16
+ BlipForConditionalGeneration,
17
+ )
18
+
19
+ # =========================================
20
+ # 0. ๊ฒฝ๋กœ / ๋””๋ฐ”์ด์Šค ์„ค์ •
21
+ # =========================================
22
+ CLIP_EMBED_PATH = "multimodal_assets/clip_text_embeds.pt"
23
+ MODEL_WEIGHTS_PATH = "models/convnext_base_merged_ema.pth"
24
+
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ print(" Device:", device)
27
+
28
+ # =========================================
29
+ # 1. ๋ณ‘ํ•ฉ ํด๋ž˜์Šค ์ด๋ฆ„ & CLIP ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ๋กœ๋“œ
30
+ # =========================================
31
+ print(" CLIP ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ๋กœ๋“œ ์ค‘...")
32
+ clip_data = torch.load(CLIP_EMBED_PATH)
33
+
34
+ merged_class_names = clip_data["class_names"] # 17๊ฐœ ๋ณ‘ํ•ฉ ํด๋ž˜์Šค ์ด๋ฆ„
35
+ clip_prompts = clip_data["prompts"]
36
+ text_embeds = clip_data["text_embeds"] # [17, D]
37
+ clip_model_name = clip_data["clip_model_name"]
38
+
39
+ # ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ์„ ๋””๋ฐ”์ด์Šค๋กœ ์˜ฌ๋ฆฌ๊ธฐ
40
+ text_embeds = text_embeds.to(device)
41
+
42
+ print("๋ณ‘ํ•ฉ ํด๋ž˜์Šค ์ˆ˜:", len(merged_class_names))
43
+ print("๋ณ‘ํ•ฉ ํด๋ž˜์Šค ๋ชฉ๋ก:", merged_class_names)
44
+
45
+ # =========================================
46
+ # 2. ConvNeXt-Base ๋ถ„๋ฅ˜ ๋ชจ๋ธ ๋กœ๋“œ
47
+ # =========================================
48
+ print(" ConvNeXt-Base ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ (timm)...")
49
+ num_classes = len(merged_class_names)
50
+
51
+ convnext_model = timm.create_model(
52
+ "convnext_base",
53
+ pretrained=False,
54
+ num_classes=num_classes,
55
+ )
56
+ state_dict = torch.load(MODEL_WEIGHTS_PATH, map_location="cpu")
57
+ convnext_model.load_state_dict(state_dict)
58
+ convnext_model.to(device)
59
+ convnext_model.eval()
60
+
61
+ print(" ConvNeXt-Base ํ•™์Šต ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ")
62
+
63
+ # ConvNeXt์šฉ ์ „์ฒ˜๋ฆฌ (๊ฒ€์ฆ์šฉ)
64
+ mean = (0.485, 0.456, 0.406)
65
+ std = (0.229, 0.224, 0.225)
66
+
67
+ val_transform = transforms.Compose([
68
+ transforms.Resize(256),
69
+ transforms.CenterCrop(224),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(mean, std),
72
+ ])
73
+
74
+ # =========================================
75
+ # 3. CLIP ๋ชจ๋ธ ๋กœ๋“œ
76
+ # =========================================
77
+ print(f" CLIP ๋ชจ๋ธ ๋กœ๋“œ ์ค‘... ({clip_model_name})")
78
+ clip_model = CLIPModel.from_pretrained(clip_model_name)
79
+ clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
80
+
81
+ clip_model.to(device)
82
+ clip_model.eval()
83
+
84
+ # =========================================
85
+ # 4. BLIP ์บก์…˜ ๋ชจ๋ธ ๋กœ๋“œ
86
+ # =========================================
87
+ print(" BLIP ์บก์…˜ ๋ชจ๋ธ ๋กœ๋“œ ์ค‘... (Salesforce/blip-image-captioning-base)")
88
+ blip_model_name = "Salesforce/blip-image-captioning-base"
89
+ blip_processor = BlipProcessor.from_pretrained(blip_model_name)
90
+ blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name).to(device)
91
+ blip_model.eval()
92
+
93
+ # =========================================
94
+ # 5. ์„ธ๋ถ€ ๋ฉ”๋‰ด ํ›„๋ณด / ์นผ๋กœ๋ฆฌ ์ •๋ณด ์ •์˜
95
+ # =========================================
96
+
97
+ # ์›๋ž˜ 27๊ฐœ ๋ฉ”๋‰ด(์„ธ๋ถ€ ๋ฉ”๋‰ด)
98
+ fine_grained_menus = [
99
+ "๊ฐ„์žฅ๋ผ๋ถˆ๋ฎ๋ฐฅ",
100
+ "๊ณ ์ถ”์น˜ํ‚จ์นด๋ ˆ๋™",
101
+ "๊ณต๊ธฐ๋ฐฅ",
102
+ "๊น€์น˜์–ด๋ฌต์šฐ๋™",
103
+ "๋‹ญ๊ฐ•์ •",
104
+ "๋ˆ๊นŒ์Šค์˜ค๋ฏ€๋ผ์ด์Šค",
105
+ "๋ˆ๊นŒ์Šค์šฐ๋™์„ธํŠธ",
106
+ "๋ˆ๊นŒ์Šค์นด๋ ˆ๋™",
107
+ "๋“ฑ์‹ฌ๋ˆ๊นŒ์Šค",
108
+ "๋งˆ๊ทธ๋งˆ์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ",
109
+ "๋งˆ๊ทธ๋งˆ์น˜ํ‚จ๋งˆ์š”",
110
+ "๋ฒ ์ด์ปจ ์•Œ๋ฆฌ์˜ค์˜ฌ๋ฆฌ์˜ค",
111
+ "์‚ผ๊ฒน๋œ์žฅ์งœ๊ธ€์ด",
112
+ "์‚ผ๊ฒน์‚ด๊ฐ•๋œ์žฅ๋น„๋น”๋ฐฅ",
113
+ "์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ",
114
+ "์ƒˆ์šฐํŠ€๊น€์šฐ๋™",
115
+ "์†Œ๋–ก์†Œ๋–ก",
116
+ "์‹ ๋ผ๋ฉด(๊ณ„๋ž€)",
117
+ "์‹ ๋ผ๋ฉด(๊ณ„๋ž€+์น˜์ฆˆ)",
118
+ "์–‘๋…์น˜ํ‚จ์˜ค๋ฏ€๋ผ์ด์Šค",
119
+ "์–ด๋ฌต์šฐ๋™",
120
+ "์—๋น„์นด๋ ˆ๋™",
121
+ "์˜ค๋ฏ€๋ผ์ด์Šค",
122
+ "์ซ‘์ซ‘์ด๋ฎ๋ฐฅ",
123
+ "์น˜ํ‚จ๋งˆ์š”",
124
+ "์ผ€๋„ค๋””์†Œ์‹œ์ง€",
125
+ "์ผ€๋„ค๋””์†Œ์‹œ์ง€์˜ค๋ฏ€๋ผ์ด์Šค",
126
+ ]
127
+
128
+ # ๋ณ‘ํ•ฉ ๋Œ€๋ถ„๋ฅ˜ โ†’ ์„ธ๋ถ€ ๋ฉ”๋‰ด ํ›„๋ณด
129
+ merged_to_fine = {
130
+ "์˜ค๋ฏ€๋ผ์ด์Šค๋ฅ˜": ["์˜ค๋ฏ€๋ผ์ด์Šค", "๋ˆ๊นŒ์Šค์˜ค๋ฏ€๋ผ์ด์Šค", "์ผ€๋„ค๋””์†Œ์‹œ์ง€์˜ค๋ฏ€๋ผ์ด์Šค"],
131
+ "์น˜ํ‚จ๋งˆ์š”๋ฅ˜": ["์น˜ํ‚จ๋งˆ์š”", "๋งˆ๊ทธ๋งˆ์น˜ํ‚จ๋งˆ์š”"],
132
+ "์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ๋ฅ˜": ["์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ", "๋งˆ๊ทธ๋งˆ์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ"],
133
+ "๋ผ๋ฉด๋ฅ˜": ["์‹ ๋ผ๋ฉด(๊ณ„๋ž€)", "์‹ ๋ผ๋ฉด(๊ณ„๋ž€+์น˜์ฆˆ)"],
134
+ }
135
+
136
+ # ๋Œ€ํ‘œ ์„ธ๋ถ€ ๋ฉ”๋‰ด (์‚ฌ์šฉ์ž๊ฐ€ ์„ ํƒ ์•ˆ ํ–ˆ์„ ๋•Œ ๊ธฐ๋ณธ๊ฐ’)
137
+ default_detail = {
138
+ "์˜ค๋ฏ€๋ผ์ด์Šค๋ฅ˜": "์˜ค๋ฏ€๋ผ์ด์Šค",
139
+ "์น˜ํ‚จ๋งˆ์š”๋ฅ˜": "์น˜ํ‚จ๋งˆ์š”",
140
+ "์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ๋ฅ˜": "์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ",
141
+ "๋ผ๋ฉด๋ฅ˜": "์‹ ๋ผ๋ฉด(๊ณ„๋ž€)",
142
+ }
143
+
144
+ # ์•„์ฃผ ๋Œ€๋žต์ ์ธ ์นผ๋กœ๋ฆฌ ํ…Œ์ด๋ธ”
145
+ calorie_table = {
146
+ "๊ฐ„์žฅ๋ผ๋ถˆ๋ฎ๋ฐฅ": 800,
147
+ "๊ณ ์ถ”์น˜ํ‚จ์นด๋ ˆ๋™": 900,
148
+ "๊ณต๊ธฐ๋ฐฅ": 300,
149
+ "๊น€์น˜์–ด๋ฌต์šฐ๋™": 500,
150
+ "๋‹ญ๊ฐ•์ •": 450,
151
+ "๋ˆ๊นŒ์Šค์˜ค๋ฏ€๋ผ์ด์Šค": 950,
152
+ "๋ˆ๊นŒ์Šค์šฐ๋™์„ธํŠธ": 900,
153
+ "๋ˆ๊นŒ์Šค์นด๋ ˆ๋™": 900,
154
+ "๋“ฑ์‹ฌ๋ˆ๊นŒ์Šค": 700,
155
+ "๋งˆ๊ทธ๋งˆ์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ": 800,
156
+ "๋งˆ๊ทธ๋งˆ์น˜ํ‚จ๋งˆ์š”": 850,
157
+ "๋ฒ ์ด์ปจ ์•Œ๋ฆฌ์˜ค์˜ฌ๋ฆฌ์˜ค": 800,
158
+ "์‚ผ๊ฒน๋œ์žฅ์งœ๊ธ€์ด": 750,
159
+ "์‚ผ๊ฒน์‚ด๊ฐ•๋œ์žฅ๋น„๋น”๋ฐฅ": 800,
160
+ "์ƒˆ์šฐํŠ€๊น€์•Œ๋ฐฅ": 750,
161
+ "์ƒˆ์šฐํŠ€๊น€์šฐ๋™": 550,
162
+ "์†Œ๋–ก์†Œ๋–ก": 450,
163
+ "์‹ ๋ผ๋ฉด(๊ณ„๋ž€)": 570,
164
+ "์‹ ๋ผ๋ฉด(๊ณ„๋ž€+์น˜์ฆˆ)": 630,
165
+ "์–‘๋…์น˜ํ‚จ์˜ค๋ฏ€๋ผ์ด์Šค": 950,
166
+ "์–ด๋ฌต์šฐ๋™": 450,
167
+ "์—๋น„์นด๋ ˆ๋™": 800,
168
+ "์˜ค๋ฏ€๋ผ์ด์Šค": 730,
169
+ "์ซ‘์ซ‘์ด๋ฎ๋ฐฅ": 700,
170
+ "์น˜ํ‚จ๋งˆ์š”": 800,
171
+ "์ผ€๋„ค๋””์†Œ์‹œ์ง€": 280,
172
+ "์ผ€๋„ค๋””์†Œ์‹œ์ง€์˜ค๋ฏ€๋ผ์ด์Šค": 1000,
173
+ }
174
+
175
+ # =========================================
176
+ # 6. ์œ ํ‹ธ ํ•จ์ˆ˜๋“ค
177
+ # =========================================
178
+
179
+ def predict_convnext(image: Image.Image):
180
+ """ConvNeXt-Base๋กœ ๋ณ‘ํ•ฉ ๋Œ€๋ถ„๋ฅ˜ ์˜ˆ์ธก"""
181
+ convnext_model.eval()
182
+ img_t = val_transform(image).unsqueeze(0).to(device)
183
+
184
+ with torch.no_grad():
185
+ logits = convnext_model(img_t)
186
+ probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
187
+
188
+ top1_idx = int(np.argmax(probs))
189
+ top1_prob = float(probs[top1_idx])
190
+
191
+ # Top-3๋„ ๋ณด๊ณ ์‹ถ์œผ๋ฉด:
192
+ top3_idx = np.argsort(probs)[::-1][:3]
193
+ top3 = [(merged_class_names[i], float(probs[i])) for i in top3_idx]
194
+
195
+ return merged_class_names[top1_idx], top1_prob, top3
196
+
197
+
198
+ def recommend_with_clip(image: Image.Image, top_k=3):
199
+ """CLIP์œผ๋กœ ๋ณ‘ํ•ฉ ๋Œ€๋ถ„๋ฅ˜ ๊ธฐ์ค€ ์œ ์‚ฌ ๋ฉ”๋‰ด Top-K"""
200
+ clip_model.eval()
201
+
202
+ inputs = clip_processor(images=image, return_tensors="pt").to(device)
203
+
204
+ with torch.no_grad():
205
+ img_feat = clip_model.get_image_features(**inputs)
206
+ img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
207
+
208
+ sims = (img_feat @ text_embeds.T).squeeze(0) # [17]
209
+ topk = sims.topk(top_k)
210
+
211
+ indices = topk.indices.tolist()
212
+ scores = topk.values.tolist()
213
+ result = [(merged_class_names[i], float(s)) for i, s in zip(indices, scores)]
214
+ return result
215
+
216
+
217
+ def generate_caption(image: Image.Image):
218
+ """BLIP์œผ๋กœ ์ด๋ฏธ์ง€ ์บก์…˜ ์ƒ์„ฑ"""
219
+ blip_model.eval()
220
+ inputs = blip_processor(images=image, return_tensors="pt").to(device)
221
+ with torch.no_grad():
222
+ out = blip_model.generate(**inputs, max_new_tokens=20)
223
+ caption = blip_processor.decode(out[0], skip_special_tokens=True)
224
+ return caption
225
+
226
+
227
+ def calorie_comment(menu_name: str, activity: str):
228
+ kcal = calorie_table.get(menu_name)
229
+ if kcal is None:
230
+ return "์ด ๋ฉ”๋‰ด์— ๋Œ€ํ•œ ์นผ๋กœ๋ฆฌ ์ •๋ณด๊ฐ€ ๋“ฑ๋ก๋˜์–ด ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค."
231
+
232
+ base = f"์˜ˆ์ƒ ์นผ๋กœ๋ฆฌ: ์•ฝ {kcal} kcal.\n"
233
+
234
+ if activity == "๊ฑฐ์˜ ์•ˆ ์›€์ง์ž„":
235
+ if kcal >= 900:
236
+ return base + "์˜ค๋Š˜ ํ™œ๋™๋Ÿ‰์„ ๊ณ ๋ คํ•˜๋ฉด ๊ฝค ๋†’์€ ์นผ๋กœ๋ฆฌ๋ผ์„œ, ์ž์ฃผ ๋จน๊ธฐ์—” ๋ถ€๋‹ด๋  ์ˆ˜ ์žˆ์–ด์š”."
237
+ elif kcal >= 600:
238
+ return base + "์ ๋‹นํ•œ ํŽธ์ด์ง€๋งŒ, ๊ฐ„์‹์ด๋‚˜ ๋‹ค๋ฅธ ์‹์‚ฌ์™€ ํ•จ๊ป˜๋ผ๋ฉด ์ด๋Ÿ‰์„ ์กฐ๊ธˆ ์‹ ๊ฒฝ ์“ฐ๋ฉด ์ข‹๊ฒ ์–ด์š”."
239
+ else:
240
+ return base + "๊ฐ€๋ฒผ์šด ํŽธ์ด๋ผ ํฐ ๋ถ€๋‹ด ์—†์ด ๋จน์–ด๋„ ๊ดœ์ฐฎ์€ ์ˆ˜์ค€์ด์—์š”."
241
+ elif activity == "๋ณดํ†ต ํ™œ๋™":
242
+ if kcal >= 1000:
243
+ return base + "ํ™œ๋™๋Ÿ‰์„ ๊ณ ๋ คํ•ด๋„ ๊ฝค ๋“ ๋“ ํ•œ ํ•œ ๋ผ๋ผ์„œ, ๋‹ค๋ฅธ ๋ผ๋‹ˆ๋Š” ์กฐ๊ธˆ ๊ฐ€๋ณ๊ฒŒ ๊ตฌ์„ฑํ•˜๋ฉด ์ข‹์•„์š”."
244
+ elif kcal >= 700:
245
+ return base + "ํ•˜๋ฃจ ํ•œ ๋ผ ๋ฉ”์ธ์œผ๋กœ ๋จน๊ธฐ ์ข‹์€ ์ •๋„์˜ ์นผ๋กœ๋ฆฌ์˜ˆ์š”."
246
+ else:
247
+ return base + "์กฐ๊ธˆ ๊ฐ€๋ฒผ์šด ํŽธ์ด๋ผ, ๋ฐฐ๊ฐ€ ๋นจ๋ฆฌ ๊บผ์งˆ ์ˆ˜๋Š” ์žˆ์–ด์š”."
248
+ else: # ๋งŽ์ด ์›€์ง์ž„
249
+ if kcal >= 1000:
250
+ return base + "ํ™œ๋™๋Ÿ‰์ด ๋งŽ๋‹ค๋ฉด ์ด ์ •๋„ ์นผ๋กœ๋ฆฌ๋Š” ์ถฉ๋ถ„ํžˆ ์ž˜ ์“ฐ์ผ ๊ฑฐ์˜ˆ์š”!"
251
+ elif kcal >= 700:
252
+ return base + "์šด๋™ ์ „ํ›„ ํ•œ ๋ผ๋กœ ์ ๋‹นํ•œ ์ˆ˜์ค€์˜ ์—๋„ˆ์ง€ ๊ณต๊ธ‰์ด ๋  ๊ฒƒ ๊ฐ™์•„์š”."
253
+ else:
254
+ return base + "ํ™œ๋™๋Ÿ‰์— ๋น„ํ•ด ์กฐ๊ธˆ ๊ฐ€๋ฒผ์šด ํŽธ์ด๋ผ, ๊ฐ„๋‹จํ•œ ๊ฐ„์‹์„ ๋” ๊ณ๋“ค์—ฌ๋„ ์ข‹๊ฒ ์–ด์š”."
255
+
256
+
257
+ # =========================================
258
+ # 7. Gradio ์›น์•ฑ ๋ฉ”์ธ ํ•จ์ˆ˜
259
+ # =========================================
260
+
261
+ def analyze_menu(image, activity_level, detail_menu_choice):
262
+ """
263
+ image: ์—…๋กœ๋“œ๋œ ์ด๋ฏธ์ง€ (PIL)
264
+ activity_level: ํ™œ๋™๋Ÿ‰ (๋ผ๋””์˜ค ๋ฒ„ํŠผ)
265
+ detail_menu_choice: ์‚ฌ์šฉ์ž๊ฐ€ ์„ ํƒํ•œ ์„ธ๋ถ€ ๋ฉ”๋‰ด (๋“œ๋กญ๋‹ค์šด)
266
+ """
267
+ if image is None:
268
+ return "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด ์ฃผ์„ธ์š”.", "", "", ""
269
+
270
+ # 1) ConvNeXt๋กœ ๋ณ‘ํ•ฉ ๋Œ€๋ถ„๋ฅ˜ ์˜ˆ์ธก
271
+ big_cls, big_prob, top3_conv = predict_convnext(image)
272
+
273
+ # 2) ํ•ด๋‹น ๋Œ€๋ถ„๋ฅ˜์— ์„ธ๋ถ€ ํ›„๋ณด๊ฐ€ ์žˆ๋Š”์ง€ ํ™•์ธ
274
+ fine_candidates = merged_to_fine.get(big_cls, [])
275
+
276
+ # 3) ์„ธ๋ถ€ ๋ฉ”๋‰ด ๊ฒฐ์ • ๋กœ์ง
277
+ if detail_menu_choice is not None and detail_menu_choice != "์„ ํƒ ์•ˆ ํ•จ (๋ชจ๋ธ์— ๋งก๊ธฐ๊ธฐ)":
278
+ final_menu = detail_menu_choice
279
+ detail_info = f"์‚ฌ์šฉ์ž๊ฐ€ ์ง์ ‘ ์„ ํƒํ•œ ์„ธ๋ถ€ ๋ฉ”๋‰ด: **{final_menu}**"
280
+ else:
281
+ # ์‚ฌ์šฉ์ž๊ฐ€ ์ง์ ‘ ์„ ํƒ ์•ˆ ํ•œ ๊ฒฝ์šฐ
282
+ if big_cls in default_detail:
283
+ final_menu = default_detail[big_cls]
284
+ detail_info = (
285
+ f"์˜ˆ์ธก ๋Œ€๋ถ„๋ฅ˜: **{big_cls}** (์‹ ๋ขฐ๋„: {big_prob*100:.2f}%)\n"
286
+ f"์„ธ๋ถ€ ๋ฉ”๋‰ด๋Š” ์„ ํƒํ•˜์ง€ ์•Š์•„, ๋Œ€ํ‘œ ๋ฉ”๋‰ด **'{final_menu}'** ๊ธฐ์ค€์œผ๋กœ ์นผ๋กœ๋ฆฌ๋ฅผ ์•ˆ๋‚ดํ•ฉ๋‹ˆ๋‹ค.\n"
287
+ f"(์„ ํƒ ๋ฉ”๋‰ด๋ฅผ ๋ฐ”๊พธ๋ฉด ์นผ๋กœ๋ฆฌ ๋ฌธ์žฅ์ด ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์–ด์š”)"
288
+ )
289
+ else:
290
+ # ๋Œ€๋ถ„๋ฅ˜ ์ž์ฒด๊ฐ€ ์ด๋ฏธ ์ตœ์ข… ๋ฉ”๋‰ด์ธ ๊ฒฝ์šฐ
291
+ final_menu = big_cls
292
+ detail_info = f"์˜ˆ์ธก ๋ฉ”๋‰ด: **{final_menu}** (์‹ ๋ขฐ๋„: {big_prob*100:.2f}%)"
293
+
294
+ # 4) CLIP Top-3 ์œ ์‚ฌ ๋ณ‘ํ•ฉ ๋ฉ”๋‰ด
295
+ clip_top3 = recommend_with_clip(image, top_k=3)
296
+ clip_text_lines = []
297
+ for name, score in clip_top3:
298
+ clip_text_lines.append(f"- {name} (์œ ์‚ฌ๋„: {score:.4f})")
299
+ clip_text = "\n".join(clip_text_lines)
300
+
301
+ # 5) BLIP ์บก์…˜ ์ƒ์„ฑ
302
+ caption = generate_caption(image)
303
+
304
+ # 6) ์นผ๋กœ๋ฆฌ ์ฝ”๋ฉ˜ํŠธ
305
+ kcal_text = calorie_comment(final_menu, activity_level)
306
+
307
+ # 7) ์•ˆ๋‚ด ๋ฌธ๊ตฌ (์„ธ๋ถ€ ํ›„๋ณด ๋ณด์—ฌ์ฃผ๊ธฐ)
308
+ if fine_candidates:
309
+ candidate_text = (
310
+ f"์ด ์ด๋ฏธ์ง€๋Š” **'{big_cls}'**(์œผ)๋กœ ๋ถ„๋ฅ˜๋˜์—ˆ์Šต๋‹ˆ๋‹ค.\n\n"
311
+ f"์ด ๋Œ€๋ถ„๋ฅ˜์— ํ•ด๋‹นํ•˜๋Š” ์„ธ๋ถ€ ๋ฉ”๋‰ด ํ›„๋ณด:\n" +
312
+ "\n".join([f"- {m}" for m in fine_candidates]) +
313
+ "\n\n์œ„ ๋“œ๋กญ๋‹ค์šด์—์„œ ์„ธ๋ถ€ ๋ฉ”๋‰ด๋ฅผ ์ง์ ‘ ์„ ํƒํ•˜๋ฉด ์นผ๋กœ๋ฆฌ ์•ˆ๋‚ด๊ฐ€ ๋” ์ •ํ™•ํ•ด์ง‘๋‹ˆ๋‹ค."
314
+ )
315
+ else:
316
+ candidate_text = f"์ด ์ด๋ฏธ์ง€๋Š” **'{big_cls}'**(์œผ)๋กœ ๋ถ„๋ฅ˜๋˜์—ˆ๊ณ , ๋ณ„๋„์˜ ์„ธ๋ถ€ ๋ฉ”๋‰ด ๋ถ„๊ธฐ๋Š” ์—†๋Š” ์นดํ…Œ๊ณ ๋ฆฌ์ž…๋‹ˆ๋‹ค."
317
+
318
+ # ์ตœ์ข… ์š”์•ฝ ๋ฉ”์‹œ์ง€
319
+ summary = (
320
+ f"### ์ตœ์ข… ๋ฉ”๋‰ด ๋ถ„์„\n"
321
+ f"- ์˜ˆ์ธก ๋Œ€๋ถ„๋ฅ˜: **{big_cls}** (์‹ ๋ขฐ๋„: {big_prob*100:.2f}%)\n"
322
+ f"- ์ตœ์ข… ๊ธฐ์ค€ ๋ฉ”๋‰ด: **{final_menu}**\n"
323
+ f"- ํ™œ๋™๋Ÿ‰: **{activity_level}**\n\n"
324
+ f"### ์„ธ๋ถ€ ๋ฉ”๋‰ด ์ •๋ณด\n{detail_info}\n\n"
325
+ f"### ConvNeXt Top-3 (๋ณ‘ํ•ฉ ํด๋ž˜์Šค ๊ธฐ์ค€)\n" +
326
+ "\n".join([f"- {name} ({p*100:.2f}%)" for name, p in top3_conv]) +
327
+ "\n\n"
328
+ f"### CLIP ์œ ์‚ฌ ๋ฉ”๋‰ด Top-3 (๋ณ‘ํ•ฉ ํด๋ž˜์Šค ๊ธฐ์ค€)\n{clip_text}\n\n"
329
+ f"### BLIP ์บก์…˜ (์˜์–ด)\n> {caption}\n\n"
330
+ f"### ์นผ๋กœ๋ฆฌ & ํ™œ๋™๋Ÿ‰ ์ฝ”๋ฉ˜ํŠธ\n{kcal_text}\n\n"
331
+ f"---\n"
332
+ f"{candidate_text}"
333
+ )
334
+
335
+ return summary, caption, clip_text, kcal_text
336
+
337
+
338
+ # =========================================
339
+ # 8. Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
340
+ # =========================================
341
+
342
+ with gr.Blocks() as demo:
343
+ gr.Markdown("## ํ•™์‹ ์Šค์บ๋„ˆ")
344
+
345
+ with gr.Row():
346
+ with gr.Column():
347
+ img_input = gr.Image(type="pil", label="๋ฉ”๋‰ด ์‚ฌ์ง„ ์—…๋กœ๋“œ")
348
+
349
+ activity_input = gr.Radio(
350
+ choices=["๊ฑฐ์˜ ์•ˆ ์›€์ง์ž„", "๋ณดํ†ต ํ™œ๋™", "๋งŽ์ด ์›€์ง์ž„"],
351
+ value="๋ณดํ†ต ํ™œ๋™",
352
+ label="์˜ค๋Š˜ ํ™œ๋™๋Ÿ‰",
353
+ )
354
+
355
+ detail_menu_input = gr.Dropdown(
356
+ choices=["์„ ํƒ ์•ˆ ํ•จ (๋ชจ๋ธ์— ๋งก๊ธฐ๊ธฐ)"] + fine_grained_menus,
357
+ value="์„ ํƒ ์•ˆ ํ•จ (๋ชจ๋ธ์— ๋งก๊ธฐ๊ธฐ)",
358
+ label="์„ธ๋ถ€ ๋ฉ”๋‰ด (์„ ํƒํ•˜๋ฉด ์นผ๋กœ๋ฆฌ ๊ณ„์‚ฐ์— ์‚ฌ์šฉ)",
359
+ )
360
+
361
+ run_btn = gr.Button("๋ถ„์„ ์‹คํ–‰ ")
362
+
363
+ with gr.Column():
364
+ summary_output = gr.Markdown(label="๋ถ„์„ ๊ฒฐ๊ณผ ์š”์•ฝ")
365
+ caption_output = gr.Textbox(label="BLIP ์บก์…˜ (์˜์–ด)", lines=2)
366
+ clip_output = gr.Textbox(label="CLIP ์œ ์‚ฌ ๋ณ‘ํ•ฉ ๋ฉ”๋‰ด Top-3", lines=4)
367
+ kcal_output = gr.Textbox(label="์นผ๋กœ๋ฆฌ ์ฝ”๋ฉ˜ํŠธ", lines=3)
368
+
369
+ run_btn.click(
370
+ fn=analyze_menu,
371
+ inputs=[img_input, activity_input, detail_menu_input],
372
+ outputs=[summary_output, caption_output, clip_output, kcal_output],
373
+ )
374
+
375
+ demo.launch()