from transformers import ViTFeatureExtractor, ViTForImageClassification from PIL import Image import torch import gradio as gr # Load the model and feature extractor model_name = 'google/vit-base-patch16-224' feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) model = ViTForImageClassification.from_pretrained(model_name) # Function to load and preprocess the image def preprocess_image(image): inputs = feature_extractor(images=image, return_tensors="pt") return inputs['pixel_values'] # Function to predict the class of the image def predict_image(image): pixel_values = preprocess_image(image) with torch.no_grad(): outputs = model(pixel_values) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() return model.config.id2label[predicted_class_idx] # Define the Gradio interface image_input = gr.inputs.Image(type="pil") label_output = gr.outputs.Label(num_top_classes=3) interface = gr.Interface( fn=predict_image, inputs=image_input, outputs=label_output, title="Image Classification with ViT", description="Upload an image and get the predicted label using Vision Transformer (ViT)." ) # Launch the interface if __name__ == "__main__": interface.launch()