Boonyaratt commited on
Commit
ddab628
·
1 Parent(s): 82a3d9c

Add application file

Browse files
Files changed (2) hide show
  1. best_multimodal.pt +3 -0
  2. gladio_webapp.py +637 -0
best_multimodal.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbe394b179a6b5f334f17725961b971ee50342e2d1d7867da43d51a64cbc45b7
3
+ size 19924917
gladio_webapp.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Gladio-webapp.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/11rgLJIwe-BYZs3NcVMFz4hnq6XIzxfsv
8
+ """
9
+
10
+
11
+
12
+ import gradio as gr
13
+
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ import os
17
+ import PIL
18
+ from PIL import Image
19
+ import pandas as pd
20
+ import torch
21
+ import torch.nn as nn
22
+ import torchvision.models as models
23
+ from torchvision import transforms
24
+ from torchvision.models import EfficientNet_B0_Weights
25
+
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ class ImageEncoder(nn.Module):
29
+ def __init__(self, backbone="efficientnet_b0", embed_dim=512, pretrained=True, train_backbone=False):
30
+ super().__init__()
31
+ if backbone == "resnet50":
32
+ base = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)
33
+ feat_dim = base.fc.in_features
34
+ base.fc = nn.Identity()
35
+ self.backbone = base
36
+ else:
37
+ base = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None)
38
+ feat_dim = base.classifier[1].in_features
39
+ base.classifier = nn.Identity()
40
+ self.backbone = base
41
+
42
+ for p in self.backbone.parameters():
43
+ p.requires_grad = train_backbone
44
+
45
+ self.proj = nn.Sequential(
46
+ nn.Linear(feat_dim, embed_dim),
47
+ nn.ReLU(inplace=True),
48
+ nn.BatchNorm1d(embed_dim),
49
+ nn.Dropout(0.2),
50
+ )
51
+
52
+ def forward(self, x):
53
+ f = self.backbone(x) # (B, feat_dim)
54
+ f = self.proj(f) # (B, embed_dim)
55
+ return f
56
+
57
+ class TabularEncoder(nn.Module):
58
+ def __init__(self, in_dim, out_dim=128):
59
+ super().__init__()
60
+ self.net = nn.Sequential(
61
+ nn.BatchNorm1d(in_dim),
62
+ nn.Linear(in_dim, 256), nn.ReLU(inplace=True),
63
+ nn.Dropout(0.2),
64
+ nn.Linear(256, out_dim), nn.ReLU(inplace=True),
65
+ )
66
+ def forward(self, x):
67
+ return self.net(x)
68
+
69
+ class MultimodalNet(nn.Module):
70
+ def __init__(self, tab_in_dim, num_classes=4, img_embed_dim=512, tab_embed_dim=128,
71
+ backbone="efficientnet_b0", pretrained=True, train_backbone=False):
72
+ super().__init__()
73
+ self.img_enc = ImageEncoder(backbone=backbone, embed_dim=img_embed_dim,
74
+ pretrained=pretrained, train_backbone=train_backbone)
75
+ self.tab_enc = TabularEncoder(in_dim=tab_in_dim, out_dim=tab_embed_dim)
76
+ self.head = nn.Sequential(
77
+ nn.Linear(img_embed_dim + tab_embed_dim, 256),
78
+ nn.ReLU(inplace=True),
79
+ nn.BatchNorm1d(256),
80
+ nn.Dropout(0.4),
81
+ nn.Linear(256, 128),
82
+ nn.ReLU(inplace=True),
83
+ nn.Dropout(0.3),
84
+ nn.Linear(128, num_classes)
85
+ )
86
+
87
+ def forward(self, front_img, back_img, tab_x):
88
+ f_front = self.img_enc(front_img)
89
+ f_back = self.img_enc(back_img)
90
+ f_img = 0.5 * (f_front + f_back) # average two views
91
+ f_tab = self.tab_enc(tab_x)
92
+ fused = torch.cat([f_img, f_tab], dim=1)
93
+ logits = self.head(fused)
94
+ return logits
95
+
96
+ # ===== Force tabular dim to 38 no matter what's inside ckpt =====
97
+ FORCE_TAB_DIM = 38
98
+ FORCE_NUM_CLASSES = None # ตั้งเป็นเลขจริงถ้าอยากบังคับ, หรือปล่อย None ให้ดึงจาก ckpt/ดีฟอลต์
99
+
100
+ CKPT_PATH = "/content/best_multimodal.pt" # <-- แก้เป็น path ของคุณ
101
+ ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
102
+
103
+ # อย่าอ่าน tab_in_dim จาก ckpt แล้วเผลอได้ 14 มาอีก
104
+ tab_in_dim = FORCE_TAB_DIM
105
+ num_classes = int(ckpt.get("num_classes", 4) if FORCE_NUM_CLASSES is None else FORCE_NUM_CLASSES)
106
+
107
+ print("[INFO] FORCE tab_in_dim:", tab_in_dim, "| num_classes:", num_classes)
108
+
109
+ # สร้างโมเดลใหม่ให้รองรับ 38 ช่อง
110
+ model = MultimodalNet(
111
+ tab_in_dim,
112
+ num_classes=num_classes,
113
+ backbone="efficientnet_b0",
114
+ pretrained=False, # ไม่โหลด imagenet เมื่อมี ckpt เอง
115
+ train_backbone=False
116
+ ).to(DEVICE)
117
+
118
+ # โหลดเฉพาะพารามิเตอร์ที่ 'เข้ากัน' และ *ตัด* ของ tab_enc (เพราะ shape ไม่ตรง)
119
+ raw_state = ckpt.get("model", ckpt)
120
+ # กรองทิ้งทั้งหมดที่ขึ้นต้นด้วย 'tab_enc.' เพื่อกันการโหลด buffer/พารามิเตอร์ 14 ช่อง
121
+ state = {k: v for k, v in raw_state.items() if not k.startswith("tab_enc.")}
122
+ missing, unexpected = model.load_state_dict(state, strict=False)
123
+ print("[load_state_dict] Missing:", missing)
124
+ print("[load_state_dict] Unexpected:", unexpected)
125
+
126
+ # ตรว��สอบว่า TabularEncoder เป็น 38 จริง (ต้องเห็น BatchNorm1d(38), Linear(in=38 -> 256))
127
+ print("[VERIFY] model.tab_enc.net =", model.tab_enc.net)
128
+
129
+ model.eval()
130
+
131
+ # base 14 ฟีเจอร์ที่ใช้ตอนเทรน
132
+ BASE_14 = [
133
+ "pilling","condition","pattern","stains","holes",
134
+ "damage_count","damage_severity",
135
+ "brand","type","size","season","category","main_color","usage"
136
+ ]
137
+ NUM_COLS_BASE = ["pilling","condition","damage_count"]
138
+
139
+ CUT_CATEGORIES = ['collar','v-collar','tight','loose','regular','turtle-neck','cropped','long']
140
+
141
+ MATERIAL_CATEGORIES = [
142
+ 'cotton','polyester','viscose','acrylic','nylon',
143
+ 'elastane','wool','rayon','silk','linen','spandex',
144
+ 'lycra','bamboo','alpaca','lyocell','cashmere'
145
+ ]
146
+
147
+ # ล็อก scheme เป็น 38
148
+ SCHEME = {"use_cut": True, "use_mat_vec": True, "use_mat_count": False}
149
+ tab_in_dim = 38
150
+
151
+ def encode_material_rows(flat_vals):
152
+ """
153
+ flat_vals: ลิสต์เรียงเป็น [p1, m1, p2, m2, ...]
154
+ คืนเวกเตอร์ยาว 16 ตาม MATERIAL_CATEGORIES
155
+ """
156
+ agg = {k: 0.0 for k in MATERIAL_CATEGORIES}
157
+ it = iter(flat_vals)
158
+ for p, mat in zip(it, it): # เดินทีละคู่
159
+ try:
160
+ pct = float(p) if p is not None else 0.0
161
+ except:
162
+ pct = 0.0
163
+ if mat in agg:
164
+ agg[mat] += max(0.0, pct)
165
+ return [agg[k] for k in MATERIAL_CATEGORIES]
166
+
167
+ # ---------- 1) เลือกสคีมาฟีเจอร์ตาม tab_in_dim ----------
168
+ def get_feature_scheme(tab_in_dim):
169
+ """
170
+ - 14: base 14
171
+ - 23: base 14 + cut(8) + material_count(1)
172
+ - 30: base 14 + material_vector(16)
173
+ - 38: base 14 + cut(8) + material_vector(16)
174
+ """
175
+ if tab_in_dim == 14:
176
+ return {"use_base": True, "use_cut": False, "use_mat_vec": False, "use_mat_count": False}
177
+ if tab_in_dim == 23:
178
+ return {"use_base": True, "use_cut": True, "use_mat_vec": False, "use_mat_count": True}
179
+ if tab_in_dim == 30:
180
+ return {"use_base": True, "use_cut": False, "use_mat_vec": True, "use_mat_count": False}
181
+ if tab_in_dim == 38:
182
+ return {"use_base": True, "use_cut": True, "use_mat_vec": True, "use_mat_count": False}
183
+ raise ValueError(f"ไม่รู้จัก tab_in_dim={tab_in_dim} (รองรับ 14/23/30/38)")
184
+
185
+ SCHEME = get_feature_scheme(tab_in_dim)
186
+
187
+ # ---------- 2) ฟีเจอร์ฐาน 14 ตัว (ของคุณ) ----------
188
+ BASE_14 = [
189
+ "pilling","condition","pattern","stains","holes",
190
+ "damage_count","damage_severity",
191
+ "brand","type","size","season","category","main_color","usage"
192
+ ]
193
+ NUM_COLS_BASE = ["pilling","condition","damage_count"]
194
+
195
+ # (ตรงนี้วาง cat_maps ทั้งชุดของคุณ: brand/type/size/pattern/stains/holes/damage_severity/usage/main_color/season/category)
196
+
197
+ # ---------- 3) CUT & MATERIAL utilities (จากที่คุณส่งมา) ----------
198
+ import re
199
+
200
+ CUT_CATEGORIES = ['collar','v-collar','tight','loose','regular','turtle-neck','cropped','long']
201
+
202
+ def clean_cut(cut_list):
203
+ if isinstance(cut_list, str):
204
+ try:
205
+ cut_list = eval(cut_list) if cut_list.strip().startswith("[") else cut_list.split(',')
206
+ except Exception:
207
+ cut_list = [cut_list]
208
+ cut_list = [c.strip().lower() for c in cut_list if c]
209
+ mapping = {
210
+ 'c-collar':'collar', 'c collar':'collar', 'collar':'collar',
211
+ 'v-collar':'v-collar', 'v collar':'v-collar',
212
+ 'tight':'tight', 'loose':'loose', 'oversize':'loose',
213
+ 'regular':'regular',
214
+ 'turtle neck':'turtle-neck', 'turtleneck':'turtle-neck',
215
+ 'cropped':'cropped', 'long':'long'
216
+ }
217
+ cleaned = set(mapping.get(x, x) for x in cut_list)
218
+ return [c for c in cleaned if c in CUT_CATEGORIES]
219
+
220
+ def cuts_to_multihot(cuts):
221
+ return [1 if cat in (cuts or []) else 0 for cat in CUT_CATEGORIES]
222
+
223
+ MATERIAL_CATEGORIES = [
224
+ 'cotton','polyester','viscose','acrylic','nylon',
225
+ 'elastane','wool','rayon','silk','linen','spandex',
226
+ 'lycra','bamboo','alpaca','lyocell','cashmere'
227
+ ]
228
+
229
+ def parse_material(text):
230
+ if text is None:
231
+ return {}
232
+ text = str(text).strip().lower()
233
+ if text in ['not available','unknown','','scanner can not read material.']:
234
+ return {}
235
+ comps = re.findall(r'(\d+)\s*%\s*([a-z]+)', text)
236
+ out = {}
237
+ for pct, mat in comps:
238
+ try:
239
+ out[mat] = int(pct)
240
+ except:
241
+ pass
242
+ return out
243
+
244
+ def material_to_vector(mat_dict):
245
+ return [mat_dict.get(cat, 0) for cat in MATERIAL_CATEGORIES]
246
+
247
+ def material_count_from_dict(mat_dict):
248
+ return sum(1 for v in mat_dict.values() if float(v) > 0)
249
+
250
+ # ---------- 4) encoder แบบ “dynamic” ให้ตรงกับโมเดล ----------
251
+ def encode_tab_from_form(base_vals, cut_selected=None, mat_count_val=None, mat_text_val=None):
252
+ # 4.1 base 14
253
+ vec = []
254
+ for col, v in zip(BASE_14, base_vals):
255
+ if col in NUM_COLS_BASE:
256
+ vec.append(float(v))
257
+ else:
258
+ m = cat_maps[col]
259
+ idx = m.get(v, list(m.values())[0]) # fallback
260
+ vec.append(float(idx))
261
+
262
+ # 4.2 cut (8)
263
+ if SCHEME["use_cut"]:
264
+ cleaned = clean_cut(cut_selected) if cut_selected else []
265
+ vec.extend(cuts_to_multihot(cleaned))
266
+
267
+ # 4.3 material
268
+ if SCHEME["use_mat_count"]:
269
+ val = 0 if mat_count_val is None else float(mat_count_val)
270
+ vec.append(val)
271
+
272
+ if SCHEME["use_mat_vec"]:
273
+ mdict = parse_material(mat_text_val)
274
+ vec.extend(material_to_vector(mdict))
275
+
276
+ x = torch.tensor([vec], dtype=torch.float32, device=DEVICE)
277
+ assert x.shape[1] == tab_in_dim, f"Encoded dim {x.shape[1]} != tab_in_dim {tab_in_dim}"
278
+ return x
279
+
280
+ import re
281
+
282
+ # ----- CUT -----
283
+ CUT_CATEGORIES = ['collar', 'v-collar', 'tight', 'loose', 'regular', 'turtle-neck', 'cropped', 'long']
284
+
285
+ def clean_cut(cut_list):
286
+ # รองรับทั้ง list และ string
287
+ if isinstance(cut_list, str):
288
+ try:
289
+ cut_list = eval(cut_list) if cut_list.strip().startswith("[") else cut_list.split(',')
290
+ except Exception:
291
+ cut_list = [cut_list]
292
+ cut_list = [c.strip().lower() for c in cut_list if c]
293
+
294
+ mapping = {
295
+ 'c-collar': 'collar', 'c collar': 'collar', 'collar': 'collar',
296
+ 'v-collar': 'v-collar', 'v collar': 'v-collar',
297
+ 'tight': 'tight', 'loose': 'loose', 'oversize': 'loose',
298
+ 'regular': 'regular',
299
+ 'turtle neck': 'turtle-neck', 'turtleneck': 'turtle-neck',
300
+ 'cropped': 'cropped', 'long': 'long'
301
+ }
302
+ cleaned = set()
303
+ for item in cut_list:
304
+ key = item.strip().lower()
305
+ cleaned.add(mapping.get(key, key))
306
+ return [c for c in cleaned if c in CUT_CATEGORIES]
307
+
308
+ def cuts_to_multihot(cuts):
309
+ return [1 if cat in cuts else 0 for cat in CUT_CATEGORIES]
310
+
311
+ # ----- MATERIAL -----
312
+ MATERIAL_CATEGORIES = [
313
+ 'cotton','polyester','viscose','acrylic','nylon',
314
+ 'elastane','wool','rayon','silk','linen','spandex',
315
+ 'lycra','bamboo','alpaca','lyocell','cashmere'
316
+ ]
317
+
318
+ def parse_material(text):
319
+ """
320
+ รับสตริงแบบ '60% cotton 40% polyester' หรือกรณีไม่พร้อมใช้งาน
321
+ คืน dict เช่น {'cotton':60, 'polyester':40}
322
+ """
323
+ if text is None:
324
+ return {}
325
+ text = str(text).strip().lower()
326
+ if text in ['not available','unknown','','scanner can not read material.']:
327
+ return {}
328
+ comps = re.findall(r'(\d+)\s*%\s*([a-z]+)', text)
329
+ out = {}
330
+ for pct, mat in comps:
331
+ try:
332
+ out[mat] = int(pct)
333
+ except Exception:
334
+ pass
335
+ return out
336
+
337
+ def material_to_vector(mat_dict):
338
+ """เวกเตอร์ยาว 16 ตาม MATERIAL_CATEGORIES (ค่าร้อยละ 0..100)"""
339
+ return [mat_dict.get(cat, 0) for cat in MATERIAL_CATEGORIES]
340
+
341
+ def material_count_from_dict(mat_dict):
342
+ """นับชนิดวัสดุที่มีสัดส่วน > 0 เพื่อใช้กับเคส 23 มิติ"""
343
+ return sum(1 for v in mat_dict.values() if float(v) > 0)
344
+
345
+ def get_feature_scheme(tab_in_dim):
346
+ """
347
+ คืน dict ที่อธิบายว่าโมเดลต้องการฟีเจอร์อะไรบ้าง
348
+ - 14: base 14
349
+ - 23: base 14 + cut(8) + material_count(1)
350
+ - 30: base 14 + material_vector(16)
351
+ - 38: base 14 + cut(8) + material_vector(16)
352
+ """
353
+ if tab_in_dim == 14:
354
+ return {"use_base": True, "use_cut": False, "use_mat_vec": False, "use_mat_count": False}
355
+ if tab_in_dim == 23:
356
+ return {"use_base": True, "use_cut": True, "use_mat_vec": False, "use_mat_count": True}
357
+ if tab_in_dim == 30:
358
+ return {"use_base": True, "use_cut": False, "use_mat_vec": True, "use_mat_count": False}
359
+ if tab_in_dim == 38:
360
+ return {"use_base": True, "use_cut": True, "use_mat_vec": True, "use_mat_count": False}
361
+ raise ValueError(f"ไม่รู้จัก tab_in_dim={tab_in_dim} (รองรับ 14/23/30/38)")
362
+
363
+ weights = EfficientNet_B0_Weights.IMAGENET1K_V1
364
+ img_tf = transforms.Compose([
365
+ transforms.Resize(256),
366
+ transforms.CenterCrop(224),
367
+ transforms.ToTensor(),
368
+ transforms.Normalize(mean=weights.transforms().mean, std=weights.transforms().std),
369
+ ])
370
+
371
+ def preprocess_image(pil_img: Image.Image):
372
+ return img_tf(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) # (1,3,224,224)
373
+
374
+ # ===== ใช้ลำดับฟีเจอร์ 14 ช่อง ตามที่คุณเทรน =====
375
+ TAB_FEATS = [
376
+ "pilling","condition","pattern","stains","holes",
377
+ "damage_count","damage_severity",
378
+ "brand","type","size","season","category","main_color","usage"
379
+ ]
380
+
381
+ # ===== mapping จริงจากโน้ตบุ๊ก (เรียง index ให้ตรงกับที่ map ในไฟล์) =====
382
+ cat_maps = {
383
+ "brand": {
384
+ "Non-Brand": 0,
385
+ "Fast Fashion & High Street Retailers": 1,
386
+ "Other Brands": 2,
387
+ "Store Brands": 3,
388
+ "Niche Brands": 4,
389
+ "Premium & Designer": 5,
390
+ "Sportswear & Outdoor": 6,
391
+ },
392
+ "pattern": {
393
+ "Solid": 0,
394
+ "Printed": 1,
395
+ "Texture_Embellishment": 2,
396
+ "Other": 3,
397
+ },
398
+ "type": {
399
+ "topwear": 0,
400
+ "dresswear": 1,
401
+ "bottomwear": 2,
402
+ "outerwear": 3,
403
+ "other": 4,
404
+ "sleepwear": 5,
405
+ },
406
+ "size": {
407
+ "unknown": 0,
408
+ "xs": 1,
409
+ "s": 2,
410
+ "m": 3,
411
+ "l": 4,
412
+ "xl": 5,
413
+ "xxl": 6,
414
+ "kids": 7,
415
+ "onesize": 8,
416
+ },
417
+ "season": {
418
+ "All": 0, "Summer": 1, "Spring": 2, "Autumn": 3, "None": 4, "Winter": 5
419
+ },
420
+ "category": {
421
+ "Ladies": 0, "Men": 1, "Children": 2, "Unisex": 3
422
+ },
423
+ "main_color": {
424
+ "black": 0, "white": 1, "blue": 2, "multicolor": 3, "pink": 4,
425
+ "grey": 5, "beige": 6, "red": 7, "green": 8, "purple": 9,
426
+ "brown": 10, "yellow": 11, "orange": 12, "turquoise": 13, "none": 14
427
+ },
428
+ "usage": {
429
+ "export": 0, "reuse": 1, "recycle": 2, "repair": 3
430
+ },
431
+ "stains": {"No": 0, "Yes": 1},
432
+ "holes": {"None": 0, "Minor": 1, "Major": 2},
433
+ "damage_severity": {
434
+ "No Damage": 0, "Minor Damage": 1, "Moderate Damage": 2, "Severe Damage": 3
435
+ },
436
+ }
437
+
438
+ # ===== สเปกอินพุตสำหรับสร้าง UI ใน Gradio =====
439
+ FEATURE_SPECS = {
440
+ # numeric (ใช้ค่าเดิม)
441
+ "pilling": {"kind":"number","min":0,"max":5,"step":1,"default":3},
442
+ "condition":{"kind":"number","min":0,"max":5,"step":1,"default":2},
443
+ "damage_count":{"kind":"number","min":0,"max":20,"step":1,"default":0},
444
+
445
+ # categorical (choices = list(cat_maps[col].keys()))
446
+ "pattern":{"kind":"cat","choices":list(cat_maps["pattern"].keys()),"default":"Solid"},
447
+ "stains":{"kind":"cat","choices":list(cat_maps["stains"].keys()),"default":"No"},
448
+ "holes":{"kind":"cat","choices":list(cat_maps["holes"].keys()),"default":"None"},
449
+ "damage_severity":{"kind":"cat","choices":list(cat_maps["damage_severity"].keys()),"default":"No Damage"},
450
+
451
+ "brand":{"kind":"cat","choices":list(cat_maps["brand"].keys()),"default":"Non-Brand"},
452
+ "type":{"kind":"cat","choices":list(cat_maps["type"].keys()),"default":"topwear"},
453
+ "size":{"kind":"cat","choices":list(cat_maps["size"].keys()),"default":"m"},
454
+ "season":{"kind":"cat","choices":list(cat_maps["season"].keys()),"default":"All"},
455
+ "category":{"kind":"cat","choices":list(cat_maps["category"].keys()),"default":"Ladies"},
456
+ "main_color":{"kind":"cat","choices":list(cat_maps["main_color"].keys()),"default":"black"},
457
+ "usage":{"kind":"cat","choices":list(cat_maps["usage"].keys()),"default":"reuse"},
458
+ }
459
+
460
+ NUM_COLS = ["pilling","condition","damage_count"]
461
+ CAT_COLS = [c for c in TAB_FEATS if c not in NUM_COLS]
462
+
463
+ def encode_tab(tab_dict):
464
+ """
465
+ แปลงค่าจากฟอร์ม → เวกเตอร์ตามลำดับ TAB_FEATS
466
+ - number: ใช้ค่า float ตรง ๆ (ไม่มี scaler ตามไฟล์เทรนของคุณ)
467
+ - categorical: map ชื่อ → index ตาม cat_maps (unknown → index 0 ของคอลัมน์นั้น)
468
+ """
469
+ vec = []
470
+ for col in TAB_FEATS:
471
+ if col in NUM_COLS:
472
+ vec.append(float(tab_dict[col]))
473
+ else:
474
+ m = cat_maps[col]
475
+ # ถ้าผู้ใช้ส่งค่าที่ไม่มีใน mapping ให้ fallback เป็นตัวแรก
476
+ idx = m.get(tab_dict[col], list(m.values())[0])
477
+ vec.append(float(idx))
478
+ return torch.tensor([vec], dtype=torch.float32, device=DEVICE)
479
+
480
+ FX_RATE = 3.4 # 1 SEK ≈ 3.4 บาท (เปลี่ยนได้)
481
+ CLASS_NAMES = ["<50", "50-100", "100-150", "150+"] # ตัวอย่าง 4 คลาส
482
+
483
+ import re
484
+ def convert_label_sek_to_thb(label, rate=FX_RATE):
485
+ """
486
+ label: สตริงช่วงราคาเป็น SEK เช่น "<50", "50-100", "150+"
487
+ คืนค่า: สตริงช่วงราคาเป็นบาท เช่น "<170 บาท", "170-340 บาท", "510+ บาท"
488
+ """
489
+ s = str(label).strip().lower()
490
+ nums = [int(x) for x in re.findall(r"\d+", s)]
491
+ if not nums:
492
+ return label
493
+
494
+ if s.startswith("<"):
495
+ return f"<{int(round(nums[0]*rate))} บาท"
496
+ if s.endswith("+"):
497
+ return f"{int(round(nums[0]*rate))}+ บาท"
498
+ if "-" in s and len(nums) == 2:
499
+ a, b = nums
500
+ return f"{int(round(a*rate))}-{int(round(b*rate))} บาท"
501
+ return f"{int(round(nums[0]*rate))} บาท"
502
+
503
+
504
+ def predict(front_img, back_img, *vals):
505
+ try:
506
+ if front_img is None or back_img is None:
507
+ return "กรุณาอัปโหลดร��ปทั้งสองภาพ", None
508
+
509
+ base_count = len(BASE_14)
510
+ i = 0
511
+ base_vals = list(vals[i:i+base_count]); i += base_count
512
+
513
+ # cut (CheckboxGroup)
514
+ cut_selected = vals[i]; i += 1
515
+
516
+ # material vector: N คู่ (percent, type) — ต้องเท่ากับ MAX_MATS ใน cell 16
517
+ MAX_MATS = 5
518
+ flat = []
519
+ for _ in range(MAX_MATS):
520
+ p = vals[i]; i += 1
521
+ m = vals[i]; i += 1
522
+ flat.extend([p, m])
523
+
524
+ # ---------- encode ----------
525
+ vec = []
526
+ # 1) base 14
527
+ for col, v in zip(BASE_14, base_vals):
528
+ if col in NUM_COLS_BASE:
529
+ vec.append(float(v))
530
+ else:
531
+ m = cat_maps[col]
532
+ idx = m.get(v, list(m.values())[0]) # fallback
533
+ vec.append(float(idx))
534
+
535
+ # 2) cut → multihot 8
536
+ cleaned = clean_cut(cut_selected) if cut_selected else []
537
+ vec.extend(cuts_to_multihot(cleaned))
538
+
539
+ # 3) material vector → 16
540
+ mvec = encode_material_rows(flat) # รวมเปอร์เซ็นต์ตามชนิด → ลิสต์ 16 ช่อง
541
+ vec.extend(mvec)
542
+
543
+ xt = torch.tensor([vec], dtype=torch.float32, device=DEVICE)
544
+ assert xt.shape[1] == tab_in_dim, f"Encoded dim {xt.shape[1]} != tab_in_dim {tab_in_dim}"
545
+
546
+ # ---------- infer ----------
547
+ with torch.no_grad():
548
+ x1 = preprocess_image(front_img)
549
+ x2 = preprocess_image(back_img)
550
+ logits = model(x1, x2, xt)
551
+ probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
552
+
553
+ top_idx = int(np.argmax(probs))
554
+ top_name_sek = str(CLASS_NAMES[top_idx]) # เช่น "<50"
555
+ top_name_thb = convert_label_sek_to_thb(top_name_sek)
556
+
557
+ import pandas as pd
558
+ rows = []
559
+ for name, p in zip(CLASS_NAMES, probs.tolist()):
560
+ thb = convert_label_sek_to_thb(str(name))
561
+ rows.append([thb, round(float(p), 4)])
562
+
563
+ df = pd.DataFrame(rows, columns=["class_THB", "probability"])
564
+
565
+ return f"ผลทำนาย: {top_name_thb}", df
566
+
567
+ except Exception as e:
568
+ import traceback
569
+ return f"Error: {type(e).__name__} - {e}\n" + traceback.format_exc(), None
570
+
571
+ def encode_material_rows(flat_vals):
572
+ """
573
+ flat_vals: ลิสต์เรียงเป็น [p1, m1, p2, m2, ...] ที่มากับ *vals ใน predict()
574
+ คืนเวกเตอร์ยาว 16 ตรงตาม MATERIAL_CATEGORIES (ค่าร้อยละรวมกันตามชนิด)
575
+ """
576
+ # รวมเปอร์เซ็นต์ตามชนิด
577
+ agg = {k: 0.0 for k in MATERIAL_CATEGORIES}
578
+ it = iter(flat_vals)
579
+ for p, mat in zip(it, it): # เดินทีละคู่
580
+ try:
581
+ pct = float(p) if p is not None else 0.0
582
+ except:
583
+ pct = 0.0
584
+ if mat in agg:
585
+ agg[mat] += max(0.0, pct) # กันค่าติดลบ
586
+ # เติมเป็นลิสต์ตามลำดับคงที่
587
+ return [agg[k] for k in MATERIAL_CATEGORIES]
588
+
589
+ with gr.Blocks(title="Multimodal (2 Images + Tabular)") as demo:
590
+ gr.Markdown("### โมเดลจำแนกด้วย 2 รูป + คุณลักษณะตาราง")
591
+
592
+ with gr.Row():
593
+ img_front = gr.Image(type="pil", label="รูปด้านหน้า")
594
+ img_back = gr.Image(type="pil", label="รูปด้านหลัง")
595
+
596
+ # อินพุต base 14
597
+ tab_inputs = []
598
+ with gr.Row():
599
+ for k in BASE_14:
600
+ spec = FEATURE_SPECS[k] # ต้องมี FEATURE_SPECS ตาม mapping ที่คุณใส่ไว้
601
+ if spec["kind"] == "number":
602
+ tab_inputs.append(gr.Slider(minimum=spec["min"], maximum=spec["max"],
603
+ step=spec["step"], value=spec["default"], label=k))
604
+ else:
605
+ tab_inputs.append(gr.Dropdown(choices=spec["choices"],
606
+ value=spec["default"], label=k))
607
+
608
+ # CUT
609
+ cut_input = gr.CheckboxGroup(label="cut (เลือกได้หลายค่า)",
610
+ choices=CUT_CATEGORIES, value=[])
611
+
612
+ # MATERIAL_VECTOR — ให้กรอกได้สูงสุด 5 ชนิด
613
+ MAX_MATS = 5
614
+ material_pairs = []
615
+ gr.Markdown("**วัสดุ (เปอร์เซ็นต์ + ชนิด)** เช่น 60% cotton, 40% polyester")
616
+ with gr.Column():
617
+ for i in range(MAX_MATS):
618
+ with gr.Row():
619
+ p = gr.Number(value=0, label=f"material_{i+1}_percent")
620
+ m = gr.Dropdown(choices=MATERIAL_CATEGORIES, value=MATERIAL_CATEGORIES[0],
621
+ label=f"material_{i+1}_type")
622
+ material_pairs.append((p, m))
623
+
624
+ # รวมอินพุตทั้งหมด
625
+ predict_inputs = [img_front, img_back] + tab_inputs + [cut_input]
626
+ for p, m in material_pairs:
627
+ predict_inputs.extend([p, m])
628
+
629
+ # ปุ่ม + เอาต์พุต
630
+ btn = gr.Button("ทำนาย")
631
+ out_txt = gr.Textbox(label="สรุปผล")
632
+ out_tbl = gr.Dataframe(headers=["class","probability"],
633
+ datatype=["str","number"], label="ความน่าจะเป็น")
634
+
635
+ btn.click(predict, inputs=predict_inputs, outputs=[out_txt, out_tbl])
636
+
637
+ demo.launch()