| import gradio as gr | |
| from transformers import AutoFeatureExtractor, SwinForImageClassification | |
| from PIL import Image | |
| import requests | |
| feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-small-patch4-window7-224") | |
| model = SwinForImageClassification.from_pretrained("microsoft/swin-small-patch4-window7-224") | |
| def classify_image(url): | |
| image = Image.open(requests.get(url, stream=True).raw) | |
| inputs = feature_extractor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_idx = logits.argmax(-1).item() | |
| return model.config.id2label[predicted_class_idx] | |
| examples = [ | |
| ["http://images.cocodataset.org/val2017/000000039769.jpg"], | |
| ] | |
| iface = gr.Interface(fn=classify_image, inputs="text", outputs="text", examples=examples) | |
| iface.launch() | |