vit-cat-dog / app.py
Parth1503's picture
Upload app.py
a88b562 verified
import gradio as gr
import torch
import timm
from torchvision import transforms
from PIL import Image
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=2)
model.load_state_dict(torch.load("vis_trans_cat_dog.pth", map_location='cpu'))
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
def predict(img):
img = img.convert("RGB")
input_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
return "Cat" if predicted.item() == 0 else "Dog"
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(),
title="ViT Cat vs Dog Classifier 🐱🐶",
description="Upload an image of a cat or dog and get a prediction from a Vision Transformer model."
)
if __name__ == "__main__":
interface.launch()