Spaces:
Runtime error
Runtime error
| import data | |
| import torch | |
| import gradio as gr | |
| from models import imagebind_model | |
| from models.imagebind_model import ModalityType | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| model = imagebind_model.imagebind_huge(pretrained=True) | |
| model.eval() | |
| model.to(device) | |
| def image_text_zeroshot(image, text_list): | |
| image_paths = [image] | |
| labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] | |
| inputs = { | |
| ModalityType.TEXT: data.load_and_transform_text(text_list, device), | |
| ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), | |
| } | |
| with torch.no_grad(): | |
| embeddings = model(inputs) | |
| scores = torch.softmax( | |
| embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, | |
| dim=-1 | |
| ).squeeze(0).tolist() | |
| score_dict = {label:score for label, score in zip(labels, scores)} | |
| return score_dict | |
| inputs = [ | |
| gr.inputs.Image(type='file', | |
| label="Input image"), | |
| gr.inputs.Textbox(lines=1, | |
| label="Candidate texts"), | |
| ] | |
| iface = gr.Interface(image_text_zeroshot, | |
| inputs, | |
| "label", | |
| examples=[[".assets/dog_image.jpg", "A dog|A car|A bird"], | |
| [".assets/car_image.jpg", "A dog|A car|A bird"], | |
| [".assets/bird_image.jpg", "A dog|A car|A bird"]], | |
| description="""Zeroshot test""", | |
| title="Zero-shot Classification") | |
| iface.launch() |