| from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
| from PIL import Image |
| import gradio as gr |
|
|
| |
| feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") |
| model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") |
|
|
| |
| def classify_image(image): |
| image = Image.fromarray(image).convert("RGB") |
| inputs = feature_extractor(images=image, return_tensors="pt") |
| outputs = model(**inputs) |
| predicted_class_idx = outputs.logits.argmax(-1).item() |
| return model.config.id2label[predicted_class_idx] |
|
|
| |
| app = gr.Interface( |
| fn=classify_image, |
| inputs=gr.Image(type="numpy"), |
| outputs="text", |
| title="Image Classification App" |
| ) |
|
|
| |
| app.launch() |