Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """Gladio-webapp.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/11rgLJIwe-BYZs3NcVMFz4hnq6XIzxfsv | |
| """ | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| import PIL | |
| from PIL import Image | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| from torchvision import transforms | |
| from torchvision.models import EfficientNet_B0_Weights | |
| from pathlib import Path | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| class ImageEncoder(nn.Module): | |
| def __init__(self, backbone="efficientnet_b0", embed_dim=512, pretrained=True, train_backbone=False): | |
| super().__init__() | |
| if backbone == "resnet50": | |
| base = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None) | |
| feat_dim = base.fc.in_features | |
| base.fc = nn.Identity() | |
| self.backbone = base | |
| else: | |
| base = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None) | |
| feat_dim = base.classifier[1].in_features | |
| base.classifier = nn.Identity() | |
| self.backbone = base | |
| for p in self.backbone.parameters(): | |
| p.requires_grad = train_backbone | |
| self.proj = nn.Sequential( | |
| nn.Linear(feat_dim, embed_dim), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm1d(embed_dim), | |
| nn.Dropout(0.2), | |
| ) | |
| def forward(self, x): | |
| f = self.backbone(x) # (B, feat_dim) | |
| f = self.proj(f) # (B, embed_dim) | |
| return f | |
| class TabularEncoder(nn.Module): | |
| def __init__(self, in_dim, out_dim=128): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.BatchNorm1d(in_dim), | |
| nn.Linear(in_dim, 256), nn.ReLU(inplace=True), | |
| nn.Dropout(0.2), | |
| nn.Linear(256, out_dim), nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class MultimodalNet(nn.Module): | |
| def __init__(self, tab_in_dim, num_classes=4, img_embed_dim=512, tab_embed_dim=128, | |
| backbone="efficientnet_b0", pretrained=True, train_backbone=False): | |
| super().__init__() | |
| self.img_enc = ImageEncoder(backbone=backbone, embed_dim=img_embed_dim, | |
| pretrained=pretrained, train_backbone=train_backbone) | |
| self.tab_enc = TabularEncoder(in_dim=tab_in_dim, out_dim=tab_embed_dim) | |
| self.head = nn.Sequential( | |
| nn.Linear(img_embed_dim + tab_embed_dim, 256), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm1d(256), | |
| nn.Dropout(0.4), | |
| nn.Linear(256, 128), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.3), | |
| nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, front_img, back_img, tab_x): | |
| f_front = self.img_enc(front_img) | |
| f_back = self.img_enc(back_img) | |
| f_img = 0.5 * (f_front + f_back) # average two views | |
| f_tab = self.tab_enc(tab_x) | |
| fused = torch.cat([f_img, f_tab], dim=1) | |
| logits = self.head(fused) | |
| return logits | |
| # ===== Force tabular dim to 38 no matter what's inside ckpt ===== | |
| FORCE_TAB_DIM = 38 | |
| FORCE_NUM_CLASSES = None # ตั้งเป็นเลขจริงถ้าอยากบังคับ, หรือปล่อย None ให้ดึงจาก ckpt/ดีฟอลต์ | |
| BASE_DIR = Path(__file__).resolve().parent | |
| CKPT_PATH = BASE_DIR / "best_multimodal.pt" # หรือ BASE_DIR / "models" / "best_multimodal.pt" | |
| ckpt = torch.load(str(CKPT_PATH), map_location=DEVICE) | |
| # อย่าอ่าน tab_in_dim จาก ckpt แล้วเผลอได้ 14 มาอีก | |
| tab_in_dim = FORCE_TAB_DIM | |
| num_classes = int(ckpt.get("num_classes", 4) if FORCE_NUM_CLASSES is None else FORCE_NUM_CLASSES) | |
| print("[INFO] FORCE tab_in_dim:", tab_in_dim, "| num_classes:", num_classes) | |
| # สร้างโมเดลใหม่ให้รองรับ 38 ช่อง | |
| model = MultimodalNet( | |
| tab_in_dim, | |
| num_classes=num_classes, | |
| backbone="efficientnet_b0", | |
| pretrained=False, # ไม่โหลด imagenet เมื่อมี ckpt เอง | |
| train_backbone=False | |
| ).to(DEVICE) | |
| # โหลดเฉพาะพารามิเตอร์ที่ 'เข้ากัน' และ *ตัด* ของ tab_enc (เพราะ shape ไม่ตรง) | |
| raw_state = ckpt.get("model", ckpt) | |
| # กรองทิ้งทั้งหมดที่ขึ้นต้นด้วย 'tab_enc.' เพื่อกันการโหลด buffer/พารามิเตอร์ 14 ช่อง | |
| state = {k: v for k, v in raw_state.items() if not k.startswith("tab_enc.")} | |
| missing, unexpected = model.load_state_dict(state, strict=False) | |
| print("[load_state_dict] Missing:", missing) | |
| print("[load_state_dict] Unexpected:", unexpected) | |
| # ตรวจสอบว่า TabularEncoder เป็น 38 จริง (ต้องเห็น BatchNorm1d(38), Linear(in=38 -> 256)) | |
| print("[VERIFY] model.tab_enc.net =", model.tab_enc.net) | |
| model.eval() | |
| # base 14 ฟีเจอร์ที่ใช้ตอนเทรน | |
| BASE_14 = [ | |
| "pilling","condition","pattern","stains","holes", | |
| "damage_count","damage_severity", | |
| "brand","type","size","season","category","main_color","usage" | |
| ] | |
| NUM_COLS_BASE = ["pilling","condition","damage_count"] | |
| CUT_CATEGORIES = ['collar','v-collar','tight','loose','regular','turtle-neck','cropped','long'] | |
| MATERIAL_CATEGORIES = [ | |
| 'cotton','polyester','viscose','acrylic','nylon', | |
| 'elastane','wool','rayon','silk','linen','spandex', | |
| 'lycra','bamboo','alpaca','lyocell','cashmere' | |
| ] | |
| # ล็อก scheme เป็น 38 | |
| SCHEME = {"use_cut": True, "use_mat_vec": True, "use_mat_count": False} | |
| tab_in_dim = 38 | |
| def encode_material_rows(flat_vals): | |
| """ | |
| flat_vals: ลิสต์เรียงเป็น [p1, m1, p2, m2, ...] | |
| คืนเวกเตอร์ยาว 16 ตาม MATERIAL_CATEGORIES | |
| """ | |
| agg = {k: 0.0 for k in MATERIAL_CATEGORIES} | |
| it = iter(flat_vals) | |
| for p, mat in zip(it, it): # เดินทีละคู่ | |
| try: | |
| pct = float(p) if p is not None else 0.0 | |
| except: | |
| pct = 0.0 | |
| if mat in agg: | |
| agg[mat] += max(0.0, pct) | |
| return [agg[k] for k in MATERIAL_CATEGORIES] | |
| # ---------- 1) เลือกสคีมาฟีเจอร์ตาม tab_in_dim ---------- | |
| def get_feature_scheme(tab_in_dim): | |
| """ | |
| - 14: base 14 | |
| - 23: base 14 + cut(8) + material_count(1) | |
| - 30: base 14 + material_vector(16) | |
| - 38: base 14 + cut(8) + material_vector(16) | |
| """ | |
| if tab_in_dim == 14: | |
| return {"use_base": True, "use_cut": False, "use_mat_vec": False, "use_mat_count": False} | |
| if tab_in_dim == 23: | |
| return {"use_base": True, "use_cut": True, "use_mat_vec": False, "use_mat_count": True} | |
| if tab_in_dim == 30: | |
| return {"use_base": True, "use_cut": False, "use_mat_vec": True, "use_mat_count": False} | |
| if tab_in_dim == 38: | |
| return {"use_base": True, "use_cut": True, "use_mat_vec": True, "use_mat_count": False} | |
| raise ValueError(f"ไม่รู้จัก tab_in_dim={tab_in_dim} (รองรับ 14/23/30/38)") | |
| SCHEME = get_feature_scheme(tab_in_dim) | |
| # ---------- 2) ฟีเจอร์ฐาน 14 ตัว (ของคุณ) ---------- | |
| BASE_14 = [ | |
| "pilling","condition","pattern","stains","holes", | |
| "damage_count","damage_severity", | |
| "brand","type","size","season","category","main_color","usage" | |
| ] | |
| NUM_COLS_BASE = ["pilling","condition","damage_count"] | |
| # (ตรงนี้วาง cat_maps ทั้งชุดของคุณ: brand/type/size/pattern/stains/holes/damage_severity/usage/main_color/season/category) | |
| # ---------- 3) CUT & MATERIAL utilities (จากที่คุณส่งมา) ---------- | |
| import re | |
| CUT_CATEGORIES = ['collar','v-collar','tight','loose','regular','turtle-neck','cropped','long'] | |
| def clean_cut(cut_list): | |
| if isinstance(cut_list, str): | |
| try: | |
| cut_list = eval(cut_list) if cut_list.strip().startswith("[") else cut_list.split(',') | |
| except Exception: | |
| cut_list = [cut_list] | |
| cut_list = [c.strip().lower() for c in cut_list if c] | |
| mapping = { | |
| 'c-collar':'collar', 'c collar':'collar', 'collar':'collar', | |
| 'v-collar':'v-collar', 'v collar':'v-collar', | |
| 'tight':'tight', 'loose':'loose', 'oversize':'loose', | |
| 'regular':'regular', | |
| 'turtle neck':'turtle-neck', 'turtleneck':'turtle-neck', | |
| 'cropped':'cropped', 'long':'long' | |
| } | |
| cleaned = set(mapping.get(x, x) for x in cut_list) | |
| return [c for c in cleaned if c in CUT_CATEGORIES] | |
| def cuts_to_multihot(cuts): | |
| return [1 if cat in (cuts or []) else 0 for cat in CUT_CATEGORIES] | |
| MATERIAL_CATEGORIES = [ | |
| 'cotton','polyester','viscose','acrylic','nylon', | |
| 'elastane','wool','rayon','silk','linen','spandex', | |
| 'lycra','bamboo','alpaca','lyocell','cashmere' | |
| ] | |
| def parse_material(text): | |
| if text is None: | |
| return {} | |
| text = str(text).strip().lower() | |
| if text in ['not available','unknown','','scanner can not read material.']: | |
| return {} | |
| comps = re.findall(r'(\d+)\s*%\s*([a-z]+)', text) | |
| out = {} | |
| for pct, mat in comps: | |
| try: | |
| out[mat] = int(pct) | |
| except: | |
| pass | |
| return out | |
| def material_to_vector(mat_dict): | |
| return [mat_dict.get(cat, 0) for cat in MATERIAL_CATEGORIES] | |
| def material_count_from_dict(mat_dict): | |
| return sum(1 for v in mat_dict.values() if float(v) > 0) | |
| # ---------- 4) encoder แบบ “dynamic” ให้ตรงกับโมเดล ---------- | |
| def encode_tab_from_form(base_vals, cut_selected=None, mat_count_val=None, mat_text_val=None): | |
| # 4.1 base 14 | |
| vec = [] | |
| for col, v in zip(BASE_14, base_vals): | |
| if col in NUM_COLS_BASE: | |
| vec.append(float(v)) | |
| else: | |
| m = cat_maps[col] | |
| idx = m.get(v, list(m.values())[0]) # fallback | |
| vec.append(float(idx)) | |
| # 4.2 cut (8) | |
| if SCHEME["use_cut"]: | |
| cleaned = clean_cut(cut_selected) if cut_selected else [] | |
| vec.extend(cuts_to_multihot(cleaned)) | |
| # 4.3 material | |
| if SCHEME["use_mat_count"]: | |
| val = 0 if mat_count_val is None else float(mat_count_val) | |
| vec.append(val) | |
| if SCHEME["use_mat_vec"]: | |
| mdict = parse_material(mat_text_val) | |
| vec.extend(material_to_vector(mdict)) | |
| x = torch.tensor([vec], dtype=torch.float32, device=DEVICE) | |
| assert x.shape[1] == tab_in_dim, f"Encoded dim {x.shape[1]} != tab_in_dim {tab_in_dim}" | |
| return x | |
| import re | |
| # ----- CUT ----- | |
| CUT_CATEGORIES = ['collar', 'v-collar', 'tight', 'loose', 'regular', 'turtle-neck', 'cropped', 'long'] | |
| def clean_cut(cut_list): | |
| # รองรับทั้ง list และ string | |
| if isinstance(cut_list, str): | |
| try: | |
| cut_list = eval(cut_list) if cut_list.strip().startswith("[") else cut_list.split(',') | |
| except Exception: | |
| cut_list = [cut_list] | |
| cut_list = [c.strip().lower() for c in cut_list if c] | |
| mapping = { | |
| 'c-collar': 'collar', 'c collar': 'collar', 'collar': 'collar', | |
| 'v-collar': 'v-collar', 'v collar': 'v-collar', | |
| 'tight': 'tight', 'loose': 'loose', 'oversize': 'loose', | |
| 'regular': 'regular', | |
| 'turtle neck': 'turtle-neck', 'turtleneck': 'turtle-neck', | |
| 'cropped': 'cropped', 'long': 'long' | |
| } | |
| cleaned = set() | |
| for item in cut_list: | |
| key = item.strip().lower() | |
| cleaned.add(mapping.get(key, key)) | |
| return [c for c in cleaned if c in CUT_CATEGORIES] | |
| def cuts_to_multihot(cuts): | |
| return [1 if cat in cuts else 0 for cat in CUT_CATEGORIES] | |
| # ----- MATERIAL ----- | |
| MATERIAL_CATEGORIES = [ | |
| 'cotton','polyester','viscose','acrylic','nylon', | |
| 'elastane','wool','rayon','silk','linen','spandex', | |
| 'lycra','bamboo','alpaca','lyocell','cashmere' | |
| ] | |
| def parse_material(text): | |
| """ | |
| รับสตริงแบบ '60% cotton 40% polyester' หรือกรณีไม่พร้อมใช้งาน | |
| คืน dict เช่น {'cotton':60, 'polyester':40} | |
| """ | |
| if text is None: | |
| return {} | |
| text = str(text).strip().lower() | |
| if text in ['not available','unknown','','scanner can not read material.']: | |
| return {} | |
| comps = re.findall(r'(\d+)\s*%\s*([a-z]+)', text) | |
| out = {} | |
| for pct, mat in comps: | |
| try: | |
| out[mat] = int(pct) | |
| except Exception: | |
| pass | |
| return out | |
| def material_to_vector(mat_dict): | |
| """เวกเตอร์ยาว 16 ตาม MATERIAL_CATEGORIES (ค่าร้อยละ 0..100)""" | |
| return [mat_dict.get(cat, 0) for cat in MATERIAL_CATEGORIES] | |
| def material_count_from_dict(mat_dict): | |
| """นับชนิดวัสดุที่มีสัดส่วน > 0 เพื่อใช้กับเคส 23 มิติ""" | |
| return sum(1 for v in mat_dict.values() if float(v) > 0) | |
| def get_feature_scheme(tab_in_dim): | |
| """ | |
| คืน dict ที่อธิบายว่าโมเดลต้องการฟีเจอร์อะไรบ้าง | |
| - 14: base 14 | |
| - 23: base 14 + cut(8) + material_count(1) | |
| - 30: base 14 + material_vector(16) | |
| - 38: base 14 + cut(8) + material_vector(16) | |
| """ | |
| if tab_in_dim == 14: | |
| return {"use_base": True, "use_cut": False, "use_mat_vec": False, "use_mat_count": False} | |
| if tab_in_dim == 23: | |
| return {"use_base": True, "use_cut": True, "use_mat_vec": False, "use_mat_count": True} | |
| if tab_in_dim == 30: | |
| return {"use_base": True, "use_cut": False, "use_mat_vec": True, "use_mat_count": False} | |
| if tab_in_dim == 38: | |
| return {"use_base": True, "use_cut": True, "use_mat_vec": True, "use_mat_count": False} | |
| raise ValueError(f"ไม่รู้จัก tab_in_dim={tab_in_dim} (รองรับ 14/23/30/38)") | |
| weights = EfficientNet_B0_Weights.IMAGENET1K_V1 | |
| img_tf = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=weights.transforms().mean, std=weights.transforms().std), | |
| ]) | |
| def preprocess_image(pil_img: Image.Image): | |
| return img_tf(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) # (1,3,224,224) | |
| # ===== ใช้ลำดับฟีเจอร์ 14 ช่อง ตามที่คุณเทรน ===== | |
| TAB_FEATS = [ | |
| "pilling","condition","pattern","stains","holes", | |
| "damage_count","damage_severity", | |
| "brand","type","size","season","category","main_color","usage" | |
| ] | |
| # ===== mapping จริงจากโน้ตบุ๊ก (เรียง index ให้ตรงกับที่ map ในไฟล์) ===== | |
| cat_maps = { | |
| "brand": { | |
| "Non-Brand": 0, | |
| "Fast Fashion & High Street Retailers": 1, | |
| "Other Brands": 2, | |
| "Store Brands": 3, | |
| "Niche Brands": 4, | |
| "Premium & Designer": 5, | |
| "Sportswear & Outdoor": 6, | |
| }, | |
| "pattern": { | |
| "Solid": 0, | |
| "Printed": 1, | |
| "Texture_Embellishment": 2, | |
| "Other": 3, | |
| }, | |
| "type": { | |
| "topwear": 0, | |
| "dresswear": 1, | |
| "bottomwear": 2, | |
| "outerwear": 3, | |
| "other": 4, | |
| "sleepwear": 5, | |
| }, | |
| "size": { | |
| "unknown": 0, | |
| "xs": 1, | |
| "s": 2, | |
| "m": 3, | |
| "l": 4, | |
| "xl": 5, | |
| "xxl": 6, | |
| "kids": 7, | |
| "onesize": 8, | |
| }, | |
| "season": { | |
| "All": 0, "Summer": 1, "Spring": 2, "Autumn": 3, "None": 4, "Winter": 5 | |
| }, | |
| "category": { | |
| "Ladies": 0, "Men": 1, "Children": 2, "Unisex": 3 | |
| }, | |
| "main_color": { | |
| "black": 0, "white": 1, "blue": 2, "multicolor": 3, "pink": 4, | |
| "grey": 5, "beige": 6, "red": 7, "green": 8, "purple": 9, | |
| "brown": 10, "yellow": 11, "orange": 12, "turquoise": 13, "none": 14 | |
| }, | |
| "usage": { | |
| "export": 0, "reuse": 1, "recycle": 2, "repair": 3 | |
| }, | |
| "stains": {"No": 0, "Yes": 1}, | |
| "holes": {"None": 0, "Minor": 1, "Major": 2}, | |
| "damage_severity": { | |
| "No Damage": 0, "Minor Damage": 1, "Moderate Damage": 2, "Severe Damage": 3 | |
| }, | |
| } | |
| # ===== สเปกอินพุตสำหรับสร้าง UI ใน Gradio ===== | |
| FEATURE_SPECS = { | |
| # numeric (ใช้ค่าเดิม) | |
| "pilling": {"kind":"number","min":0,"max":5,"step":1,"default":3}, | |
| "condition":{"kind":"number","min":0,"max":5,"step":1,"default":2}, | |
| "damage_count":{"kind":"number","min":0,"max":20,"step":1,"default":0}, | |
| # categorical (choices = list(cat_maps[col].keys())) | |
| "pattern":{"kind":"cat","choices":list(cat_maps["pattern"].keys()),"default":"Solid"}, | |
| "stains":{"kind":"cat","choices":list(cat_maps["stains"].keys()),"default":"No"}, | |
| "holes":{"kind":"cat","choices":list(cat_maps["holes"].keys()),"default":"None"}, | |
| "damage_severity":{"kind":"cat","choices":list(cat_maps["damage_severity"].keys()),"default":"No Damage"}, | |
| "brand":{"kind":"cat","choices":list(cat_maps["brand"].keys()),"default":"Non-Brand"}, | |
| "type":{"kind":"cat","choices":list(cat_maps["type"].keys()),"default":"topwear"}, | |
| "size":{"kind":"cat","choices":list(cat_maps["size"].keys()),"default":"m"}, | |
| "season":{"kind":"cat","choices":list(cat_maps["season"].keys()),"default":"All"}, | |
| "category":{"kind":"cat","choices":list(cat_maps["category"].keys()),"default":"Ladies"}, | |
| "main_color":{"kind":"cat","choices":list(cat_maps["main_color"].keys()),"default":"black"}, | |
| "usage":{"kind":"cat","choices":list(cat_maps["usage"].keys()),"default":"reuse"}, | |
| } | |
| NUM_COLS = ["pilling","condition","damage_count"] | |
| CAT_COLS = [c for c in TAB_FEATS if c not in NUM_COLS] | |
| def encode_tab(tab_dict): | |
| """ | |
| แปลงค่าจากฟอร์ม → เวกเตอร์ตามลำดับ TAB_FEATS | |
| - number: ใช้ค่า float ตรง ๆ (ไม่มี scaler ตามไฟล์เทรนของคุณ) | |
| - categorical: map ชื่อ → index ตาม cat_maps (unknown → index 0 ของคอลัมน์นั้น) | |
| """ | |
| vec = [] | |
| for col in TAB_FEATS: | |
| if col in NUM_COLS: | |
| vec.append(float(tab_dict[col])) | |
| else: | |
| m = cat_maps[col] | |
| # ถ้าผู้ใช้ส่งค่าที่ไม่มีใน mapping ให้ fallback เป็นตัวแรก | |
| idx = m.get(tab_dict[col], list(m.values())[0]) | |
| vec.append(float(idx)) | |
| return torch.tensor([vec], dtype=torch.float32, device=DEVICE) | |
| FX_RATE = 3.4 # 1 SEK ≈ 3.4 บาท (เปลี่ยนได้) | |
| CLASS_NAMES = ["<50", "50-100", "100-150", "150+"] # ตัวอย่าง 4 คลาส | |
| import re | |
| def convert_label_sek_to_thb(label, rate=FX_RATE): | |
| """ | |
| label: สตริงช่วงราคาเป็น SEK เช่น "<50", "50-100", "150+" | |
| คืนค่า: สตริงช่วงราคาเป็นบาท เช่น "<170 บาท", "170-340 บาท", "510+ บาท" | |
| """ | |
| s = str(label).strip().lower() | |
| nums = [int(x) for x in re.findall(r"\d+", s)] | |
| if not nums: | |
| return label | |
| if s.startswith("<"): | |
| return f"<{int(round(nums[0]*rate))} บาท" | |
| if s.endswith("+"): | |
| return f"{int(round(nums[0]*rate))}+ บาท" | |
| if "-" in s and len(nums) == 2: | |
| a, b = nums | |
| return f"{int(round(a*rate))}-{int(round(b*rate))} บาท" | |
| return f"{int(round(nums[0]*rate))} บาท" | |
| def predict(front_img, back_img, *vals): | |
| try: | |
| if front_img is None or back_img is None: | |
| return "กรุณาอัปโหลดรูปทั้งสองภาพ", None | |
| base_count = len(BASE_14) | |
| i = 0 | |
| base_vals = list(vals[i:i+base_count]); i += base_count | |
| # cut (CheckboxGroup) | |
| cut_selected = vals[i]; i += 1 | |
| # material vector: N คู่ (percent, type) — ต้องเท่ากับ MAX_MATS ใน cell 16 | |
| MAX_MATS = 5 | |
| flat = [] | |
| for _ in range(MAX_MATS): | |
| p = vals[i]; i += 1 | |
| m = vals[i]; i += 1 | |
| flat.extend([p, m]) | |
| # ---------- encode ---------- | |
| vec = [] | |
| # 1) base 14 | |
| for col, v in zip(BASE_14, base_vals): | |
| if col in NUM_COLS_BASE: | |
| vec.append(float(v)) | |
| else: | |
| m = cat_maps[col] | |
| idx = m.get(v, list(m.values())[0]) # fallback | |
| vec.append(float(idx)) | |
| # 2) cut → multihot 8 | |
| cleaned = clean_cut(cut_selected) if cut_selected else [] | |
| vec.extend(cuts_to_multihot(cleaned)) | |
| # 3) material vector → 16 | |
| mvec = encode_material_rows(flat) # รวมเปอร์เซ็นต์ตามชนิด → ลิสต์ 16 ช่อง | |
| vec.extend(mvec) | |
| xt = torch.tensor([vec], dtype=torch.float32, device=DEVICE) | |
| assert xt.shape[1] == tab_in_dim, f"Encoded dim {xt.shape[1]} != tab_in_dim {tab_in_dim}" | |
| # ---------- infer ---------- | |
| with torch.no_grad(): | |
| x1 = preprocess_image(front_img) | |
| x2 = preprocess_image(back_img) | |
| logits = model(x1, x2, xt) | |
| probs = torch.softmax(logits, dim=1).cpu().numpy()[0] | |
| top_idx = int(np.argmax(probs)) | |
| top_name_sek = str(CLASS_NAMES[top_idx]) # เช่น "<50" | |
| top_name_thb = convert_label_sek_to_thb(top_name_sek) | |
| import pandas as pd | |
| rows = [] | |
| for name, p in zip(CLASS_NAMES, probs.tolist()): | |
| thb = convert_label_sek_to_thb(str(name)) | |
| rows.append([thb, round(float(p), 4)]) | |
| df = pd.DataFrame(rows, columns=["class_THB", "probability"]) | |
| return f"ผลทำนาย: {top_name_thb}", df | |
| except Exception as e: | |
| import traceback | |
| return f"Error: {type(e).__name__} - {e}\n" + traceback.format_exc(), None | |
| def encode_material_rows(flat_vals): | |
| """ | |
| flat_vals: ลิสต์เรียงเป็น [p1, m1, p2, m2, ...] ที่มากับ *vals ใน predict() | |
| คืนเวกเตอร์ยาว 16 ตรงตาม MATERIAL_CATEGORIES (ค่าร้อยละรวมกันตามชนิด) | |
| """ | |
| # รวมเปอร์เซ็นต์ตามชนิด | |
| agg = {k: 0.0 for k in MATERIAL_CATEGORIES} | |
| it = iter(flat_vals) | |
| for p, mat in zip(it, it): # เดินทีละคู่ | |
| try: | |
| pct = float(p) if p is not None else 0.0 | |
| except: | |
| pct = 0.0 | |
| if mat in agg: | |
| agg[mat] += max(0.0, pct) # กันค่าติดลบ | |
| # เติมเป็นลิสต์ตามลำดับคงที่ | |
| return [agg[k] for k in MATERIAL_CATEGORIES] | |
| with gr.Blocks(title="Multimodal (2 Images + Tabular)") as demo: | |
| gr.Markdown("### โมเดลจำแนกด้วย 2 รูป + คุณลักษณะตาราง") | |
| with gr.Row(): | |
| img_front = gr.Image(type="pil", label="รูปด้านหน้า") | |
| img_back = gr.Image(type="pil", label="รูปด้านหลัง") | |
| # อินพุต base 14 | |
| tab_inputs = [] | |
| with gr.Row(): | |
| for k in BASE_14: | |
| spec = FEATURE_SPECS[k] # ต้องมี FEATURE_SPECS ตาม mapping ที่คุณใส่ไว้ | |
| if spec["kind"] == "number": | |
| tab_inputs.append(gr.Slider(minimum=spec["min"], maximum=spec["max"], | |
| step=spec["step"], value=spec["default"], label=k)) | |
| else: | |
| tab_inputs.append(gr.Dropdown(choices=spec["choices"], | |
| value=spec["default"], label=k)) | |
| # CUT | |
| cut_input = gr.CheckboxGroup(label="cut (เลือกได้หลายค่า)", | |
| choices=CUT_CATEGORIES, value=[]) | |
| # MATERIAL_VECTOR — ให้กรอกได้สูงสุด 5 ชนิด | |
| MAX_MATS = 5 | |
| material_pairs = [] | |
| gr.Markdown("**วัสดุ (เปอร์เซ็นต์ + ชนิด)** เช่น 60% cotton, 40% polyester") | |
| with gr.Column(): | |
| for i in range(MAX_MATS): | |
| with gr.Row(): | |
| p = gr.Number(value=0, label=f"material_{i+1}_percent") | |
| m = gr.Dropdown(choices=MATERIAL_CATEGORIES, value=MATERIAL_CATEGORIES[0], | |
| label=f"material_{i+1}_type") | |
| material_pairs.append((p, m)) | |
| # รวมอินพุตทั้งหมด | |
| predict_inputs = [img_front, img_back] + tab_inputs + [cut_input] | |
| for p, m in material_pairs: | |
| predict_inputs.extend([p, m]) | |
| # ปุ่ม + เอาต์พุต | |
| btn = gr.Button("ทำนาย") | |
| out_txt = gr.Textbox(label="สรุปผล") | |
| out_tbl = gr.Dataframe(headers=["class","probability"], | |
| datatype=["str","number"], label="ความน่าจะเป็น") | |
| btn.click(predict, inputs=predict_inputs, outputs=[out_txt, out_tbl]) | |
| demo.launch() |