| import torch |
| import torch.nn as nn |
| import gradio as gr |
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
| import torchvision.transforms as T |
| import torchvision.models as tv_models |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| IMG_SIZE = 224 |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
| CLASS_NAMES = [ |
| "Banh beo", "Banh bot loc", "Banh can", "Banh canh", "Banh chung", |
| "Banh cuon", "Banh duc", "Banh gio", "Banh khot", "Banh mi", |
| "Banh pia", "Banh tet", "Banh trang nuong", "Banh xeo", "Bun bo Hue", |
| "Bun dau mam tom", "Bun mam", "Bun rieu", "Bun thit nuong", "Ca kho to", |
| "Canh chua", "Cao lau", "Chao long", "Com tam", "Goi cuon", |
| "Hu tieu", "Mi quang", "Nem chua", "Pho", "Xoi xeo", |
| ] |
| NUM_CLASSES = len(CLASS_NAMES) |
|
|
| |
| transform = T.Compose([ |
| T.Resize((IMG_SIZE, IMG_SIZE)), |
| T.ToTensor(), |
| T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ]) |
|
|
|
|
| |
| |
| |
| class ConvBlock(nn.Module): |
| """Conv β BN β ReLU β Conv β BN β ReLU β MaxPool β Dropout.""" |
| def __init__(self, in_ch, out_ch, pool=True, drop=0.25): |
| super().__init__() |
| layers = [ |
| nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(inplace=True), |
| ] |
| if pool: |
| layers.append(nn.MaxPool2d(2, 2)) |
| if drop > 0: |
| layers.append(nn.Dropout2d(drop)) |
| self.block = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| return self.block(x) |
|
|
|
|
| class CustomCNN(nn.Module): |
| """ |
| 5-block CNN cho αΊ£nh 224Γ224. |
| 224 β 112 β 56 β 28 β 14 β 7 β GAP β FC |
| """ |
| def __init__(self, num_classes=30): |
| super().__init__() |
| self.features = nn.Sequential( |
| ConvBlock(3, 64, pool=True, drop=0.20), |
| ConvBlock(64, 128, pool=True, drop=0.20), |
| ConvBlock(128, 256, pool=True, drop=0.25), |
| ConvBlock(256, 512, pool=True, drop=0.25), |
| ConvBlock(512, 512, pool=True, drop=0.30), |
| ) |
| self.gap = nn.AdaptiveAvgPool2d(1) |
| self.classifier = nn.Sequential( |
| nn.Flatten(), |
| nn.Linear(512, 512), |
| nn.BatchNorm1d(512), |
| nn.ReLU(inplace=True), |
| nn.Dropout(0.5), |
| nn.Linear(512, num_classes), |
| ) |
|
|
| def forward(self, x): |
| x = self.features(x) |
| x = self.gap(x) |
| return self.classifier(x) |
|
|
|
|
| |
| |
| |
| def load_state_dict_from_pth(path): |
| """Load state_dict from .pth that may be a full checkpoint or plain state_dict.""" |
| data = torch.load(path, map_location=DEVICE, weights_only=False) |
| if isinstance(data, dict) and "model_state_dict" in data: |
| return data["model_state_dict"] |
| if isinstance(data, dict) and "net" in data: |
| return data["net"] |
| return data |
|
|
|
|
| def load_custom_cnn(repo_id="quynong/vnfood-cnn"): |
| weights_path = hf_hub_download(repo_id, "pytorch_model.bin") |
| model = CustomCNN(num_classes=NUM_CLASSES) |
| state = torch.load(weights_path, map_location=DEVICE, weights_only=True) |
| model.load_state_dict(state) |
| model.to(DEVICE).eval() |
| return model |
|
|
|
|
| def load_vgg16(repo_id="minfu2k5/vgg16-30vnfoods"): |
| weights_path = hf_hub_download(repo_id, "best_vgg16_30vnfoods.pt") |
| model = tv_models.vgg16(weights=None) |
| model.classifier = nn.Sequential( |
| nn.Linear(25088, 512), |
| nn.ReLU(inplace=True), |
| nn.BatchNorm1d(512), |
| nn.Dropout(0.5), |
| nn.Linear(512, NUM_CLASSES), |
| ) |
| state = load_state_dict_from_pth(weights_path) |
| model.load_state_dict(state, strict=True) |
| model.to(DEVICE).eval() |
| return model |
|
|
|
|
| def load_resnet18(repo_id="trinhtrantran122/resnet18-vnfoods-v3"): |
| weights_path = hf_hub_download(repo_id, "resnet18_v2_processed_best.pth") |
| state = load_state_dict_from_pth(weights_path) |
| |
| if any(k.startswith("module.") for k in state.keys()): |
| state = {k.replace("module.", "", 1): v for k, v in state.items()} |
|
|
| model = tv_models.resnet18(weights=None) |
| in_features = model.fc.in_features |
|
|
| |
| |
| |
| if any(k.startswith("fc.1.") for k in state.keys()): |
| model.fc = nn.Sequential( |
| nn.Identity(), |
| nn.Linear(in_features, NUM_CLASSES), |
| ) |
| else: |
| model.fc = nn.Linear(in_features, NUM_CLASSES) |
|
|
| model.load_state_dict(state, strict=True) |
| model.to(DEVICE).eval() |
| return model |
|
|
|
|
| def load_mobilenetv2(repo_id="shidamaring/cs231-30food-mobilnetv2"): |
| weights_path = hf_hub_download(repo_id, "finetune_best_lan8.pth") |
| model = tv_models.mobilenet_v2(weights=None) |
| model.classifier = nn.Sequential( |
| nn.Dropout(0.3), |
| nn.Linear(1280, 256), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(256, NUM_CLASSES), |
| ) |
| state = load_state_dict_from_pth(weights_path) |
| model.load_state_dict(state, strict=True) |
| model.to(DEVICE).eval() |
| return model |
|
|
|
|
| |
| |
| |
| print("Loading models β¦") |
|
|
| models_dict = {} |
|
|
| print(" [1/4] Custom CNN β¦") |
| models_dict["Custom CNN"] = load_custom_cnn("quynong/vnfood-cnn") |
|
|
| print(" [2/4] VGG-16 β¦") |
| models_dict["VGG-16"] = load_vgg16("minfu2k5/vgg16-30vnfoods") |
|
|
| print(" [3/4] ResNet-18 β¦") |
| models_dict["ResNet-18"] = load_resnet18("trinhtrantran122/resnet18-vnfoods-v3") |
|
|
| print(" [4/4] MobileNetV2 β¦") |
| models_dict["MobileNetV2"] = load_mobilenetv2("shidamaring/cs231-30food-mobilnetv2") |
|
|
| print("All models loaded β") |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def predict_single(model, image: Image.Image, top_k=5): |
| tensor = transform(image.convert("RGB")).unsqueeze(0).to(DEVICE) |
| probs = torch.softmax(model(tensor), dim=1)[0].cpu() |
| top_probs, top_idxs = probs.topk(top_k) |
| return {CLASS_NAMES[idx]: float(prob) for idx, prob in zip(top_idxs, top_probs)} |
|
|
|
|
| def predict_all(image): |
| if image is None: |
| return {name: {} for name in models_dict} |
| pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image |
| results = {} |
| for name, model in models_dict.items(): |
| results[name] = predict_single(model, pil) |
| return ( |
| results["Custom CNN"], |
| results["VGG-16"], |
| results["ResNet-18"], |
| results["MobileNetV2"], |
| ) |
|
|
|
|
| |
| |
| |
| with gr.Blocks(title="VNFood Classifier β 30 Vietnamese Dishes") as demo: |
| gr.Markdown( |
| "## π VNFood Classifier\n" |
| "Upload αΊ£nh mΓ³n Δn Viα»t Nam β xem kαΊΏt quαΊ£ phΓ’n loαΊ‘i tα»« **4 model** khΓ‘c nhau." |
| ) |
|
|
| with gr.Row(): |
| image_input = gr.Image(type="pil", label="Upload αΊ£nh mΓ³n Δn") |
|
|
| btn = gr.Button("Predict", variant="primary") |
|
|
| with gr.Row(): |
| out_cnn = gr.Label(num_top_classes=5, label="Custom CNN") |
| out_vgg = gr.Label(num_top_classes=5, label="VGG-16") |
| with gr.Row(): |
| out_res = gr.Label(num_top_classes=5, label="ResNet-18") |
| out_mob = gr.Label(num_top_classes=5, label="MobileNetV2") |
|
|
| btn.click( |
| fn=predict_all, |
| inputs=image_input, |
| outputs=[out_cnn, out_vgg, out_res, out_mob], |
| ) |
|
|
| gr.Markdown( |
| "**Models:** `quynong/vnfood-cnn` Β· `minfu2k5/vgg16-30vnfoods` Β· " |
| "`trinhtrantran122/resnet18-vnfoods` Β· `shidamaring/cs231-30food-mobilnetv2`" |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|