image-tagging / app.py
user-agent's picture
Update app.py
4ecab25 verified
raw
history blame
2.96 kB
import requests
from PIL import Image
from io import BytesIO
import torch
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoConfig
import gradio as gr
import spaces
model_id = "thelabel/240903-image-tagging"
config = AutoConfig.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Standard ViT image transforms
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def load_image_from_url(url):
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
return None
@spaces.GPU
def predict_tags(image_url, threshold=0.5):
image = load_image_from_url(image_url)
if image is None:
return [], "Could not load image from the provided URL."
image_tensor = image_transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(image_tensor).logits
probs = torch.sigmoid(logits).squeeze()
results = [
(config.idx_to_label[str(i)], float(probs[i]))
for i in range(len(probs))
if probs[i] >= threshold
]
results.sort(key=lambda x: x[1], reverse=True)
return results, None
def gradio_predict(url, threshold):
tags, error = predict_tags(url, threshold)
if error:
return error, None
return "\n".join([f"{tag}: {score:.2f}" for tag, score in tags]), url
demo = gr.Interface(
fn=gradio_predict,
inputs=[
gr.Textbox(label="Image URL", value="https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiJoaWNjdXAtaW1hZ2UtaG9zdGluZyIsImtleSI6ImhpY2N1cC1wcm9kdWN0cy9GQVFZTFkyNzFGLmpwZWciLCJlZGl0cyI6eyJyZXNpemUiOnsid2lkdGgiOjI1NjAsImhlaWdodCI6Mzg0MCwiZml0IjoiY292ZXIifX19?v=1748968367"),
gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold"),
],
outputs=[
gr.Textbox(label="Tags"),
gr.Image(label="Preview", type="url"),
],
title="Image Tagging with ViT",
description="Paste an image URL and get predicted tags using thelabel/240903-image-tagging model.",
examples=[
[
"https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiJoaWNjdXAtaW1hZ2UtaG9zdGluZyIsImtleSI6ImhpY2N1cC1wcm9kdWN0cy9GQVFZTFkyNzFGLmpwZWciLCJlZGl0cyI6eyJyZXNpemUiOnsid2lkdGgiOjI1NjAsImhlaWdodCI6Mzg0MCwiZml0IjoiY292ZXIifX19?v=1748968367", 0.5
],
[
"https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiJoaWNjdXAtaW1hZ2UtaG9zdGluZyIsImtleSI6ImhpY2N1cC1wcm9kdWN0cy9ON01aQkpUMDlFLmpwZWciLCJlZGl0cyI6eyJyZXNpemUiOnsid2lkdGgiOjI1NjAsImhlaWdodCI6Mzg0MCwiZml0IjoiY292ZXIifX19?v=1748968367", 0.5
]
]
)
if __name__ == "__main__":
demo.launch()