ptschandl's picture
Update app.py
260b822 verified
import torch
import gradio as gr
from PIL import Image
from urllib.request import urlopen
from open_clip import create_model_from_pretrained, get_tokenizer
# Load the model and tokenizer from the Hugging Face Hub
model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
# Zero-shot image classification
template = 'this is a photo of '
# Device configuration
device = torch.device('mps') if torch.mps.is_available() else torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()
def classify_image(image, candidate_labels):
# Convert candidate_labels string to a list
labels = [label.strip() for label in candidate_labels.split(",")]
context_length = 256
# Preprocess the image
image_input = preprocess(image).unsqueeze(0).to(device)
# Tokenize the candidate labels
texts = tokenizer([template + label for label in labels], context_length=context_length).to(device)
# Perform inference
with torch.no_grad():
image_features, text_features, logit_scale = model(image_input, texts)
logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
sorted_indices = torch.argsort(logits, dim=-1, descending=True)
logits = logits.cpu().numpy()
sorted_indices = sorted_indices.cpu().numpy()
# Prepare the results
results = []
for j in range(len(labels)):
jth_index = sorted_indices[0][j]
results.append({
"label": labels[jth_index],
"score": float(logits[0][jth_index])
})
return results
# Create the Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(lines=2, placeholder="Enter candidate labels, separated by commas..."),
],
outputs=gr.JSON(),
title="Zero-Shot Image Classification",
description="Technical example of using the `microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224` model.\nUpload an image and enter candidate labels to classify the image."
)
# Launch the interface
iface.launch()