File size: 2,278 Bytes
065ca40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260b822
065ca40
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import gradio as gr
from PIL import Image
from urllib.request import urlopen
from open_clip import create_model_from_pretrained, get_tokenizer

# Load the model and tokenizer from the Hugging Face Hub
model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

# Zero-shot image classification
template = 'this is a photo of '

# Device configuration
device = torch.device('mps') if torch.mps.is_available() else torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()

def classify_image(image, candidate_labels):
    # Convert candidate_labels string to a list
    labels = [label.strip() for label in candidate_labels.split(",")]
    context_length = 256

    # Preprocess the image
    image_input = preprocess(image).unsqueeze(0).to(device)

    # Tokenize the candidate labels
    texts = tokenizer([template + label for label in labels], context_length=context_length).to(device)

    # Perform inference
    with torch.no_grad():
        image_features, text_features, logit_scale = model(image_input, texts)
        logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
        sorted_indices = torch.argsort(logits, dim=-1, descending=True)
        logits = logits.cpu().numpy()
        sorted_indices = sorted_indices.cpu().numpy()

    # Prepare the results
    results = []
    for j in range(len(labels)):
        jth_index = sorted_indices[0][j]
        results.append({
            "label": labels[jth_index],
            "score": float(logits[0][jth_index])
        })

    return results

# Create the Gradio interface
iface = gr.Interface(
    fn=classify_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(lines=2, placeholder="Enter candidate labels, separated by commas..."),
    ],
    outputs=gr.JSON(),
    title="Zero-Shot Image Classification",
    description="Technical example of using the `microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224` model.\nUpload an image and enter candidate labels to classify the image."
)

# Launch the interface
iface.launch()