| import os | |
| import torch | |
| import clip | |
| import gradio as gr | |
| import pandas as pd | |
| from PIL import Image | |
| DEFAULT_POSITIVE_PROMPTS = [ | |
| "A creative data visualization design", | |
| "A clear and understandable data visualization", | |
| ] | |
| DEFAULT_NEGATIVE_PROMPTS = [ | |
| "A common and boring data visualization design", | |
| "A confused or messy data visualization", | |
| "An unrelated or off-topic image", | |
| ] | |
| def get_device(): | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(): | |
| device = get_device() | |
| model, preprocess = clip.load("ViT-B/32", device=device) | |
| return device, model, preprocess | |
| device, model, preprocess = load_model() | |
| def parse_prompts(prompt_text: str, defaults): | |
| if not prompt_text or not prompt_text.strip(): | |
| return defaults | |
| lines = [line.strip() for line in prompt_text.splitlines() if line.strip()] | |
| if len(lines) == 1 and "," in lines[0]: | |
| lines = [seg.strip() for seg in lines[0].split(",") if seg.strip()] | |
| return lines if lines else defaults | |
| def compute_similarity(image: Image.Image, positive_text: str, negative_text: str): | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| positive_prompts = parse_prompts(positive_text, DEFAULT_POSITIVE_PROMPTS) | |
| negative_prompts = parse_prompts(negative_text, DEFAULT_NEGATIVE_PROMPTS) | |
| prompts = positive_prompts + negative_prompts | |
| image = image.convert("RGB") | |
| image_input = preprocess(image).unsqueeze(0).to(device) | |
| text_tokens = clip.tokenize(prompts).to(device) | |
| with torch.no_grad(): | |
| logits_per_image, _ = model(image_input, text_tokens) | |
| logits = logits_per_image[0].cpu().numpy() | |
| probs = logits_per_image.softmax(dim=-1)[0].cpu().numpy() | |
| df = pd.DataFrame( | |
| { | |
| "prompt": prompts, | |
| "probability": probs, | |
| "logit": logits, | |
| } | |
| ).sort_values("probability", ascending=False) | |
| top_prompt = df.iloc[0]["prompt"] | |
| top_prob = df.iloc[0]["probability"] | |
| summary = f"Top match: {top_prompt} ({top_prob * 100:.2f}%)" | |
| return summary, df | |
| description = ( | |
| "Upload a single chart image and compare it against text prompts using CLIP. " | |
| "Provide your own prompts (one per line or comma-separated), or leave blank to use defaults." | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# CLIP Single-Image Similarity") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Chart Image") | |
| positive_input = gr.Textbox( | |
| label="Positive Prompts (optional)", | |
| placeholder="One per line or comma-separated", | |
| lines=6, | |
| value="\n".join(DEFAULT_POSITIVE_PROMPTS), | |
| ) | |
| negative_input = gr.Textbox( | |
| label="Negative Prompts (optional)", | |
| placeholder="One per line or comma-separated", | |
| lines=6, | |
| value="\n".join(DEFAULT_NEGATIVE_PROMPTS), | |
| ) | |
| run_btn = gr.Button("Run Similarity") | |
| summary_out = gr.Textbox(label="Summary", interactive=False) | |
| table_out = gr.Dataframe( | |
| label="Similarity Scores", | |
| headers=["prompt", "probability", "logit"], | |
| interactive=False, | |
| ) | |
| run_btn.click( | |
| compute_similarity, | |
| inputs=[image_input, positive_input, negative_input], | |
| outputs=[summary_out, table_out], | |
| ) | |
| gr.Examples( | |
| examples=[], | |
| inputs=[image_input, positive_input, negative_input], | |
| label="Examples (add images in the repository to enable)", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |