File size: 4,292 Bytes
3e94470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7b5a5d
 
0718174
a7b5a5d
3e94470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56bb014
3e94470
 
 
 
614a821
 
005bf67
3e94470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56bb014
 
3e94470
 
 
 
 
 
 
 
 
 
fcf5257
3e94470
 
 
3fd308e
3e94470
 
 
 
 
0718174
3e94470
56bb014
3e94470
 
 
 
 
 
826f6b8
 
231e4ea
3e94470
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# -*- coding: utf-8 -*-
"""zero-shot-classification.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1dme6a-Yhl1xYbXobqu56YnjKyr0q9VbL

### Summary

---

These models perform zero-shot classification of images of "homographs" (words that could be associated with different meanings/objects/concepts, whithout any changes in spelling or pronunciation)
###### *** 02.2026

### Load dependencies

---
"""

import torch
from transformers import pipeline
import os
import gradio

"""### Initialize containers

---


"""

MODEL_CACHE = {}

MODEL_OPTIONS = {
    "CLIP-base":    "openai/clip-vit-base-patch32",
    "CLIP-large":   "openai/clip-vit-large-patch14",
    # "SigLIP-base":  "google/siglip-base-patch16-224",
    # "SigLIP-large": "google/siglip-so400m-patch14-384"
    "ALIGN": "kakaobrain/align-base"
}

CANDIDATE_LABELS = ["a bat (baseball)", "a bat (mammal)",
                    "a flower (plant)", "flour (baking powder)",
                    "a mouse (mammal)", "a mouse (electronic)",
                    "a nail (human finger)", "a nail (metal)",
                    "a nut (fruit seed)", "a nut (metal)"]

LABELS_MAP = ["Bat (baseball)", "Bat (mammal)",
              "Flower (plant)", "Flour (baking powder)",
              "Mouse (mammal)", "Mouse (electronic)",
              "Nail (human finger)", "Nail (metal)",
              "Nut (fruit seed)", "Nut (metal)"]

"""### Select, load and run pre-trained zero-shot multi-modal model

---


"""

def run_classifer(model_key, image_path, prob_threshold):
    # model_key: name of backbone zero-shot-image-classification model to use
    # image_path: path to test image
    # prob_threshold: confidence (i.e., probability) threshold above which to consider a prediction valid

    # if prob_threshold is None:
    #     prob_threshold = 0.4
    
    device = 0 if torch.cuda.is_available() else -1

    # Load model (cache for speed)
    if model_key not in MODEL_CACHE:
      MODEL_CACHE[model_key]     = pipeline(task  = "zero-shot-image-classification",
                          model  = MODEL_OPTIONS[model_key],
                          device = device)
    classifier = MODEL_CACHE[model_key]

    outputs = classifier(
        image               = image_path,
        candidate_labels    = CANDIDATE_LABELS,
        hypothesis_template = "This image shows {}")

    # label_str = f"This image shows {output[0]["label"]}",
    # prob_str  = f"{100*output[0]["score"]: .1f}%"

    label_lookup = dict(zip(CANDIDATE_LABELS, LABELS_MAP))

    # Dictionary mapping all candidate labels to their predicted probabilities
    prob_dict    = {label_lookup[output['label']]: round(output["score"], 4) for output in outputs}

    predicted_label_str = f"This image shows {outputs[0]['label']} | Confidence (probability): {100*outputs[0]['score']:.1f}%" if float(outputs[0]['score']) > prob_threshold else "No prediction"
    # predicted_label_str = f"This image shows {outputs[0]['label']} | Confidence (probability): {100*outputs[0]['score']:.1f}%" 

    return predicted_label_str, prob_dict

"""### Gradio App

---


"""

example_list_dir = os.path.join(os.getcwd(), "Example images")
example_list_img_names = os.listdir(example_list_dir)

example_list = [
    ["CLIP-base", os.path.join(example_list_dir, example_img), 0.4]
    for example_img in example_list_img_names
    if example_img.lower().endswith(".png")]

gradio_app = gradio.Interface(
    fn     = run_classifer,
    inputs = [gradio.Dropdown(["CLIP-base", "CLIP-large", "ALIGN"], value="CLIP-base", label  = "Select Classifier"),
              gradio.Image(type="pil", label="Load sample image here"),
              gradio.Slider(minimum = 0.1, maximum = 0.9, step = 0.05, value = 0.25, label = "Set Prediction Threshold")
              ],

    outputs = [gradio.Textbox(label="Image Classification"),
             gradio.Label(label="Prediction Probabilities", show_label=False)],

    examples       = example_list,
    cache_examples = False,
    title          = "SemanticVision",
    description    = "Vision-language models for zero-shot classification of images of homophones and homographs",
    article        = "Author: C. Foli (02.2026) | Website: coming soon...")

gradio_app.launch()