Spaces:
Sleeping
Sleeping
| import requests | |
| from PIL import Image | |
| from io import BytesIO | |
| import torch | |
| from torchvision import transforms | |
| from transformers import AutoModelForImageClassification, AutoConfig | |
| import gradio as gr | |
| import spaces | |
| model_id = "thelabel/240903-image-tagging" | |
| config = AutoConfig.from_pretrained(model_id) | |
| model = AutoModelForImageClassification.from_pretrained(model_id) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Standard ViT image transforms | |
| image_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| def load_image_from_url(url): | |
| try: | |
| response = requests.get(url, timeout=10) | |
| response.raise_for_status() | |
| return Image.open(BytesIO(response.content)).convert("RGB") | |
| except Exception as e: | |
| return None | |
| def predict_tags(image_url, threshold=0.5): | |
| image = load_image_from_url(image_url) | |
| if image is None: | |
| return [], "Could not load image from the provided URL." | |
| image_tensor = image_transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(image_tensor).logits | |
| probs = torch.sigmoid(logits).squeeze() | |
| results = [ | |
| (config.idx_to_label[str(i)], float(probs[i])) | |
| for i in range(len(probs)) | |
| if probs[i] >= threshold | |
| ] | |
| results.sort(key=lambda x: x[1], reverse=True) | |
| return results, None | |
| def gradio_predict(url, threshold): | |
| tags, error = predict_tags(url, threshold) | |
| if error: | |
| return error, None | |
| return "\n".join([f"{tag}: {score:.2f}" for tag, score in tags]), url | |
| demo = gr.Interface( | |
| fn=gradio_predict, | |
| inputs=[ | |
| gr.Textbox(label="Image URL", value="https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiJoaWNjdXAtaW1hZ2UtaG9zdGluZyIsImtleSI6ImhpY2N1cC1wcm9kdWN0cy9GQVFZTFkyNzFGLmpwZWciLCJlZGl0cyI6eyJyZXNpemUiOnsid2lkdGgiOjI1NjAsImhlaWdodCI6Mzg0MCwiZml0IjoiY292ZXIifX19?v=1748968367"), | |
| gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold"), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Tags"), | |
| gr.Image(label="Preview", type="url"), | |
| ], | |
| title="Image Tagging with ViT", | |
| description="Paste an image URL and get predicted tags using thelabel/240903-image-tagging model.", | |
| examples=[ | |
| [ | |
| "https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiJoaWNjdXAtaW1hZ2UtaG9zdGluZyIsImtleSI6ImhpY2N1cC1wcm9kdWN0cy9GQVFZTFkyNzFGLmpwZWciLCJlZGl0cyI6eyJyZXNpemUiOnsid2lkdGgiOjI1NjAsImhlaWdodCI6Mzg0MCwiZml0IjoiY292ZXIifX19?v=1748968367", 0.5 | |
| ], | |
| [ | |
| "https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiJoaWNjdXAtaW1hZ2UtaG9zdGluZyIsImtleSI6ImhpY2N1cC1wcm9kdWN0cy9ON01aQkpUMDlFLmpwZWciLCJlZGl0cyI6eyJyZXNpemUiOnsid2lkdGgiOjI1NjAsImhlaWdodCI6Mzg0MCwiZml0IjoiY292ZXIifX19?v=1748968367", 0.5 | |
| ] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |