cfoli commited on
Commit
3e94470
·
verified ·
1 Parent(s): 5168a3d

Upload zero_shot_classification.py

Browse files
Files changed (1) hide show
  1. zero_shot_classification.py +153 -0
zero_shot_classification.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """zero-shot-classification.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1dme6a-Yhl1xYbXobqu56YnjKyr0q9VbL
8
+
9
+ ### Summary
10
+
11
+ ---
12
+
13
+ These models perform zero-shot classification of images of "homographs" (words that could be associated with different meanings/objects/concepts, whithout any changes in spelling or pronunciation)
14
+ ###### *** 02.2026
15
+
16
+ ### Load dependencies
17
+
18
+ ---
19
+ """
20
+
21
+ import torch
22
+ from transformers import pipeline
23
+ import os
24
+ import gradio
25
+
26
+ """### Initialize containers
27
+
28
+ ---
29
+
30
+
31
+ """
32
+
33
+ MODEL_CACHE = {}
34
+
35
+ MODEL_OPTIONS = {
36
+ "CLIP-base": "openai/clip-vit-base-patch32",
37
+ "CLIP-large": "openai/clip-vit-large-patch14",
38
+ "SigLIP-base": "google/siglip-base-patch16-224",
39
+ "SigLIP-large": "google/siglip-so400m-patch14-384"}
40
+
41
+ CANDIDATE_LABELS = ["a bat (baseball)", "a bat (mammal)",
42
+ "a flower (plant)", "flour (baking powder)",
43
+ "a mouse (mammal)", "a mouse (electronic)",
44
+ "a nail (human finger)", "a nail (metal)",
45
+ "a nut (fruit seed)", "a nut (metal)"]
46
+
47
+ LABELS_MAP = ["Bat (baseball)", "Bat (mammal)",
48
+ "Flower (plant)", "Flour (baking powder)",
49
+ "Mouse (mammal)", "Mouse (electronic)",
50
+ "Nail (human finger)", "Nail (metal)",
51
+ "Nut (fruit seed)", "Nut (metal)"]
52
+
53
+ """### Select, load and run pre-trained zero-shot multi-modal model
54
+
55
+ ---
56
+
57
+
58
+ """
59
+
60
+ model_key = "CLIP-large"
61
+
62
+ # Load model (cache for speed)
63
+ if model_key not in MODEL_CACHE:
64
+ MODEL_CACHE[model_key] = pipeline(task = "zero-shot-image-classification",
65
+ model = MODEL_OPTIONS[model_key])
66
+ classifier = MODEL_CACHE[model_key]
67
+
68
+ BASE_DIR = '/content/drive/MyDrive/ML Projects/Zero-shot Image Classification/Images'
69
+ image_path = os.path.join(BASE_DIR, 'Mouse1_2.png')
70
+
71
+ output = classifier(
72
+ image = image_path,
73
+ candidate_labels = CANDIDATE_LABELS,
74
+ hypothesis_template = "This image shows {}")
75
+
76
+ print("\n\n=============================================================================")
77
+ print(f"\nPrediction: This image shows {output[0]["label"]} | Confidence (probability): {100*output[0]["score"]: .1f}%")
78
+
79
+ def run_classifer(model_key, image_path, prob_threshold = None):
80
+ # model_key: name of backbone zero-shot-image-classification model to use
81
+ # image_path: path to test image
82
+ # prob_threshold: confidence (i.e., probability) threshold above which to consider a prediction valid
83
+
84
+ device = 0 if torch.cuda.is_available() else -1
85
+
86
+ # Load model (cache for speed)
87
+ if model_key not in MODEL_CACHE:
88
+ MODEL_CACHE[model_key] = pipeline(task = "zero-shot-image-classification",
89
+ model = MODEL_OPTIONS[model_key],
90
+ device = device)
91
+ classifier = MODEL_CACHE[model_key]
92
+
93
+ outputs = classifier(
94
+ image = image_path,
95
+ candidate_labels = CANDIDATE_LABELS,
96
+ hypothesis_template = "This image shows {}")
97
+
98
+ # label_str = f"This image shows {output[0]["label"]}",
99
+ # prob_str = f"{100*output[0]["score"]: .1f}%"
100
+
101
+ label_lookup = dict(zip(CANDIDATE_LABELS, LABELS_MAP))
102
+
103
+ # Dictionary mapping all candidate labels to their predicted probabilities
104
+ prob_dict = {label_lookup[output['label']]: round(output["score"], 4) for output in outputs}
105
+
106
+ predicted_label_str = f"This image shows {outputs[0]['label']} | Confidence (probability): {100*outputs[0]['score']:.1f}%" if outputs[0]['score'] > prob_threshold else "No prediction"
107
+
108
+ return predicted_label_str, prob_dict
109
+
110
+ # example run
111
+ model_key = "CLIP-large"
112
+ BASE_DIR = '/content/drive/MyDrive/ML Projects/Zero-shot Image Classification/Images'
113
+ image_path = os.path.join(BASE_DIR, 'Nail2_1.png')
114
+
115
+ predicted_label_str, prob_dict = run_classifer(model_key, image_path, prob_threshold = 0.4)
116
+ print("\n\n=============================================================================")
117
+ # print(f"\nPrediction: {predicted_label_str} | Confidence (probability): {100*output[0]['score']:.1f}%")
118
+ print(f"\nPrediction: {predicted_label_str}")
119
+
120
+ prob_dict
121
+
122
+ """### Gradio App
123
+
124
+ ---
125
+
126
+
127
+ """
128
+
129
+ example_list_dir = os.path.join(os.getcwd(), "Example images")
130
+ example_list_img_names = os.listdir(example_list_dir)
131
+
132
+ example_list = [
133
+ ["CLIP-large", os.path.join(os.getcwd(), example_img)]
134
+ for example_img in example_list_img_names
135
+ if example_img.lower().endswith(".png")]
136
+
137
+ gradio_app = gradio.Interface(
138
+ fn = run_classifer,
139
+ inputs = [gradio.Dropdown(["CLIP-base", "CLIP-large", "SigLIP-base", "SigLIP-large"], value="CLIP-large", label = "Select Classifier"),
140
+ gradio.Image(type="pil", label="Load sample image here"),
141
+ gradio.Slider(minimum = 0.1, maximum = 0.9, step = 0.05, value = 0.25, label = "Set Prediction Threshold")
142
+ ],
143
+
144
+ outputs = [gradio.Textbox(label="Image Classification"),
145
+ gradio.Label(label="Prediction Probabilities", show_label=False)],
146
+
147
+ examples = example_list,
148
+ cache_examples = True,
149
+ title = "ChestVision",
150
+ description = "Multi-modal models for zero-shot classification of images of homophones and homographs",
151
+ article = "Author: C. Foli (02.2026) | Website: coming soon...")
152
+
153
+ gradio_app.launch()