|
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
|
from PIL import Image |
|
|
import torch |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
model_name = 'google/vit-base-patch16-224' |
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) |
|
|
model = ViTForImageClassification.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
def preprocess_image(image): |
|
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
|
return inputs['pixel_values'] |
|
|
|
|
|
|
|
|
def predict_image(image): |
|
|
pixel_values = preprocess_image(image) |
|
|
with torch.no_grad(): |
|
|
outputs = model(pixel_values) |
|
|
logits = outputs.logits |
|
|
predicted_class_idx = logits.argmax(-1).item() |
|
|
return model.config.id2label[predicted_class_idx] |
|
|
|
|
|
|
|
|
image_input = gr.inputs.Image(type="pil") |
|
|
label_output = gr.outputs.Label(num_top_classes=3) |
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=predict_image, |
|
|
inputs=image_input, |
|
|
outputs=label_output, |
|
|
title="Image Classification with ViT", |
|
|
description="Upload an image and get the predicted label using Vision Transformer (ViT)." |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
interface.launch() |
|
|
|