Boonyaratt's picture
change path
d9df2d4
# -*- coding: utf-8 -*-
"""Gladio-webapp.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/11rgLJIwe-BYZs3NcVMFz4hnq6XIzxfsv
"""
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from PIL import Image
import pandas as pd
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from torchvision.models import EfficientNet_B0_Weights
from pathlib import Path
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class ImageEncoder(nn.Module):
def __init__(self, backbone="efficientnet_b0", embed_dim=512, pretrained=True, train_backbone=False):
super().__init__()
if backbone == "resnet50":
base = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)
feat_dim = base.fc.in_features
base.fc = nn.Identity()
self.backbone = base
else:
base = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None)
feat_dim = base.classifier[1].in_features
base.classifier = nn.Identity()
self.backbone = base
for p in self.backbone.parameters():
p.requires_grad = train_backbone
self.proj = nn.Sequential(
nn.Linear(feat_dim, embed_dim),
nn.ReLU(inplace=True),
nn.BatchNorm1d(embed_dim),
nn.Dropout(0.2),
)
def forward(self, x):
f = self.backbone(x) # (B, feat_dim)
f = self.proj(f) # (B, embed_dim)
return f
class TabularEncoder(nn.Module):
def __init__(self, in_dim, out_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.BatchNorm1d(in_dim),
nn.Linear(in_dim, 256), nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(256, out_dim), nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class MultimodalNet(nn.Module):
def __init__(self, tab_in_dim, num_classes=4, img_embed_dim=512, tab_embed_dim=128,
backbone="efficientnet_b0", pretrained=True, train_backbone=False):
super().__init__()
self.img_enc = ImageEncoder(backbone=backbone, embed_dim=img_embed_dim,
pretrained=pretrained, train_backbone=train_backbone)
self.tab_enc = TabularEncoder(in_dim=tab_in_dim, out_dim=tab_embed_dim)
self.head = nn.Sequential(
nn.Linear(img_embed_dim + tab_embed_dim, 256),
nn.ReLU(inplace=True),
nn.BatchNorm1d(256),
nn.Dropout(0.4),
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
def forward(self, front_img, back_img, tab_x):
f_front = self.img_enc(front_img)
f_back = self.img_enc(back_img)
f_img = 0.5 * (f_front + f_back) # average two views
f_tab = self.tab_enc(tab_x)
fused = torch.cat([f_img, f_tab], dim=1)
logits = self.head(fused)
return logits
# ===== Force tabular dim to 38 no matter what's inside ckpt =====
FORCE_TAB_DIM = 38
FORCE_NUM_CLASSES = None # ตั้งเป็นเลขจริงถ้าอยากบังคับ, หรือปล่อย None ให้ดึงจาก ckpt/ดีฟอลต์
BASE_DIR = Path(__file__).resolve().parent
CKPT_PATH = BASE_DIR / "best_multimodal.pt" # หรือ BASE_DIR / "models" / "best_multimodal.pt"
ckpt = torch.load(str(CKPT_PATH), map_location=DEVICE)
# อย่าอ่าน tab_in_dim จาก ckpt แล้วเผลอได้ 14 มาอีก
tab_in_dim = FORCE_TAB_DIM
num_classes = int(ckpt.get("num_classes", 4) if FORCE_NUM_CLASSES is None else FORCE_NUM_CLASSES)
print("[INFO] FORCE tab_in_dim:", tab_in_dim, "| num_classes:", num_classes)
# สร้างโมเดลใหม่ให้รองรับ 38 ช่อง
model = MultimodalNet(
tab_in_dim,
num_classes=num_classes,
backbone="efficientnet_b0",
pretrained=False, # ไม่โหลด imagenet เมื่อมี ckpt เอง
train_backbone=False
).to(DEVICE)
# โหลดเฉพาะพารามิเตอร์ที่ 'เข้ากัน' และ *ตัด* ของ tab_enc (เพราะ shape ไม่ตรง)
raw_state = ckpt.get("model", ckpt)
# กรองทิ้งทั้งหมดที่ขึ้นต้นด้วย 'tab_enc.' เพื่อกันการโหลด buffer/พารามิเตอร์ 14 ช่อง
state = {k: v for k, v in raw_state.items() if not k.startswith("tab_enc.")}
missing, unexpected = model.load_state_dict(state, strict=False)
print("[load_state_dict] Missing:", missing)
print("[load_state_dict] Unexpected:", unexpected)
# ตรวจสอบว่า TabularEncoder เป็น 38 จริง (ต้องเห็น BatchNorm1d(38), Linear(in=38 -> 256))
print("[VERIFY] model.tab_enc.net =", model.tab_enc.net)
model.eval()
# base 14 ฟีเจอร์ที่ใช้ตอนเทรน
BASE_14 = [
"pilling","condition","pattern","stains","holes",
"damage_count","damage_severity",
"brand","type","size","season","category","main_color","usage"
]
NUM_COLS_BASE = ["pilling","condition","damage_count"]
CUT_CATEGORIES = ['collar','v-collar','tight','loose','regular','turtle-neck','cropped','long']
MATERIAL_CATEGORIES = [
'cotton','polyester','viscose','acrylic','nylon',
'elastane','wool','rayon','silk','linen','spandex',
'lycra','bamboo','alpaca','lyocell','cashmere'
]
# ล็อก scheme เป็น 38
SCHEME = {"use_cut": True, "use_mat_vec": True, "use_mat_count": False}
tab_in_dim = 38
def encode_material_rows(flat_vals):
"""
flat_vals: ลิสต์เรียงเป็น [p1, m1, p2, m2, ...]
คืนเวกเตอร์ยาว 16 ตาม MATERIAL_CATEGORIES
"""
agg = {k: 0.0 for k in MATERIAL_CATEGORIES}
it = iter(flat_vals)
for p, mat in zip(it, it): # เดินทีละคู่
try:
pct = float(p) if p is not None else 0.0
except:
pct = 0.0
if mat in agg:
agg[mat] += max(0.0, pct)
return [agg[k] for k in MATERIAL_CATEGORIES]
# ---------- 1) เลือกสคีมาฟีเจอร์ตาม tab_in_dim ----------
def get_feature_scheme(tab_in_dim):
"""
- 14: base 14
- 23: base 14 + cut(8) + material_count(1)
- 30: base 14 + material_vector(16)
- 38: base 14 + cut(8) + material_vector(16)
"""
if tab_in_dim == 14:
return {"use_base": True, "use_cut": False, "use_mat_vec": False, "use_mat_count": False}
if tab_in_dim == 23:
return {"use_base": True, "use_cut": True, "use_mat_vec": False, "use_mat_count": True}
if tab_in_dim == 30:
return {"use_base": True, "use_cut": False, "use_mat_vec": True, "use_mat_count": False}
if tab_in_dim == 38:
return {"use_base": True, "use_cut": True, "use_mat_vec": True, "use_mat_count": False}
raise ValueError(f"ไม่รู้จัก tab_in_dim={tab_in_dim} (รองรับ 14/23/30/38)")
SCHEME = get_feature_scheme(tab_in_dim)
# ---------- 2) ฟีเจอร์ฐาน 14 ตัว (ของคุณ) ----------
BASE_14 = [
"pilling","condition","pattern","stains","holes",
"damage_count","damage_severity",
"brand","type","size","season","category","main_color","usage"
]
NUM_COLS_BASE = ["pilling","condition","damage_count"]
# (ตรงนี้วาง cat_maps ทั้งชุดของคุณ: brand/type/size/pattern/stains/holes/damage_severity/usage/main_color/season/category)
# ---------- 3) CUT & MATERIAL utilities (จากที่คุณส่งมา) ----------
import re
CUT_CATEGORIES = ['collar','v-collar','tight','loose','regular','turtle-neck','cropped','long']
def clean_cut(cut_list):
if isinstance(cut_list, str):
try:
cut_list = eval(cut_list) if cut_list.strip().startswith("[") else cut_list.split(',')
except Exception:
cut_list = [cut_list]
cut_list = [c.strip().lower() for c in cut_list if c]
mapping = {
'c-collar':'collar', 'c collar':'collar', 'collar':'collar',
'v-collar':'v-collar', 'v collar':'v-collar',
'tight':'tight', 'loose':'loose', 'oversize':'loose',
'regular':'regular',
'turtle neck':'turtle-neck', 'turtleneck':'turtle-neck',
'cropped':'cropped', 'long':'long'
}
cleaned = set(mapping.get(x, x) for x in cut_list)
return [c for c in cleaned if c in CUT_CATEGORIES]
def cuts_to_multihot(cuts):
return [1 if cat in (cuts or []) else 0 for cat in CUT_CATEGORIES]
MATERIAL_CATEGORIES = [
'cotton','polyester','viscose','acrylic','nylon',
'elastane','wool','rayon','silk','linen','spandex',
'lycra','bamboo','alpaca','lyocell','cashmere'
]
def parse_material(text):
if text is None:
return {}
text = str(text).strip().lower()
if text in ['not available','unknown','','scanner can not read material.']:
return {}
comps = re.findall(r'(\d+)\s*%\s*([a-z]+)', text)
out = {}
for pct, mat in comps:
try:
out[mat] = int(pct)
except:
pass
return out
def material_to_vector(mat_dict):
return [mat_dict.get(cat, 0) for cat in MATERIAL_CATEGORIES]
def material_count_from_dict(mat_dict):
return sum(1 for v in mat_dict.values() if float(v) > 0)
# ---------- 4) encoder แบบ “dynamic” ให้ตรงกับโมเดล ----------
def encode_tab_from_form(base_vals, cut_selected=None, mat_count_val=None, mat_text_val=None):
# 4.1 base 14
vec = []
for col, v in zip(BASE_14, base_vals):
if col in NUM_COLS_BASE:
vec.append(float(v))
else:
m = cat_maps[col]
idx = m.get(v, list(m.values())[0]) # fallback
vec.append(float(idx))
# 4.2 cut (8)
if SCHEME["use_cut"]:
cleaned = clean_cut(cut_selected) if cut_selected else []
vec.extend(cuts_to_multihot(cleaned))
# 4.3 material
if SCHEME["use_mat_count"]:
val = 0 if mat_count_val is None else float(mat_count_val)
vec.append(val)
if SCHEME["use_mat_vec"]:
mdict = parse_material(mat_text_val)
vec.extend(material_to_vector(mdict))
x = torch.tensor([vec], dtype=torch.float32, device=DEVICE)
assert x.shape[1] == tab_in_dim, f"Encoded dim {x.shape[1]} != tab_in_dim {tab_in_dim}"
return x
import re
# ----- CUT -----
CUT_CATEGORIES = ['collar', 'v-collar', 'tight', 'loose', 'regular', 'turtle-neck', 'cropped', 'long']
def clean_cut(cut_list):
# รองรับทั้ง list และ string
if isinstance(cut_list, str):
try:
cut_list = eval(cut_list) if cut_list.strip().startswith("[") else cut_list.split(',')
except Exception:
cut_list = [cut_list]
cut_list = [c.strip().lower() for c in cut_list if c]
mapping = {
'c-collar': 'collar', 'c collar': 'collar', 'collar': 'collar',
'v-collar': 'v-collar', 'v collar': 'v-collar',
'tight': 'tight', 'loose': 'loose', 'oversize': 'loose',
'regular': 'regular',
'turtle neck': 'turtle-neck', 'turtleneck': 'turtle-neck',
'cropped': 'cropped', 'long': 'long'
}
cleaned = set()
for item in cut_list:
key = item.strip().lower()
cleaned.add(mapping.get(key, key))
return [c for c in cleaned if c in CUT_CATEGORIES]
def cuts_to_multihot(cuts):
return [1 if cat in cuts else 0 for cat in CUT_CATEGORIES]
# ----- MATERIAL -----
MATERIAL_CATEGORIES = [
'cotton','polyester','viscose','acrylic','nylon',
'elastane','wool','rayon','silk','linen','spandex',
'lycra','bamboo','alpaca','lyocell','cashmere'
]
def parse_material(text):
"""
รับสตริงแบบ '60% cotton 40% polyester' หรือกรณีไม่พร้อมใช้งาน
คืน dict เช่น {'cotton':60, 'polyester':40}
"""
if text is None:
return {}
text = str(text).strip().lower()
if text in ['not available','unknown','','scanner can not read material.']:
return {}
comps = re.findall(r'(\d+)\s*%\s*([a-z]+)', text)
out = {}
for pct, mat in comps:
try:
out[mat] = int(pct)
except Exception:
pass
return out
def material_to_vector(mat_dict):
"""เวกเตอร์ยาว 16 ตาม MATERIAL_CATEGORIES (ค่าร้อยละ 0..100)"""
return [mat_dict.get(cat, 0) for cat in MATERIAL_CATEGORIES]
def material_count_from_dict(mat_dict):
"""นับชนิดวัสดุที่มีสัดส่วน > 0 เพื่อใช้กับเคส 23 มิติ"""
return sum(1 for v in mat_dict.values() if float(v) > 0)
def get_feature_scheme(tab_in_dim):
"""
คืน dict ที่อธิบายว่าโมเดลต้องการฟีเจอร์อะไรบ้าง
- 14: base 14
- 23: base 14 + cut(8) + material_count(1)
- 30: base 14 + material_vector(16)
- 38: base 14 + cut(8) + material_vector(16)
"""
if tab_in_dim == 14:
return {"use_base": True, "use_cut": False, "use_mat_vec": False, "use_mat_count": False}
if tab_in_dim == 23:
return {"use_base": True, "use_cut": True, "use_mat_vec": False, "use_mat_count": True}
if tab_in_dim == 30:
return {"use_base": True, "use_cut": False, "use_mat_vec": True, "use_mat_count": False}
if tab_in_dim == 38:
return {"use_base": True, "use_cut": True, "use_mat_vec": True, "use_mat_count": False}
raise ValueError(f"ไม่รู้จัก tab_in_dim={tab_in_dim} (รองรับ 14/23/30/38)")
weights = EfficientNet_B0_Weights.IMAGENET1K_V1
img_tf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=weights.transforms().mean, std=weights.transforms().std),
])
def preprocess_image(pil_img: Image.Image):
return img_tf(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) # (1,3,224,224)
# ===== ใช้ลำดับฟีเจอร์ 14 ช่อง ตามที่คุณเทรน =====
TAB_FEATS = [
"pilling","condition","pattern","stains","holes",
"damage_count","damage_severity",
"brand","type","size","season","category","main_color","usage"
]
# ===== mapping จริงจากโน้ตบุ๊ก (เรียง index ให้ตรงกับที่ map ในไฟล์) =====
cat_maps = {
"brand": {
"Non-Brand": 0,
"Fast Fashion & High Street Retailers": 1,
"Other Brands": 2,
"Store Brands": 3,
"Niche Brands": 4,
"Premium & Designer": 5,
"Sportswear & Outdoor": 6,
},
"pattern": {
"Solid": 0,
"Printed": 1,
"Texture_Embellishment": 2,
"Other": 3,
},
"type": {
"topwear": 0,
"dresswear": 1,
"bottomwear": 2,
"outerwear": 3,
"other": 4,
"sleepwear": 5,
},
"size": {
"unknown": 0,
"xs": 1,
"s": 2,
"m": 3,
"l": 4,
"xl": 5,
"xxl": 6,
"kids": 7,
"onesize": 8,
},
"season": {
"All": 0, "Summer": 1, "Spring": 2, "Autumn": 3, "None": 4, "Winter": 5
},
"category": {
"Ladies": 0, "Men": 1, "Children": 2, "Unisex": 3
},
"main_color": {
"black": 0, "white": 1, "blue": 2, "multicolor": 3, "pink": 4,
"grey": 5, "beige": 6, "red": 7, "green": 8, "purple": 9,
"brown": 10, "yellow": 11, "orange": 12, "turquoise": 13, "none": 14
},
"usage": {
"export": 0, "reuse": 1, "recycle": 2, "repair": 3
},
"stains": {"No": 0, "Yes": 1},
"holes": {"None": 0, "Minor": 1, "Major": 2},
"damage_severity": {
"No Damage": 0, "Minor Damage": 1, "Moderate Damage": 2, "Severe Damage": 3
},
}
# ===== สเปกอินพุตสำหรับสร้าง UI ใน Gradio =====
FEATURE_SPECS = {
# numeric (ใช้ค่าเดิม)
"pilling": {"kind":"number","min":0,"max":5,"step":1,"default":3},
"condition":{"kind":"number","min":0,"max":5,"step":1,"default":2},
"damage_count":{"kind":"number","min":0,"max":20,"step":1,"default":0},
# categorical (choices = list(cat_maps[col].keys()))
"pattern":{"kind":"cat","choices":list(cat_maps["pattern"].keys()),"default":"Solid"},
"stains":{"kind":"cat","choices":list(cat_maps["stains"].keys()),"default":"No"},
"holes":{"kind":"cat","choices":list(cat_maps["holes"].keys()),"default":"None"},
"damage_severity":{"kind":"cat","choices":list(cat_maps["damage_severity"].keys()),"default":"No Damage"},
"brand":{"kind":"cat","choices":list(cat_maps["brand"].keys()),"default":"Non-Brand"},
"type":{"kind":"cat","choices":list(cat_maps["type"].keys()),"default":"topwear"},
"size":{"kind":"cat","choices":list(cat_maps["size"].keys()),"default":"m"},
"season":{"kind":"cat","choices":list(cat_maps["season"].keys()),"default":"All"},
"category":{"kind":"cat","choices":list(cat_maps["category"].keys()),"default":"Ladies"},
"main_color":{"kind":"cat","choices":list(cat_maps["main_color"].keys()),"default":"black"},
"usage":{"kind":"cat","choices":list(cat_maps["usage"].keys()),"default":"reuse"},
}
NUM_COLS = ["pilling","condition","damage_count"]
CAT_COLS = [c for c in TAB_FEATS if c not in NUM_COLS]
def encode_tab(tab_dict):
"""
แปลงค่าจากฟอร์ม → เวกเตอร์ตามลำดับ TAB_FEATS
- number: ใช้ค่า float ตรง ๆ (ไม่มี scaler ตามไฟล์เทรนของคุณ)
- categorical: map ชื่อ → index ตาม cat_maps (unknown → index 0 ของคอลัมน์นั้น)
"""
vec = []
for col in TAB_FEATS:
if col in NUM_COLS:
vec.append(float(tab_dict[col]))
else:
m = cat_maps[col]
# ถ้าผู้ใช้ส่งค่าที่ไม่มีใน mapping ให้ fallback เป็นตัวแรก
idx = m.get(tab_dict[col], list(m.values())[0])
vec.append(float(idx))
return torch.tensor([vec], dtype=torch.float32, device=DEVICE)
FX_RATE = 3.4 # 1 SEK ≈ 3.4 บาท (เปลี่ยนได้)
CLASS_NAMES = ["<50", "50-100", "100-150", "150+"] # ตัวอย่าง 4 คลาส
import re
def convert_label_sek_to_thb(label, rate=FX_RATE):
"""
label: สตริงช่วงราคาเป็น SEK เช่น "<50", "50-100", "150+"
คืนค่า: สตริงช่วงราคาเป็นบาท เช่น "<170 บาท", "170-340 บาท", "510+ บาท"
"""
s = str(label).strip().lower()
nums = [int(x) for x in re.findall(r"\d+", s)]
if not nums:
return label
if s.startswith("<"):
return f"<{int(round(nums[0]*rate))} บาท"
if s.endswith("+"):
return f"{int(round(nums[0]*rate))}+ บาท"
if "-" in s and len(nums) == 2:
a, b = nums
return f"{int(round(a*rate))}-{int(round(b*rate))} บาท"
return f"{int(round(nums[0]*rate))} บาท"
def predict(front_img, back_img, *vals):
try:
if front_img is None or back_img is None:
return "กรุณาอัปโหลดรูปทั้งสองภาพ", None
base_count = len(BASE_14)
i = 0
base_vals = list(vals[i:i+base_count]); i += base_count
# cut (CheckboxGroup)
cut_selected = vals[i]; i += 1
# material vector: N คู่ (percent, type) — ต้องเท่ากับ MAX_MATS ใน cell 16
MAX_MATS = 5
flat = []
for _ in range(MAX_MATS):
p = vals[i]; i += 1
m = vals[i]; i += 1
flat.extend([p, m])
# ---------- encode ----------
vec = []
# 1) base 14
for col, v in zip(BASE_14, base_vals):
if col in NUM_COLS_BASE:
vec.append(float(v))
else:
m = cat_maps[col]
idx = m.get(v, list(m.values())[0]) # fallback
vec.append(float(idx))
# 2) cut → multihot 8
cleaned = clean_cut(cut_selected) if cut_selected else []
vec.extend(cuts_to_multihot(cleaned))
# 3) material vector → 16
mvec = encode_material_rows(flat) # รวมเปอร์เซ็นต์ตามชนิด → ลิสต์ 16 ช่อง
vec.extend(mvec)
xt = torch.tensor([vec], dtype=torch.float32, device=DEVICE)
assert xt.shape[1] == tab_in_dim, f"Encoded dim {xt.shape[1]} != tab_in_dim {tab_in_dim}"
# ---------- infer ----------
with torch.no_grad():
x1 = preprocess_image(front_img)
x2 = preprocess_image(back_img)
logits = model(x1, x2, xt)
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
top_idx = int(np.argmax(probs))
top_name_sek = str(CLASS_NAMES[top_idx]) # เช่น "<50"
top_name_thb = convert_label_sek_to_thb(top_name_sek)
import pandas as pd
rows = []
for name, p in zip(CLASS_NAMES, probs.tolist()):
thb = convert_label_sek_to_thb(str(name))
rows.append([thb, round(float(p), 4)])
df = pd.DataFrame(rows, columns=["class_THB", "probability"])
return f"ผลทำนาย: {top_name_thb}", df
except Exception as e:
import traceback
return f"Error: {type(e).__name__} - {e}\n" + traceback.format_exc(), None
def encode_material_rows(flat_vals):
"""
flat_vals: ลิสต์เรียงเป็น [p1, m1, p2, m2, ...] ที่มากับ *vals ใน predict()
คืนเวกเตอร์ยาว 16 ตรงตาม MATERIAL_CATEGORIES (ค่าร้อยละรวมกันตามชนิด)
"""
# รวมเปอร์เซ็นต์ตามชนิด
agg = {k: 0.0 for k in MATERIAL_CATEGORIES}
it = iter(flat_vals)
for p, mat in zip(it, it): # เดินทีละคู่
try:
pct = float(p) if p is not None else 0.0
except:
pct = 0.0
if mat in agg:
agg[mat] += max(0.0, pct) # กันค่าติดลบ
# เติมเป็นลิสต์ตามลำดับคงที่
return [agg[k] for k in MATERIAL_CATEGORIES]
with gr.Blocks(title="Multimodal (2 Images + Tabular)") as demo:
gr.Markdown("### โมเดลจำแนกด้วย 2 รูป + คุณลักษณะตาราง")
with gr.Row():
img_front = gr.Image(type="pil", label="รูปด้านหน้า")
img_back = gr.Image(type="pil", label="รูปด้านหลัง")
# อินพุต base 14
tab_inputs = []
with gr.Row():
for k in BASE_14:
spec = FEATURE_SPECS[k] # ต้องมี FEATURE_SPECS ตาม mapping ที่คุณใส่ไว้
if spec["kind"] == "number":
tab_inputs.append(gr.Slider(minimum=spec["min"], maximum=spec["max"],
step=spec["step"], value=spec["default"], label=k))
else:
tab_inputs.append(gr.Dropdown(choices=spec["choices"],
value=spec["default"], label=k))
# CUT
cut_input = gr.CheckboxGroup(label="cut (เลือกได้หลายค่า)",
choices=CUT_CATEGORIES, value=[])
# MATERIAL_VECTOR — ให้กรอกได้สูงสุด 5 ชนิด
MAX_MATS = 5
material_pairs = []
gr.Markdown("**วัสดุ (เปอร์เซ็นต์ + ชนิด)** เช่น 60% cotton, 40% polyester")
with gr.Column():
for i in range(MAX_MATS):
with gr.Row():
p = gr.Number(value=0, label=f"material_{i+1}_percent")
m = gr.Dropdown(choices=MATERIAL_CATEGORIES, value=MATERIAL_CATEGORIES[0],
label=f"material_{i+1}_type")
material_pairs.append((p, m))
# รวมอินพุตทั้งหมด
predict_inputs = [img_front, img_back] + tab_inputs + [cut_input]
for p, m in material_pairs:
predict_inputs.extend([p, m])
# ปุ่ม + เอาต์พุต
btn = gr.Button("ทำนาย")
out_txt = gr.Textbox(label="สรุปผล")
out_tbl = gr.Dataframe(headers=["class","probability"],
datatype=["str","number"], label="ความน่าจะเป็น")
btn.click(predict, inputs=predict_inputs, outputs=[out_txt, out_tbl])
demo.launch()