| | import torch |
| | import gradio as gr |
| | from PIL import Image |
| | from urllib.request import urlopen |
| | from open_clip import create_model_from_pretrained, get_tokenizer |
| |
|
| | |
| | model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') |
| | tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') |
| |
|
| | |
| | template = 'this is a photo of ' |
| |
|
| | |
| | device = torch.device('mps') if torch.mps.is_available() else torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
| | model.to(device) |
| | model.eval() |
| |
|
| | def classify_image(image, candidate_labels): |
| | |
| | labels = [label.strip() for label in candidate_labels.split(",")] |
| | context_length = 256 |
| |
|
| | |
| | image_input = preprocess(image).unsqueeze(0).to(device) |
| |
|
| | |
| | texts = tokenizer([template + label for label in labels], context_length=context_length).to(device) |
| |
|
| | |
| | with torch.no_grad(): |
| | image_features, text_features, logit_scale = model(image_input, texts) |
| | logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1) |
| | sorted_indices = torch.argsort(logits, dim=-1, descending=True) |
| | logits = logits.cpu().numpy() |
| | sorted_indices = sorted_indices.cpu().numpy() |
| |
|
| | |
| | results = [] |
| | for j in range(len(labels)): |
| | jth_index = sorted_indices[0][j] |
| | results.append({ |
| | "label": labels[jth_index], |
| | "score": float(logits[0][jth_index]) |
| | }) |
| |
|
| | return results |
| |
|
| | |
| | iface = gr.Interface( |
| | fn=classify_image, |
| | inputs=[ |
| | gr.Image(type="pil", label="Upload Image"), |
| | gr.Textbox(lines=2, placeholder="Enter candidate labels, separated by commas..."), |
| | ], |
| | outputs=gr.JSON(), |
| | title="Zero-Shot Image Classification", |
| | description="Upload an image and enter candidate labels to classify the image." |
| | ) |
| |
|
| | |
| | iface.launch() |
| |
|