SemanticVision / app.py
cfoli's picture
Update app.py
231e4ea verified
# -*- coding: utf-8 -*-
"""zero-shot-classification.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1dme6a-Yhl1xYbXobqu56YnjKyr0q9VbL
### Summary
---
These models perform zero-shot classification of images of "homographs" (words that could be associated with different meanings/objects/concepts, whithout any changes in spelling or pronunciation)
###### *** 02.2026
### Load dependencies
---
"""
import torch
from transformers import pipeline
import os
import gradio
"""### Initialize containers
---
"""
MODEL_CACHE = {}
MODEL_OPTIONS = {
"CLIP-base": "openai/clip-vit-base-patch32",
"CLIP-large": "openai/clip-vit-large-patch14",
# "SigLIP-base": "google/siglip-base-patch16-224",
# "SigLIP-large": "google/siglip-so400m-patch14-384"
"ALIGN": "kakaobrain/align-base"
}
CANDIDATE_LABELS = ["a bat (baseball)", "a bat (mammal)",
"a flower (plant)", "flour (baking powder)",
"a mouse (mammal)", "a mouse (electronic)",
"a nail (human finger)", "a nail (metal)",
"a nut (fruit seed)", "a nut (metal)"]
LABELS_MAP = ["Bat (baseball)", "Bat (mammal)",
"Flower (plant)", "Flour (baking powder)",
"Mouse (mammal)", "Mouse (electronic)",
"Nail (human finger)", "Nail (metal)",
"Nut (fruit seed)", "Nut (metal)"]
"""### Select, load and run pre-trained zero-shot multi-modal model
---
"""
def run_classifer(model_key, image_path, prob_threshold):
# model_key: name of backbone zero-shot-image-classification model to use
# image_path: path to test image
# prob_threshold: confidence (i.e., probability) threshold above which to consider a prediction valid
# if prob_threshold is None:
# prob_threshold = 0.4
device = 0 if torch.cuda.is_available() else -1
# Load model (cache for speed)
if model_key not in MODEL_CACHE:
MODEL_CACHE[model_key] = pipeline(task = "zero-shot-image-classification",
model = MODEL_OPTIONS[model_key],
device = device)
classifier = MODEL_CACHE[model_key]
outputs = classifier(
image = image_path,
candidate_labels = CANDIDATE_LABELS,
hypothesis_template = "This image shows {}")
# label_str = f"This image shows {output[0]["label"]}",
# prob_str = f"{100*output[0]["score"]: .1f}%"
label_lookup = dict(zip(CANDIDATE_LABELS, LABELS_MAP))
# Dictionary mapping all candidate labels to their predicted probabilities
prob_dict = {label_lookup[output['label']]: round(output["score"], 4) for output in outputs}
predicted_label_str = f"This image shows {outputs[0]['label']} | Confidence (probability): {100*outputs[0]['score']:.1f}%" if float(outputs[0]['score']) > prob_threshold else "No prediction"
# predicted_label_str = f"This image shows {outputs[0]['label']} | Confidence (probability): {100*outputs[0]['score']:.1f}%"
return predicted_label_str, prob_dict
"""### Gradio App
---
"""
example_list_dir = os.path.join(os.getcwd(), "Example images")
example_list_img_names = os.listdir(example_list_dir)
example_list = [
["CLIP-base", os.path.join(example_list_dir, example_img), 0.4]
for example_img in example_list_img_names
if example_img.lower().endswith(".png")]
gradio_app = gradio.Interface(
fn = run_classifer,
inputs = [gradio.Dropdown(["CLIP-base", "CLIP-large", "ALIGN"], value="CLIP-base", label = "Select Classifier"),
gradio.Image(type="pil", label="Load sample image here"),
gradio.Slider(minimum = 0.1, maximum = 0.9, step = 0.05, value = 0.25, label = "Set Prediction Threshold")
],
outputs = [gradio.Textbox(label="Image Classification"),
gradio.Label(label="Prediction Probabilities", show_label=False)],
examples = example_list,
cache_examples = False,
title = "SemanticVision",
description = "Vision-language models for zero-shot classification of images of homophones and homographs",
article = "Author: C. Foli (02.2026) | Website: coming soon...")
gradio_app.launch()