Aumkeshchy2003's picture
Update app.py
f47f169 verified
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import json
import torch.nn as nn
import torch.nn.functional as F
# ViT model implementation - matching checkpoint architecture
class ConvPatchEmbed(nn.Module):
"""Conv stem that produces 32x32 patches from 32x32 input"""
def __init__(self, in_chans=3, embed_dim=128):
super().__init__()
# Adjusted to produce 32x32 patches instead of 8x8
# Use stride=1 for all convs to maintain spatial resolution
self.conv = nn.Sequential(
nn.Conv2d(in_chans, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, embed_dim, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(embed_dim),
nn.ReLU(inplace=True),
)
# n_patches = 32*32 = 1024 (no spatial downsampling)
self.n_patches = 32 * 32
def forward(self, x):
# x: (B, C, H, W) - (B, 3, 32, 32)
x = self.conv(x) # (B, E, H, W) - (B, 128, 32, 32)
x = x.flatten(2) # (B, E, N) - (B, 128, 1024)
x = x.transpose(1, 2) # (B, N, E) - (B, 1024, 128)
return x
class MLP(nn.Module):
def __init__(self, in_features, hidden_features=None, drop=0.):
super().__init__()
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, in_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2,0,3,1,4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1,2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = nn.Identity() if drop_path == 0. else _StochasticDepth(drop_path)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, int(dim*mlp_ratio), drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class _StochasticDepth(nn.Module):
def __init__(self, p):
super().__init__()
self.p = p
def forward(self, x):
if not self.training or self.p == 0.:
return x
keep = torch.rand(x.shape[0], 1, 1, device=x.device) >= self.p
return x * keep / (1 - self.p)
class ViT(nn.Module):
def __init__(self, cfg):
super().__init__()
img_size, patch_size = cfg["image_size"], cfg["patch_size"]
# Use ConvPatchEmbed to match the checkpoint architecture
self.patch_embed = ConvPatchEmbed(cfg["in_channels"], cfg["emb_dim"])
n_patches = self.patch_embed.n_patches
self.cls_token = nn.Parameter(torch.zeros(1,1,cfg["emb_dim"]))
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + n_patches, cfg["emb_dim"]))
self.pos_drop = nn.Dropout(p=cfg["drop"])
# transformer blocks
dpr = [x.item() for x in torch.linspace(0, cfg.get("drop_path", 0.2), cfg["depth"])]
self.blocks = nn.ModuleList([
Block(cfg["emb_dim"], num_heads=cfg["num_heads"], mlp_ratio=cfg["mlp_ratio"],
drop=cfg["drop"], drop_path=dpr[i])
for i in range(cfg["depth"])
])
self.norm = nn.LayerNorm(cfg["emb_dim"])
self.head = nn.Linear(cfg["emb_dim"], cfg["num_classes"])
# init
nn.init.trunc_normal_(self.pos_embed, std=.02)
nn.init.trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if getattr(m, "bias", None) is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
cls = x[:, 0]
out = self.head(cls)
return out
# === Load config and model ===
with open("config.json", "r") as f:
cfg = json.load(f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT(cfg)
model.load_state_dict(torch.load("best_vit_bird_cls_mps_safe.pt", map_location=device))
model.to(device).eval()
# === Preprocessing ===
mean = (0.4914, 0.4822, 0.4465)
std = (0.247, 0.243, 0.261)
transform = transforms.Compose([
transforms.Resize((cfg["image_size"], cfg["image_size"])),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
# === Bird species labels ===
class_names = [
"Common_Myna",
"Eurasian_Collared-Dove",
"Female_Rose_Ringed_Parakeet",
"House_Crow",
"Male_Rose_Ringed_Parakeet",
"Rufous_Treepie",
"Silver_Bill"
]
# === Prediction function ===
def predict(image):
try:
# Handle None or invalid input
if image is None:
return {}
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Ensure RGB mode
if image.mode != 'RGB':
image = image.convert('RGB')
# CRITICAL: Explicitly resize to match model's expected input size
image = image.resize((cfg["image_size"], cfg["image_size"]), Image.BILINEAR)
# Transform and predict
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image_tensor)
probs = torch.softmax(outputs, dim=1)[0]
# Get top 5 predictions - ensure all values are floats
top5_prob, top5_idx = probs.topk(min(5, len(class_names)))
results = {class_names[int(i)]: float(top5_prob[j]) for j, i in enumerate(top5_idx)}
return results
except Exception as e:
print(f"Error in prediction: {str(e)}")
import traceback
traceback.print_exc()
# Return empty dict on error to avoid type issues
return {}
# === Gradio Interface ===
app = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Bird Image"),
outputs=gr.Label(num_top_classes=5, label="Predicted Species"),
title="🐦 Bird Species Classifier (ViT)",
description="Upload an image of a bird and the ViT model will classify its species.",
examples=[
"frame_000131.jpg",
"frame_000181.jpg",
"frame_000211.jpg",
"frame_000313.jpg",
"frame_000665.jpg",
"Screenshot 2025-11-12 at 4.14.53 PM.png"
]
)
if __name__ == "__main__":
app.launch()