image-tagging / app.py
user-agent's picture
Update app.py
aab5af1 verified
import requests
from PIL import Image
from io import BytesIO
import torch
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoConfig
import gradio as gr
import spaces
import os
token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
model_id = "thelabel/240903-image-tagging"
config = AutoConfig.from_pretrained(model_id, token=token)
model = AutoModelForImageClassification.from_pretrained(model_id, token=token)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def load_image_from_url(url):
try:
response = requests.get(url.strip(), timeout=10)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert("RGB")
except Exception:
return None
@spaces.GPU
def predict_tags(image_url, threshold=0.5):
image = load_image_from_url(image_url)
if image is None:
return None, "Could not load image."
image_tensor = image_transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(image_tensor).logits
probs = torch.sigmoid(logits).squeeze()
results = [
(config.idx_to_label[str(i)], float(probs[i]))
for i in range(len(probs))
if probs[i] >= threshold
]
results.sort(key=lambda x: x[1], reverse=True)
return results, None
def gradio_predict(urls, threshold):
url_list = [u.strip() for u in urls.split(",") if u.strip()]
output = []
for url in url_list:
tags, error = predict_tags(url, threshold)
if error or not tags:
output.append({
"image_url": url,
"error": error or "No tags above threshold."
})
else:
top_tag, top_score = tags[0]
output.append({
"image_url": url,
"tag_name": top_tag,
"tag_score": round(top_score, 4)
})
return str(output) # Return as string for textbox display
demo = gr.Interface(
fn=gradio_predict,
inputs=[
gr.Textbox(label="Image URL(s) (comma-separated)"),
gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold"),
],
outputs=gr.Textbox(label="Tags"),
title="Batch Image Tagging with ViT",
description="Paste one or more image URLs separated by commas to get predicted tags using thelabel/240903-image-tagging model.",
)
if __name__ == "__main__":
demo.launch()