import torch import torch.nn as nn import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download import torchvision.transforms as T import torchvision.models as tv_models DEVICE = "cuda" if torch.cuda.is_available() else "cpu" IMG_SIZE = 224 IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] CLASS_NAMES = [ "Banh beo", "Banh bot loc", "Banh can", "Banh canh", "Banh chung", "Banh cuon", "Banh duc", "Banh gio", "Banh khot", "Banh mi", "Banh pia", "Banh tet", "Banh trang nuong", "Banh xeo", "Bun bo Hue", "Bun dau mam tom", "Bun mam", "Bun rieu", "Bun thit nuong", "Ca kho to", "Canh chua", "Cao lau", "Chao long", "Com tam", "Goi cuon", "Hu tieu", "Mi quang", "Nem chua", "Pho", "Xoi xeo", ] NUM_CLASSES = len(CLASS_NAMES) # ── Transform ──────────────────────────────────────────────────────────────── transform = T.Compose([ T.Resize((IMG_SIZE, IMG_SIZE)), T.ToTensor(), T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) # ══════════════════════════════════════════════════════════════════════════════ # Custom CNN architecture (must match training exactly) # ══════════════════════════════════════════════════════════════════════════════ class ConvBlock(nn.Module): """Conv → BN → ReLU → Conv → BN → ReLU → MaxPool → Dropout.""" def __init__(self, in_ch, out_ch, pool=True, drop=0.25): super().__init__() layers = [ nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ] if pool: layers.append(nn.MaxPool2d(2, 2)) if drop > 0: layers.append(nn.Dropout2d(drop)) self.block = nn.Sequential(*layers) def forward(self, x): return self.block(x) class CustomCNN(nn.Module): """ 5-block CNN cho ảnh 224×224. 224 → 112 → 56 → 28 → 14 → 7 → GAP → FC """ def __init__(self, num_classes=30): super().__init__() self.features = nn.Sequential( ConvBlock(3, 64, pool=True, drop=0.20), # 112×112 ConvBlock(64, 128, pool=True, drop=0.20), # 56×56 ConvBlock(128, 256, pool=True, drop=0.25), # 28×28 ConvBlock(256, 512, pool=True, drop=0.25), # 14×14 ConvBlock(512, 512, pool=True, drop=0.30), # 7×7 ) self.gap = nn.AdaptiveAvgPool2d(1) # 1×1 self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(512, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(512, num_classes), ) def forward(self, x): x = self.features(x) x = self.gap(x) return self.classifier(x) # ══════════════════════════════════════════════════════════════════════════════ # Model loading helpers # ══════════════════════════════════════════════════════════════════════════════ def load_state_dict_from_pth(path): """Load state_dict from .pth that may be a full checkpoint or plain state_dict.""" data = torch.load(path, map_location=DEVICE, weights_only=False) if isinstance(data, dict) and "model_state_dict" in data: return data["model_state_dict"] if isinstance(data, dict) and "net" in data: return data["net"] return data def load_custom_cnn(repo_id="quynong/vnfood-cnn"): weights_path = hf_hub_download(repo_id, "pytorch_model.bin") model = CustomCNN(num_classes=NUM_CLASSES) state = torch.load(weights_path, map_location=DEVICE, weights_only=True) model.load_state_dict(state) model.to(DEVICE).eval() return model def load_vgg16(repo_id="minfu2k5/vgg16-30vnfoods"): weights_path = hf_hub_download(repo_id, "best_vgg16_30vnfoods.pt") model = tv_models.vgg16(weights=None) model.classifier = nn.Sequential( nn.Linear(25088, 512), nn.ReLU(inplace=True), nn.BatchNorm1d(512), nn.Dropout(0.5), nn.Linear(512, NUM_CLASSES), ) state = load_state_dict_from_pth(weights_path) model.load_state_dict(state, strict=True) model.to(DEVICE).eval() return model def load_resnet18(repo_id="trinhtrantran122/resnet18-vnfoods-v3"): weights_path = hf_hub_download(repo_id, "resnet18_v2_processed_best.pth") state = load_state_dict_from_pth(weights_path) # Handle DataParallel-trained checkpoints. if any(k.startswith("module.") for k in state.keys()): state = {k.replace("module.", "", 1): v for k, v in state.items()} model = tv_models.resnet18(weights=None) in_features = model.fc.in_features # Support both checkpoint heads: # - old: fc.weight / fc.bias # - new: fc.1.weight / fc.1.bias (Sequential head) if any(k.startswith("fc.1.") for k in state.keys()): model.fc = nn.Sequential( nn.Identity(), nn.Linear(in_features, NUM_CLASSES), ) else: model.fc = nn.Linear(in_features, NUM_CLASSES) model.load_state_dict(state, strict=True) model.to(DEVICE).eval() return model def load_mobilenetv2(repo_id="shidamaring/cs231-30food-mobilnetv2"): weights_path = hf_hub_download(repo_id, "finetune_best_lan8.pth") model = tv_models.mobilenet_v2(weights=None) model.classifier = nn.Sequential( nn.Dropout(0.3), nn.Linear(1280, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, NUM_CLASSES), ) state = load_state_dict_from_pth(weights_path) model.load_state_dict(state, strict=True) model.to(DEVICE).eval() return model # ══════════════════════════════════════════════════════════════════════════════ # Load all 4 models at startup # ══════════════════════════════════════════════════════════════════════════════ print("Loading models …") models_dict = {} print(" [1/4] Custom CNN …") models_dict["Custom CNN"] = load_custom_cnn("quynong/vnfood-cnn") print(" [2/4] VGG-16 …") models_dict["VGG-16"] = load_vgg16("minfu2k5/vgg16-30vnfoods") print(" [3/4] ResNet-18 …") models_dict["ResNet-18"] = load_resnet18("trinhtrantran122/resnet18-vnfoods-v3") print(" [4/4] MobileNetV2 …") models_dict["MobileNetV2"] = load_mobilenetv2("shidamaring/cs231-30food-mobilnetv2") print("All models loaded ✓") # ══════════════════════════════════════════════════════════════════════════════ # Prediction # ══════════════════════════════════════════════════════════════════════════════ @torch.no_grad() def predict_single(model, image: Image.Image, top_k=5): tensor = transform(image.convert("RGB")).unsqueeze(0).to(DEVICE) probs = torch.softmax(model(tensor), dim=1)[0].cpu() top_probs, top_idxs = probs.topk(top_k) return {CLASS_NAMES[idx]: float(prob) for idx, prob in zip(top_idxs, top_probs)} def predict_all(image): if image is None: return {name: {} for name in models_dict} pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image results = {} for name, model in models_dict.items(): results[name] = predict_single(model, pil) return ( results["Custom CNN"], results["VGG-16"], results["ResNet-18"], results["MobileNetV2"], ) # ══════════════════════════════════════════════════════════════════════════════ # Gradio UI # ══════════════════════════════════════════════════════════════════════════════ with gr.Blocks(title="VNFood Classifier – 30 Vietnamese Dishes") as demo: gr.Markdown( "## 🍜 VNFood Classifier\n" "Upload ảnh món ăn Việt Nam → xem kết quả phân loại từ **4 model** khác nhau." ) with gr.Row(): image_input = gr.Image(type="pil", label="Upload ảnh món ăn") btn = gr.Button("Predict", variant="primary") with gr.Row(): out_cnn = gr.Label(num_top_classes=5, label="Custom CNN") out_vgg = gr.Label(num_top_classes=5, label="VGG-16") with gr.Row(): out_res = gr.Label(num_top_classes=5, label="ResNet-18") out_mob = gr.Label(num_top_classes=5, label="MobileNetV2") btn.click( fn=predict_all, inputs=image_input, outputs=[out_cnn, out_vgg, out_res, out_mob], ) gr.Markdown( "**Models:** `quynong/vnfood-cnn` · `minfu2k5/vgg16-30vnfoods` · " "`trinhtrantran122/resnet18-vnfoods` · `shidamaring/cs231-30food-mobilnetv2`" ) if __name__ == "__main__": demo.launch()