ViT-One110 / app.py
Aumkeshchy2003's picture
Update app.py
d8b8755 verified
# app.py
import os
import sys
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
# Config
CKPT_PATH = "vit_cnn_110class.pt" # put the file in the repo root (or update path)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE, file=sys.stderr)
# Label lists (CIFAR-10 then CIFAR-100 shifted)
cifar10_classes = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
cifar100_classes = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea',
'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank',
'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip',
'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
]
# unified label list 0..109 (0-9 CIFAR10, 10-109 CIFAR100)
LABELS = cifar10_classes + cifar100_classes
# Model architecture
class ConvPatchEmbed(nn.Module):
def __init__(self, in_chans=3, embed_dim=384):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_chans, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, embed_dim, 3, 2, 1, bias=False),
nn.BatchNorm2d(embed_dim),
nn.ReLU(inplace=True),
)
self.n_patches = (32 // 4) ** 2
def forward(self, x):
x = self.conv(x)
x = x.flatten(2).transpose(1,2)
return x
class MLP(nn.Module):
def __init__(self, in_features, hidden_features=None, drop=0.):
super().__init__()
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, in_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x); x = self.act(x); x = self.drop(x)
x = self.fc2(x); x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=6):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim*3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B,N,C = x.shape
qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
q,k,v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2,-1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1,2).reshape(B,N,C)
return self.proj(x)
class _StochasticDepth(nn.Module):
def __init__(self,p): super().__init__(); self.p = p
def forward(self,x):
if not self.training or self.p==0: return x
keep = torch.rand(x.shape[0],1,1,device=x.device) >= self.p
return x * keep / (1 - self.p)
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., drop_path=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads)
self.drop_path = nn.Identity() if drop_path==0 else _StochasticDepth(drop_path)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, int(dim*mlp_ratio), drop)
def forward(self,x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class ViT110(nn.Module):
def __init__(self, emb_dim=384, depth=8, num_heads=6, mlp_ratio=4.0, num_classes=110, drop=0.1, drop_path=0.1):
super().__init__()
cfg = {"in_channels":3, "emb_dim":emb_dim, "depth":depth, "num_heads":num_heads, "mlp_ratio":mlp_ratio, "drop":drop, "drop_path":drop_path}
self.patch_embed = ConvPatchEmbed(cfg["in_channels"], cfg["emb_dim"])
n_patches = self.patch_embed.n_patches
self.cls_token = nn.Parameter(torch.zeros(1,1,cfg["emb_dim"]))
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + n_patches, cfg["emb_dim"]))
self.pos_drop = nn.Dropout(p=cfg["drop"])
dpr = torch.linspace(0, drop_path, depth).tolist()
self.blocks = nn.ModuleList([Block(cfg["emb_dim"], cfg["num_heads"], cfg["mlp_ratio"], drop=cfg["drop"], drop_path=dpr[i]) for i in range(depth)])
self.norm = nn.LayerNorm(cfg["emb_dim"])
self.head = nn.Linear(cfg["emb_dim"], num_classes)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls = self.cls_token.expand(B,-1,-1)
x = torch.cat([cls,x],dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return self.head(x[:,0])
# Load model
def load_model(ckpt_path=CKPT_PATH, device=DEVICE):
model = ViT110().to(device)
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
sd = torch.load(ckpt_path, map_location="cpu")
# sd may be state_dict or plain dict; try both
if "state_dict" in sd and isinstance(sd, dict):
sd = sd["state_dict"]
# filter mismatch keys (if any), load with strict=False
model.load_state_dict(sd, strict=False)
model.eval()
return model
MODEL = load_model()
# Transforms (CIFAR-style)
transform = transforms.Compose([
transforms.Resize(40),
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize((0.5071,0.4867,0.4408),(0.2675,0.2565,0.2761)),
])
# Example images
examples_list = [
["cat.avif"],
["Red_Kangaroo_Peter_and_Shelly_some_rights_res.width-1200.c03bc40.jpg"],
["beagle-hound-dog.webp"],
["niko-photos-tGTVxeOr_Rs-unsplash.jpg"],
["1_9527341a-93b9-4566-9eb3-3bfe92cfed5f.webp"],
["Feng-shui-fish-acquarium_0_1200.jpg.webp"],
["ED-ARTICLE-IMAGES-21.png"],
["apples-101-about-1440x810.webp"],
["beautiful-overhead-cityscape-shot-with-drone.jpg"],
["crocodile-Nile-swath-one-sub-Saharan-Africa-Madagascar.webp"],
["detect(1).jpg"]
]
# UI CSS and pretty display
custom_css = """
/* ---------- GLOBAL ---------- */
body {
font-family: 'Inter', sans-serif !important;
}
.gradio-container {
max-width: 960px !important;
margin: auto !important;
}
#app-title {
text-align: center;
font-size: 30px;
font-weight: 800;
margin-bottom: 6px;
}
#app-subtitle {
text-align: center;
font-size: 15px;
opacity: 0.85;
margin-top: -3px;
margin-bottom: 18px;
}
.image-upload-container {
border-radius: 14px !important;
padding: 12px;
transition: 0.25s ease;
}
.image-upload-container:hover {
box-shadow: 0 8px 22px rgba(0,0,0,0.12);
transform: translateY(-3px);
}
.output-card {
background: var(--block-background-fill);
padding: 18px;
border-radius: 12px;
box-shadow: 0 8px 20px rgba(0,0,0,0.10);
transition: 0.22s ease;
}
.model-badge {
display: inline-block;
padding: 5px 10px;
border-radius: 10px;
font-size: 13px;
font-weight: 700;
margin-bottom: 8px;
background-color: #4f46e5;
color: white;
}
.conf-bar-container { height: 12px; background: #e6e7ea; border-radius: 10px; overflow: hidden; margin-top: 8px; }
.conf-bar { height: 100%; background: linear-gradient(90deg, #10b981, #059669); width: 0%; transition: width 0.8s ease; }
.json-output pre { font-size: 13px; background: #0f1724; color: #e6eef6; border-radius: 8px; padding: 12px; }
.router-meta { font-size: 13px; color: #6b7280; margin-top: 8px; }
"""
# ---------------------------
def predict(img: Image.Image):
if img is None:
return {"error": "no image provided"}
try:
x = transform(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = MODEL(x)
probs = F.softmax(logits, dim=1)[0]
conf, idx = probs.max(0)
conf = float(conf)
idx = int(idx)
label = LABELS[idx]
router_info = {
"class_index": idx,
"pred_label": label,
"confidence": round(conf,6),
"model_used": "Unified ViT-110"
}
return {"predicted_class": label, "class_index": idx, "confidence": conf, "router_info": router_info}
except Exception as e:
return {"error": str(e)}
def pretty_display(result):
if result is None:
return "<div class='output-card'><div class='model-badge'>No prediction</div><div>No result returned.</div></div>"
if "error" in result:
return f"<div class='output-card'><div class='model-badge'>Error</div><div>{result['error']}</div></div>"
cls = result.get("predicted_class", "unknown")
idx = result.get("class_index", -1)
conf = result.get("confidence", 0.0)
conf_pct = round(conf * 100, 2)
info = result.get("router_info", {})
meta_html = f"<div class='router-meta'><b>Index:</b> {idx} &nbsp;|&nbsp; <b>Model:</b> {info.get('model_used','Unified ViT-110')} &nbsp;|&nbsp; <b>Confidence:</b> {conf_pct}%</div>"
html = f"""
<div class="output-card">
<div class="model-badge">Unified ViT-110</div>
<h2 style="margin-top:4px;margin-bottom:6px;font-size:22px;">
Prediction: <span style="color:#10b981;font-weight:700">{cls}</span>
</h2>
<div style="font-size:15px;opacity:0.85;">Confidence: {conf_pct}%</div>
<div class="conf-bar-container"><div class="conf-bar" style="width:{conf_pct}%;"></div></div>
{meta_html}
</div>
"""
return html
# Gradio UI
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.HTML("<div id='app-title'>ViT-Fusion: Hybrid Transformer for 110 CIFAR Classes</div>")
gr.HTML("<div id='app-subtitle'>Hybrid Vision Transformer for unified 110-class CIFAR image recognition</div>")
with gr.Row():
with gr.Column(scale=1):
image_in = gr.Image(type="pil", label="Upload image", elem_classes=["image-upload-container"])
submit = gr.Button("Classify", variant="primary", size="lg")
clear = gr.Button("Clear", variant="secondary")
examples = gr.Examples(examples=examples_list, inputs=image_in, label="Try example images")
with gr.Column(scale=1):
html_out = gr.HTML(label="Prediction")
json_out = gr.JSON(label="Raw output", elem_classes=["json-output"])
submit.click(predict, inputs=image_in, outputs=json_out).then(pretty_display, inputs=json_out, outputs=html_out)
clear.click(lambda: (None, None, None), outputs=[image_in, html_out, json_out])
if __name__ == "__main__":
demo.launch()