import gradio as gr import torch import timm from torchvision import transforms from PIL import Image model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=2) model.load_state_dict(torch.load("vis_trans_cat_dog.pth", map_location='cpu')) model.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) def predict(img): img = img.convert("RGB") input_tensor = transform(img).unsqueeze(0) with torch.no_grad(): outputs = model(input_tensor) _, predicted = torch.max(outputs, 1) return "Cat" if predicted.item() == 0 else "Dog" interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(), title="ViT Cat vs Dog Classifier 🐱🐶", description="Upload an image of a cat or dog and get a prediction from a Vision Transformer model." ) if __name__ == "__main__": interface.launch()