Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import clip | |
| from PIL import Image | |
| print("Getting device...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Loading model...") | |
| model, preprocess = clip.load("ViT-B/32", device=device) | |
| print("Loaded model.") | |
| def process(image, prompt): | |
| print("Inferring...") | |
| image = preprocess(image).unsqueeze(0).to(device) | |
| print("Image: ", image) | |
| prompts = prompt.split("\n") | |
| print("Prompts: ", prompts) | |
| text = clip.tokenize(prompts).to(device) | |
| print("Tokens: ", text) | |
| with torch.no_grad(): | |
| logits_per_image, logits_per_text = model(image, text) | |
| probs = logits_per_image.softmax(dim=-1).cpu() | |
| print("Probs: ", probs) | |
| return {k: v.item() for (k,v) in zip(prompts, probs[0])} | |
| iface = gr.Interface( | |
| fn=process, | |
| inputs=[ | |
| gr.Image(type="pil", label="Image"), | |
| gr.Textbox(lines=5, label="Prompts (newline-separated)"), | |
| ], | |
| outputs="label", | |
| examples=[ | |
| ["dog.jpg", "a photo of a dog\na photo of a cat"], | |
| ["cat.jpg", "a photo of a dog\na photo of a cat"], | |
| ["car.jpg", "a red car on a golf course\na red sports car on a road\na blue sports car\na red family car"] | |
| ] | |
| ) | |
| iface.launch() | |