# app.py import os import sys import torch import torch.nn.functional as F import torch.nn as nn from torchvision import transforms from PIL import Image import gradio as gr # Config CKPT_PATH = "vit_cnn_110class.pt" # put the file in the repo root (or update path) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Device:", DEVICE, file=sys.stderr) # Label lists (CIFAR-10 then CIFAR-100 shifted) cifar10_classes = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'] cifar100_classes = [ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' ] # unified label list 0..109 (0-9 CIFAR10, 10-109 CIFAR100) LABELS = cifar10_classes + cifar100_classes # Model architecture class ConvPatchEmbed(nn.Module): def __init__(self, in_chans=3, embed_dim=384): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_chans, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 128, 3, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, embed_dim, 3, 2, 1, bias=False), nn.BatchNorm2d(embed_dim), nn.ReLU(inplace=True), ) self.n_patches = (32 // 4) ** 2 def forward(self, x): x = self.conv(x) x = x.flatten(2).transpose(1,2) return x class MLP(nn.Module): def __init__(self, in_features, hidden_features=None, drop=0.): super().__init__() hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_features, in_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x); x = self.act(x); x = self.drop(x) x = self.fc2(x); x = self.drop(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads=6): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim*3) self.proj = nn.Linear(dim, dim) def forward(self, x): B,N,C = x.shape qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4) q,k,v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2,-1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1,2).reshape(B,N,C) return self.proj(x) class _StochasticDepth(nn.Module): def __init__(self,p): super().__init__(); self.p = p def forward(self,x): if not self.training or self.p==0: return x keep = torch.rand(x.shape[0],1,1,device=x.device) >= self.p return x * keep / (1 - self.p) class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., drop_path=0.): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = Attention(dim, num_heads) self.drop_path = nn.Identity() if drop_path==0 else _StochasticDepth(drop_path) self.norm2 = nn.LayerNorm(dim) self.mlp = MLP(dim, int(dim*mlp_ratio), drop) def forward(self,x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class ViT110(nn.Module): def __init__(self, emb_dim=384, depth=8, num_heads=6, mlp_ratio=4.0, num_classes=110, drop=0.1, drop_path=0.1): super().__init__() cfg = {"in_channels":3, "emb_dim":emb_dim, "depth":depth, "num_heads":num_heads, "mlp_ratio":mlp_ratio, "drop":drop, "drop_path":drop_path} self.patch_embed = ConvPatchEmbed(cfg["in_channels"], cfg["emb_dim"]) n_patches = self.patch_embed.n_patches self.cls_token = nn.Parameter(torch.zeros(1,1,cfg["emb_dim"])) self.pos_embed = nn.Parameter(torch.zeros(1, 1 + n_patches, cfg["emb_dim"])) self.pos_drop = nn.Dropout(p=cfg["drop"]) dpr = torch.linspace(0, drop_path, depth).tolist() self.blocks = nn.ModuleList([Block(cfg["emb_dim"], cfg["num_heads"], cfg["mlp_ratio"], drop=cfg["drop"], drop_path=dpr[i]) for i in range(depth)]) self.norm = nn.LayerNorm(cfg["emb_dim"]) self.head = nn.Linear(cfg["emb_dim"], num_classes) def forward(self, x): B = x.shape[0] x = self.patch_embed(x) cls = self.cls_token.expand(B,-1,-1) x = torch.cat([cls,x],dim=1) x = x + self.pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return self.head(x[:,0]) # Load model def load_model(ckpt_path=CKPT_PATH, device=DEVICE): model = ViT110().to(device) if not os.path.exists(ckpt_path): raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") sd = torch.load(ckpt_path, map_location="cpu") # sd may be state_dict or plain dict; try both if "state_dict" in sd and isinstance(sd, dict): sd = sd["state_dict"] # filter mismatch keys (if any), load with strict=False model.load_state_dict(sd, strict=False) model.eval() return model MODEL = load_model() # Transforms (CIFAR-style) transform = transforms.Compose([ transforms.Resize(40), transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize((0.5071,0.4867,0.4408),(0.2675,0.2565,0.2761)), ]) # Example images examples_list = [ ["cat.avif"], ["Red_Kangaroo_Peter_and_Shelly_some_rights_res.width-1200.c03bc40.jpg"], ["beagle-hound-dog.webp"], ["niko-photos-tGTVxeOr_Rs-unsplash.jpg"], ["1_9527341a-93b9-4566-9eb3-3bfe92cfed5f.webp"], ["Feng-shui-fish-acquarium_0_1200.jpg.webp"], ["ED-ARTICLE-IMAGES-21.png"], ["apples-101-about-1440x810.webp"], ["beautiful-overhead-cityscape-shot-with-drone.jpg"], ["crocodile-Nile-swath-one-sub-Saharan-Africa-Madagascar.webp"], ["detect(1).jpg"] ] # UI CSS and pretty display custom_css = """ /* ---------- GLOBAL ---------- */ body { font-family: 'Inter', sans-serif !important; } .gradio-container { max-width: 960px !important; margin: auto !important; } #app-title { text-align: center; font-size: 30px; font-weight: 800; margin-bottom: 6px; } #app-subtitle { text-align: center; font-size: 15px; opacity: 0.85; margin-top: -3px; margin-bottom: 18px; } .image-upload-container { border-radius: 14px !important; padding: 12px; transition: 0.25s ease; } .image-upload-container:hover { box-shadow: 0 8px 22px rgba(0,0,0,0.12); transform: translateY(-3px); } .output-card { background: var(--block-background-fill); padding: 18px; border-radius: 12px; box-shadow: 0 8px 20px rgba(0,0,0,0.10); transition: 0.22s ease; } .model-badge { display: inline-block; padding: 5px 10px; border-radius: 10px; font-size: 13px; font-weight: 700; margin-bottom: 8px; background-color: #4f46e5; color: white; } .conf-bar-container { height: 12px; background: #e6e7ea; border-radius: 10px; overflow: hidden; margin-top: 8px; } .conf-bar { height: 100%; background: linear-gradient(90deg, #10b981, #059669); width: 0%; transition: width 0.8s ease; } .json-output pre { font-size: 13px; background: #0f1724; color: #e6eef6; border-radius: 8px; padding: 12px; } .router-meta { font-size: 13px; color: #6b7280; margin-top: 8px; } """ # --------------------------- def predict(img: Image.Image): if img is None: return {"error": "no image provided"} try: x = transform(img).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = MODEL(x) probs = F.softmax(logits, dim=1)[0] conf, idx = probs.max(0) conf = float(conf) idx = int(idx) label = LABELS[idx] router_info = { "class_index": idx, "pred_label": label, "confidence": round(conf,6), "model_used": "Unified ViT-110" } return {"predicted_class": label, "class_index": idx, "confidence": conf, "router_info": router_info} except Exception as e: return {"error": str(e)} def pretty_display(result): if result is None: return "