from transformers import ViTImageProcessor, ViTForImageClassification import torch import gradio as gr from PIL import Image # Load general ViT model (ImageNet pretrained) model_name = "google/vit-base-patch16-224" processor = ViTImageProcessor.from_pretrained(model_name) model = ViTForImageClassification.from_pretrained(model_name) def predict(image): if image is None: return "Please upload an image." # Preprocess image inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=1) conf, predicted_class = torch.max(probs, dim=1) label = model.config.id2label[predicted_class.item()] confidence = conf.item() * 100 # This label will be a general ImageNet class, e.g. "banana", "bee", "daisy" return f"Detected class: {label}\nConfidence: {confidence:.2f}%" gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="General Image Classification with ViT", description="Upload an image to classify using ViT pretrained on ImageNet." ).launch()