CS231 / app.py
quynong's picture
Update app.py
efab2c6 verified
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
import torchvision.transforms as T
import torchvision.models as tv_models
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
CLASS_NAMES = [
"Banh beo", "Banh bot loc", "Banh can", "Banh canh", "Banh chung",
"Banh cuon", "Banh duc", "Banh gio", "Banh khot", "Banh mi",
"Banh pia", "Banh tet", "Banh trang nuong", "Banh xeo", "Bun bo Hue",
"Bun dau mam tom", "Bun mam", "Bun rieu", "Bun thit nuong", "Ca kho to",
"Canh chua", "Cao lau", "Chao long", "Com tam", "Goi cuon",
"Hu tieu", "Mi quang", "Nem chua", "Pho", "Xoi xeo",
]
NUM_CLASSES = len(CLASS_NAMES)
# ── Transform ────────────────────────────────────────────────────────────────
transform = T.Compose([
T.Resize((IMG_SIZE, IMG_SIZE)),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
# ══════════════════════════════════════════════════════════════════════════════
# Custom CNN architecture (must match training exactly)
# ══════════════════════════════════════════════════════════════════════════════
class ConvBlock(nn.Module):
"""Conv β†’ BN β†’ ReLU β†’ Conv β†’ BN β†’ ReLU β†’ MaxPool β†’ Dropout."""
def __init__(self, in_ch, out_ch, pool=True, drop=0.25):
super().__init__()
layers = [
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
]
if pool:
layers.append(nn.MaxPool2d(2, 2))
if drop > 0:
layers.append(nn.Dropout2d(drop))
self.block = nn.Sequential(*layers)
def forward(self, x):
return self.block(x)
class CustomCNN(nn.Module):
"""
5-block CNN cho αΊ£nh 224Γ—224.
224 β†’ 112 β†’ 56 β†’ 28 β†’ 14 β†’ 7 β†’ GAP β†’ FC
"""
def __init__(self, num_classes=30):
super().__init__()
self.features = nn.Sequential(
ConvBlock(3, 64, pool=True, drop=0.20), # 112Γ—112
ConvBlock(64, 128, pool=True, drop=0.20), # 56Γ—56
ConvBlock(128, 256, pool=True, drop=0.25), # 28Γ—28
ConvBlock(256, 512, pool=True, drop=0.25), # 14Γ—14
ConvBlock(512, 512, pool=True, drop=0.30), # 7Γ—7
)
self.gap = nn.AdaptiveAvgPool2d(1) # 1Γ—1
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(512, num_classes),
)
def forward(self, x):
x = self.features(x)
x = self.gap(x)
return self.classifier(x)
# ══════════════════════════════════════════════════════════════════════════════
# Model loading helpers
# ══════════════════════════════════════════════════════════════════════════════
def load_state_dict_from_pth(path):
"""Load state_dict from .pth that may be a full checkpoint or plain state_dict."""
data = torch.load(path, map_location=DEVICE, weights_only=False)
if isinstance(data, dict) and "model_state_dict" in data:
return data["model_state_dict"]
if isinstance(data, dict) and "net" in data:
return data["net"]
return data
def load_custom_cnn(repo_id="quynong/vnfood-cnn"):
weights_path = hf_hub_download(repo_id, "pytorch_model.bin")
model = CustomCNN(num_classes=NUM_CLASSES)
state = torch.load(weights_path, map_location=DEVICE, weights_only=True)
model.load_state_dict(state)
model.to(DEVICE).eval()
return model
def load_vgg16(repo_id="minfu2k5/vgg16-30vnfoods"):
weights_path = hf_hub_download(repo_id, "best_vgg16_30vnfoods.pt")
model = tv_models.vgg16(weights=None)
model.classifier = nn.Sequential(
nn.Linear(25088, 512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(0.5),
nn.Linear(512, NUM_CLASSES),
)
state = load_state_dict_from_pth(weights_path)
model.load_state_dict(state, strict=True)
model.to(DEVICE).eval()
return model
def load_resnet18(repo_id="trinhtrantran122/resnet18-vnfoods-v3"):
weights_path = hf_hub_download(repo_id, "resnet18_v2_processed_best.pth")
state = load_state_dict_from_pth(weights_path)
# Handle DataParallel-trained checkpoints.
if any(k.startswith("module.") for k in state.keys()):
state = {k.replace("module.", "", 1): v for k, v in state.items()}
model = tv_models.resnet18(weights=None)
in_features = model.fc.in_features
# Support both checkpoint heads:
# - old: fc.weight / fc.bias
# - new: fc.1.weight / fc.1.bias (Sequential head)
if any(k.startswith("fc.1.") for k in state.keys()):
model.fc = nn.Sequential(
nn.Identity(),
nn.Linear(in_features, NUM_CLASSES),
)
else:
model.fc = nn.Linear(in_features, NUM_CLASSES)
model.load_state_dict(state, strict=True)
model.to(DEVICE).eval()
return model
def load_mobilenetv2(repo_id="shidamaring/cs231-30food-mobilnetv2"):
weights_path = hf_hub_download(repo_id, "finetune_best_lan8.pth")
model = tv_models.mobilenet_v2(weights=None)
model.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(1280, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, NUM_CLASSES),
)
state = load_state_dict_from_pth(weights_path)
model.load_state_dict(state, strict=True)
model.to(DEVICE).eval()
return model
# ══════════════════════════════════════════════════════════════════════════════
# Load all 4 models at startup
# ══════════════════════════════════════════════════════════════════════════════
print("Loading models …")
models_dict = {}
print(" [1/4] Custom CNN …")
models_dict["Custom CNN"] = load_custom_cnn("quynong/vnfood-cnn")
print(" [2/4] VGG-16 …")
models_dict["VGG-16"] = load_vgg16("minfu2k5/vgg16-30vnfoods")
print(" [3/4] ResNet-18 …")
models_dict["ResNet-18"] = load_resnet18("trinhtrantran122/resnet18-vnfoods-v3")
print(" [4/4] MobileNetV2 …")
models_dict["MobileNetV2"] = load_mobilenetv2("shidamaring/cs231-30food-mobilnetv2")
print("All models loaded βœ“")
# ══════════════════════════════════════════════════════════════════════════════
# Prediction
# ══════════════════════════════════════════════════════════════════════════════
@torch.no_grad()
def predict_single(model, image: Image.Image, top_k=5):
tensor = transform(image.convert("RGB")).unsqueeze(0).to(DEVICE)
probs = torch.softmax(model(tensor), dim=1)[0].cpu()
top_probs, top_idxs = probs.topk(top_k)
return {CLASS_NAMES[idx]: float(prob) for idx, prob in zip(top_idxs, top_probs)}
def predict_all(image):
if image is None:
return {name: {} for name in models_dict}
pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image
results = {}
for name, model in models_dict.items():
results[name] = predict_single(model, pil)
return (
results["Custom CNN"],
results["VGG-16"],
results["ResNet-18"],
results["MobileNetV2"],
)
# ══════════════════════════════════════════════════════════════════════════════
# Gradio UI
# ══════════════════════════════════════════════════════════════════════════════
with gr.Blocks(title="VNFood Classifier – 30 Vietnamese Dishes") as demo:
gr.Markdown(
"## 🍜 VNFood Classifier\n"
"Upload αΊ£nh mΓ³n Δƒn Việt Nam β†’ xem kαΊΏt quαΊ£ phΓ’n loαΊ‘i tα»« **4 model** khΓ‘c nhau."
)
with gr.Row():
image_input = gr.Image(type="pil", label="Upload αΊ£nh mΓ³n Δƒn")
btn = gr.Button("Predict", variant="primary")
with gr.Row():
out_cnn = gr.Label(num_top_classes=5, label="Custom CNN")
out_vgg = gr.Label(num_top_classes=5, label="VGG-16")
with gr.Row():
out_res = gr.Label(num_top_classes=5, label="ResNet-18")
out_mob = gr.Label(num_top_classes=5, label="MobileNetV2")
btn.click(
fn=predict_all,
inputs=image_input,
outputs=[out_cnn, out_vgg, out_res, out_mob],
)
gr.Markdown(
"**Models:** `quynong/vnfood-cnn` Β· `minfu2k5/vgg16-30vnfoods` Β· "
"`trinhtrantran122/resnet18-vnfoods` Β· `shidamaring/cs231-30food-mobilnetv2`"
)
if __name__ == "__main__":
demo.launch()