Spaces:
Running
Running
File size: 11,617 Bytes
46f8cd8 33482ea 46f8cd8 33482ea d8b8755 46f8cd8 33482ea 46f8cd8 d8b8755 1926dcd 33482ea 46f8cd8 33482ea d8b8755 4b8fed0 4dbd088 33482ea 46f8cd8 4b8fed0 46f8cd8 4b8fed0 46f8cd8 4b8fed0 33482ea 4b8fed0 3ae4ef6 46f8cd8 1926dcd 46f8cd8 1926dcd 46f8cd8 1926dcd 46f8cd8 1926dcd 46f8cd8 1926dcd 46f8cd8 33482ea 1926dcd 46f8cd8 1926dcd 46f8cd8 1926dcd 46f8cd8 1926dcd 46f8cd8 1926dcd 4b8fed0 46f8cd8 33482ea 46f8cd8 4b8fed0 1926dcd 4b8fed0 1926dcd 46f8cd8 33482ea 46f8cd8 4b8fed0 33482ea 3ae4ef6 46f8cd8 33482ea 1926dcd 33482ea 46f8cd8 1926dcd 46f8cd8 d8b8755 46f8cd8 3ae4ef6 46f8cd8 d8b8755 46f8cd8 d8b8755 46f8cd8 33482ea 46f8cd8 33482ea d8b8755 cd9c6b2 84f32f3 9cb20ea 84f32f3 cd9c6b2 84f32f3 cd9c6b2 d8b8755 cd9c6b2 3ae4ef6 d8b8755 b01d1a9 3ae4ef6 b01d1a9 3ae4ef6 b01d1a9 3ae4ef6 b01d1a9 3ae4ef6 b01d1a9 3ae4ef6 b01d1a9 3ae4ef6 b01d1a9 3ae4ef6 b01d1a9 3ae4ef6 b01d1a9 3ae4ef6 b01d1a9 46f8cd8 b01d1a9 3ae4ef6 b01d1a9 3ae4ef6 46f8cd8 3ae4ef6 46f8cd8 3ae4ef6 46f8cd8 3ae4ef6 b01d1a9 46f8cd8 b01d1a9 46f8cd8 3ae4ef6 b01d1a9 3ae4ef6 b01d1a9 d8b8755 46f8cd8 d8b8755 b01d1a9 46f8cd8 cd9c6b2 b01d1a9 46f8cd8 b01d1a9 46f8cd8 b01d1a9 3ae4ef6 46f8cd8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 | # app.py
import os
import sys
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
# Config
CKPT_PATH = "vit_cnn_110class.pt" # put the file in the repo root (or update path)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE, file=sys.stderr)
# Label lists (CIFAR-10 then CIFAR-100 shifted)
cifar10_classes = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
cifar100_classes = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea',
'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank',
'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip',
'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
]
# unified label list 0..109 (0-9 CIFAR10, 10-109 CIFAR100)
LABELS = cifar10_classes + cifar100_classes
# Model architecture
class ConvPatchEmbed(nn.Module):
def __init__(self, in_chans=3, embed_dim=384):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_chans, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, embed_dim, 3, 2, 1, bias=False),
nn.BatchNorm2d(embed_dim),
nn.ReLU(inplace=True),
)
self.n_patches = (32 // 4) ** 2
def forward(self, x):
x = self.conv(x)
x = x.flatten(2).transpose(1,2)
return x
class MLP(nn.Module):
def __init__(self, in_features, hidden_features=None, drop=0.):
super().__init__()
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, in_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x); x = self.act(x); x = self.drop(x)
x = self.fc2(x); x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=6):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim*3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B,N,C = x.shape
qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
q,k,v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2,-1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1,2).reshape(B,N,C)
return self.proj(x)
class _StochasticDepth(nn.Module):
def __init__(self,p): super().__init__(); self.p = p
def forward(self,x):
if not self.training or self.p==0: return x
keep = torch.rand(x.shape[0],1,1,device=x.device) >= self.p
return x * keep / (1 - self.p)
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., drop_path=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads)
self.drop_path = nn.Identity() if drop_path==0 else _StochasticDepth(drop_path)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, int(dim*mlp_ratio), drop)
def forward(self,x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class ViT110(nn.Module):
def __init__(self, emb_dim=384, depth=8, num_heads=6, mlp_ratio=4.0, num_classes=110, drop=0.1, drop_path=0.1):
super().__init__()
cfg = {"in_channels":3, "emb_dim":emb_dim, "depth":depth, "num_heads":num_heads, "mlp_ratio":mlp_ratio, "drop":drop, "drop_path":drop_path}
self.patch_embed = ConvPatchEmbed(cfg["in_channels"], cfg["emb_dim"])
n_patches = self.patch_embed.n_patches
self.cls_token = nn.Parameter(torch.zeros(1,1,cfg["emb_dim"]))
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + n_patches, cfg["emb_dim"]))
self.pos_drop = nn.Dropout(p=cfg["drop"])
dpr = torch.linspace(0, drop_path, depth).tolist()
self.blocks = nn.ModuleList([Block(cfg["emb_dim"], cfg["num_heads"], cfg["mlp_ratio"], drop=cfg["drop"], drop_path=dpr[i]) for i in range(depth)])
self.norm = nn.LayerNorm(cfg["emb_dim"])
self.head = nn.Linear(cfg["emb_dim"], num_classes)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls = self.cls_token.expand(B,-1,-1)
x = torch.cat([cls,x],dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return self.head(x[:,0])
# Load model
def load_model(ckpt_path=CKPT_PATH, device=DEVICE):
model = ViT110().to(device)
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
sd = torch.load(ckpt_path, map_location="cpu")
# sd may be state_dict or plain dict; try both
if "state_dict" in sd and isinstance(sd, dict):
sd = sd["state_dict"]
# filter mismatch keys (if any), load with strict=False
model.load_state_dict(sd, strict=False)
model.eval()
return model
MODEL = load_model()
# Transforms (CIFAR-style)
transform = transforms.Compose([
transforms.Resize(40),
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize((0.5071,0.4867,0.4408),(0.2675,0.2565,0.2761)),
])
# Example images
examples_list = [
["cat.avif"],
["Red_Kangaroo_Peter_and_Shelly_some_rights_res.width-1200.c03bc40.jpg"],
["beagle-hound-dog.webp"],
["niko-photos-tGTVxeOr_Rs-unsplash.jpg"],
["1_9527341a-93b9-4566-9eb3-3bfe92cfed5f.webp"],
["Feng-shui-fish-acquarium_0_1200.jpg.webp"],
["ED-ARTICLE-IMAGES-21.png"],
["apples-101-about-1440x810.webp"],
["beautiful-overhead-cityscape-shot-with-drone.jpg"],
["crocodile-Nile-swath-one-sub-Saharan-Africa-Madagascar.webp"],
["detect(1).jpg"]
]
# UI CSS and pretty display
custom_css = """
/* ---------- GLOBAL ---------- */
body {
font-family: 'Inter', sans-serif !important;
}
.gradio-container {
max-width: 960px !important;
margin: auto !important;
}
#app-title {
text-align: center;
font-size: 30px;
font-weight: 800;
margin-bottom: 6px;
}
#app-subtitle {
text-align: center;
font-size: 15px;
opacity: 0.85;
margin-top: -3px;
margin-bottom: 18px;
}
.image-upload-container {
border-radius: 14px !important;
padding: 12px;
transition: 0.25s ease;
}
.image-upload-container:hover {
box-shadow: 0 8px 22px rgba(0,0,0,0.12);
transform: translateY(-3px);
}
.output-card {
background: var(--block-background-fill);
padding: 18px;
border-radius: 12px;
box-shadow: 0 8px 20px rgba(0,0,0,0.10);
transition: 0.22s ease;
}
.model-badge {
display: inline-block;
padding: 5px 10px;
border-radius: 10px;
font-size: 13px;
font-weight: 700;
margin-bottom: 8px;
background-color: #4f46e5;
color: white;
}
.conf-bar-container { height: 12px; background: #e6e7ea; border-radius: 10px; overflow: hidden; margin-top: 8px; }
.conf-bar { height: 100%; background: linear-gradient(90deg, #10b981, #059669); width: 0%; transition: width 0.8s ease; }
.json-output pre { font-size: 13px; background: #0f1724; color: #e6eef6; border-radius: 8px; padding: 12px; }
.router-meta { font-size: 13px; color: #6b7280; margin-top: 8px; }
"""
# ---------------------------
def predict(img: Image.Image):
if img is None:
return {"error": "no image provided"}
try:
x = transform(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = MODEL(x)
probs = F.softmax(logits, dim=1)[0]
conf, idx = probs.max(0)
conf = float(conf)
idx = int(idx)
label = LABELS[idx]
router_info = {
"class_index": idx,
"pred_label": label,
"confidence": round(conf,6),
"model_used": "Unified ViT-110"
}
return {"predicted_class": label, "class_index": idx, "confidence": conf, "router_info": router_info}
except Exception as e:
return {"error": str(e)}
def pretty_display(result):
if result is None:
return "<div class='output-card'><div class='model-badge'>No prediction</div><div>No result returned.</div></div>"
if "error" in result:
return f"<div class='output-card'><div class='model-badge'>Error</div><div>{result['error']}</div></div>"
cls = result.get("predicted_class", "unknown")
idx = result.get("class_index", -1)
conf = result.get("confidence", 0.0)
conf_pct = round(conf * 100, 2)
info = result.get("router_info", {})
meta_html = f"<div class='router-meta'><b>Index:</b> {idx} | <b>Model:</b> {info.get('model_used','Unified ViT-110')} | <b>Confidence:</b> {conf_pct}%</div>"
html = f"""
<div class="output-card">
<div class="model-badge">Unified ViT-110</div>
<h2 style="margin-top:4px;margin-bottom:6px;font-size:22px;">
Prediction: <span style="color:#10b981;font-weight:700">{cls}</span>
</h2>
<div style="font-size:15px;opacity:0.85;">Confidence: {conf_pct}%</div>
<div class="conf-bar-container"><div class="conf-bar" style="width:{conf_pct}%;"></div></div>
{meta_html}
</div>
"""
return html
# Gradio UI
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.HTML("<div id='app-title'>ViT-Fusion: Hybrid Transformer for 110 CIFAR Classes</div>")
gr.HTML("<div id='app-subtitle'>Hybrid Vision Transformer for unified 110-class CIFAR image recognition</div>")
with gr.Row():
with gr.Column(scale=1):
image_in = gr.Image(type="pil", label="Upload image", elem_classes=["image-upload-container"])
submit = gr.Button("Classify", variant="primary", size="lg")
clear = gr.Button("Clear", variant="secondary")
examples = gr.Examples(examples=examples_list, inputs=image_in, label="Try example images")
with gr.Column(scale=1):
html_out = gr.HTML(label="Prediction")
json_out = gr.JSON(label="Raw output", elem_classes=["json-output"])
submit.click(predict, inputs=image_in, outputs=json_out).then(pretty_display, inputs=json_out, outputs=html_out)
clear.click(lambda: (None, None, None), outputs=[image_in, html_out, json_out])
if __name__ == "__main__":
demo.launch()
|