Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """zero-shot-classification.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1dme6a-Yhl1xYbXobqu56YnjKyr0q9VbL | |
| ### Summary | |
| --- | |
| These models perform zero-shot classification of images of "homographs" (words that could be associated with different meanings/objects/concepts, whithout any changes in spelling or pronunciation) | |
| ###### *** 02.2026 | |
| ### Load dependencies | |
| --- | |
| """ | |
| import torch | |
| from transformers import pipeline | |
| import os | |
| import gradio | |
| """### Initialize containers | |
| --- | |
| """ | |
| MODEL_CACHE = {} | |
| MODEL_OPTIONS = { | |
| "CLIP-base": "openai/clip-vit-base-patch32", | |
| "CLIP-large": "openai/clip-vit-large-patch14", | |
| # "SigLIP-base": "google/siglip-base-patch16-224", | |
| # "SigLIP-large": "google/siglip-so400m-patch14-384" | |
| "ALIGN": "kakaobrain/align-base" | |
| } | |
| CANDIDATE_LABELS = ["a bat (baseball)", "a bat (mammal)", | |
| "a flower (plant)", "flour (baking powder)", | |
| "a mouse (mammal)", "a mouse (electronic)", | |
| "a nail (human finger)", "a nail (metal)", | |
| "a nut (fruit seed)", "a nut (metal)"] | |
| LABELS_MAP = ["Bat (baseball)", "Bat (mammal)", | |
| "Flower (plant)", "Flour (baking powder)", | |
| "Mouse (mammal)", "Mouse (electronic)", | |
| "Nail (human finger)", "Nail (metal)", | |
| "Nut (fruit seed)", "Nut (metal)"] | |
| """### Select, load and run pre-trained zero-shot multi-modal model | |
| --- | |
| """ | |
| def run_classifer(model_key, image_path, prob_threshold): | |
| # model_key: name of backbone zero-shot-image-classification model to use | |
| # image_path: path to test image | |
| # prob_threshold: confidence (i.e., probability) threshold above which to consider a prediction valid | |
| # if prob_threshold is None: | |
| # prob_threshold = 0.4 | |
| device = 0 if torch.cuda.is_available() else -1 | |
| # Load model (cache for speed) | |
| if model_key not in MODEL_CACHE: | |
| MODEL_CACHE[model_key] = pipeline(task = "zero-shot-image-classification", | |
| model = MODEL_OPTIONS[model_key], | |
| device = device) | |
| classifier = MODEL_CACHE[model_key] | |
| outputs = classifier( | |
| image = image_path, | |
| candidate_labels = CANDIDATE_LABELS, | |
| hypothesis_template = "This image shows {}") | |
| # label_str = f"This image shows {output[0]["label"]}", | |
| # prob_str = f"{100*output[0]["score"]: .1f}%" | |
| label_lookup = dict(zip(CANDIDATE_LABELS, LABELS_MAP)) | |
| # Dictionary mapping all candidate labels to their predicted probabilities | |
| prob_dict = {label_lookup[output['label']]: round(output["score"], 4) for output in outputs} | |
| predicted_label_str = f"This image shows {outputs[0]['label']} | Confidence (probability): {100*outputs[0]['score']:.1f}%" if float(outputs[0]['score']) > prob_threshold else "No prediction" | |
| # predicted_label_str = f"This image shows {outputs[0]['label']} | Confidence (probability): {100*outputs[0]['score']:.1f}%" | |
| return predicted_label_str, prob_dict | |
| """### Gradio App | |
| --- | |
| """ | |
| example_list_dir = os.path.join(os.getcwd(), "Example images") | |
| example_list_img_names = os.listdir(example_list_dir) | |
| example_list = [ | |
| ["CLIP-base", os.path.join(example_list_dir, example_img), 0.4] | |
| for example_img in example_list_img_names | |
| if example_img.lower().endswith(".png")] | |
| gradio_app = gradio.Interface( | |
| fn = run_classifer, | |
| inputs = [gradio.Dropdown(["CLIP-base", "CLIP-large", "ALIGN"], value="CLIP-base", label = "Select Classifier"), | |
| gradio.Image(type="pil", label="Load sample image here"), | |
| gradio.Slider(minimum = 0.1, maximum = 0.9, step = 0.05, value = 0.25, label = "Set Prediction Threshold") | |
| ], | |
| outputs = [gradio.Textbox(label="Image Classification"), | |
| gradio.Label(label="Prediction Probabilities", show_label=False)], | |
| examples = example_list, | |
| cache_examples = False, | |
| title = "SemanticVision", | |
| description = "Vision-language models for zero-shot classification of images of homophones and homographs", | |
| article = "Author: C. Foli (02.2026) | Website: coming soon...") | |
| gradio_app.launch() |