ptschandl commited on
Commit
065ca40
·
verified ·
1 Parent(s): 00f18b0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from urllib.request import urlopen
5
+ from open_clip import create_model_from_pretrained, get_tokenizer
6
+
7
+ # Load the model and tokenizer from the Hugging Face Hub
8
+ model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
9
+ tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
10
+
11
+ # Zero-shot image classification
12
+ template = 'this is a photo of '
13
+
14
+ # Device configuration
15
+ device = torch.device('mps') if torch.mps.is_available() else torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
16
+ model.to(device)
17
+ model.eval()
18
+
19
+ def classify_image(image, candidate_labels):
20
+ # Convert candidate_labels string to a list
21
+ labels = [label.strip() for label in candidate_labels.split(",")]
22
+ context_length = 256
23
+
24
+ # Preprocess the image
25
+ image_input = preprocess(image).unsqueeze(0).to(device)
26
+
27
+ # Tokenize the candidate labels
28
+ texts = tokenizer([template + label for label in labels], context_length=context_length).to(device)
29
+
30
+ # Perform inference
31
+ with torch.no_grad():
32
+ image_features, text_features, logit_scale = model(image_input, texts)
33
+ logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
34
+ sorted_indices = torch.argsort(logits, dim=-1, descending=True)
35
+ logits = logits.cpu().numpy()
36
+ sorted_indices = sorted_indices.cpu().numpy()
37
+
38
+ # Prepare the results
39
+ results = []
40
+ for j in range(len(labels)):
41
+ jth_index = sorted_indices[0][j]
42
+ results.append({
43
+ "label": labels[jth_index],
44
+ "score": float(logits[0][jth_index])
45
+ })
46
+
47
+ return results
48
+
49
+ # Create the Gradio interface
50
+ iface = gr.Interface(
51
+ fn=classify_image,
52
+ inputs=[
53
+ gr.Image(type="pil", label="Upload Image"),
54
+ gr.Textbox(lines=2, placeholder="Enter candidate labels, separated by commas..."),
55
+ ],
56
+ outputs=gr.JSON(),
57
+ title="Zero-Shot Image Classification",
58
+ description="Upload an image and enter candidate labels to classify the image."
59
+ )
60
+
61
+ # Launch the interface
62
+ iface.launch()