Spaces:
Sleeping
Sleeping
| import requests | |
| from PIL import Image | |
| from io import BytesIO | |
| import torch | |
| from torchvision import transforms | |
| from transformers import AutoModelForImageClassification, AutoConfig | |
| import gradio as gr | |
| import spaces | |
| import os | |
| token = os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| model_id = "thelabel/240903-image-tagging" | |
| config = AutoConfig.from_pretrained(model_id, token=token) | |
| model = AutoModelForImageClassification.from_pretrained(model_id, token=token) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| image_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| def load_image_from_url(url): | |
| try: | |
| response = requests.get(url.strip(), timeout=10) | |
| response.raise_for_status() | |
| return Image.open(BytesIO(response.content)).convert("RGB") | |
| except Exception: | |
| return None | |
| def predict_tags(image_url, threshold=0.5): | |
| image = load_image_from_url(image_url) | |
| if image is None: | |
| return None, "Could not load image." | |
| image_tensor = image_transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(image_tensor).logits | |
| probs = torch.sigmoid(logits).squeeze() | |
| results = [ | |
| (config.idx_to_label[str(i)], float(probs[i])) | |
| for i in range(len(probs)) | |
| if probs[i] >= threshold | |
| ] | |
| results.sort(key=lambda x: x[1], reverse=True) | |
| return results, None | |
| def gradio_predict(urls, threshold): | |
| url_list = [u.strip() for u in urls.split(",") if u.strip()] | |
| output = [] | |
| for url in url_list: | |
| tags, error = predict_tags(url, threshold) | |
| if error or not tags: | |
| output.append({ | |
| "image_url": url, | |
| "error": error or "No tags above threshold." | |
| }) | |
| else: | |
| top_tag, top_score = tags[0] | |
| output.append({ | |
| "image_url": url, | |
| "tag_name": top_tag, | |
| "tag_score": round(top_score, 4) | |
| }) | |
| return str(output) # Return as string for textbox display | |
| demo = gr.Interface( | |
| fn=gradio_predict, | |
| inputs=[ | |
| gr.Textbox(label="Image URL(s) (comma-separated)"), | |
| gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold"), | |
| ], | |
| outputs=gr.Textbox(label="Tags"), | |
| title="Batch Image Tagging with ViT", | |
| description="Paste one or more image URLs separated by commas to get predicted tags using thelabel/240903-image-tagging model.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |