import os import torch import clip import gradio as gr import pandas as pd from PIL import Image DEFAULT_POSITIVE_PROMPTS = [ "A creative data visualization design", "A clear and understandable data visualization", ] DEFAULT_NEGATIVE_PROMPTS = [ "A common and boring data visualization design", "A confused or messy data visualization", "An unrelated or off-topic image", ] def get_device(): return "cuda" if torch.cuda.is_available() else "cpu" def load_model(): device = get_device() model, preprocess = clip.load("ViT-B/32", device=device) return device, model, preprocess device, model, preprocess = load_model() def parse_prompts(prompt_text: str, defaults): if not prompt_text or not prompt_text.strip(): return defaults lines = [line.strip() for line in prompt_text.splitlines() if line.strip()] if len(lines) == 1 and "," in lines[0]: lines = [seg.strip() for seg in lines[0].split(",") if seg.strip()] return lines if lines else defaults def compute_similarity(image: Image.Image, positive_text: str, negative_text: str): if image is None: raise gr.Error("Please upload an image.") positive_prompts = parse_prompts(positive_text, DEFAULT_POSITIVE_PROMPTS) negative_prompts = parse_prompts(negative_text, DEFAULT_NEGATIVE_PROMPTS) prompts = positive_prompts + negative_prompts image = image.convert("RGB") image_input = preprocess(image).unsqueeze(0).to(device) text_tokens = clip.tokenize(prompts).to(device) with torch.no_grad(): logits_per_image, _ = model(image_input, text_tokens) logits = logits_per_image[0].cpu().numpy() probs = logits_per_image.softmax(dim=-1)[0].cpu().numpy() df = pd.DataFrame( { "prompt": prompts, "probability": probs, "logit": logits, } ).sort_values("probability", ascending=False) top_prompt = df.iloc[0]["prompt"] top_prob = df.iloc[0]["probability"] summary = f"Top match: {top_prompt} ({top_prob * 100:.2f}%)" return summary, df description = ( "Upload a single chart image and compare it against text prompts using CLIP. " "Provide your own prompts (one per line or comma-separated), or leave blank to use defaults." ) with gr.Blocks() as demo: gr.Markdown("# CLIP Single-Image Similarity") gr.Markdown(description) with gr.Row(): image_input = gr.Image(type="pil", label="Chart Image") positive_input = gr.Textbox( label="Positive Prompts (optional)", placeholder="One per line or comma-separated", lines=6, value="\n".join(DEFAULT_POSITIVE_PROMPTS), ) negative_input = gr.Textbox( label="Negative Prompts (optional)", placeholder="One per line or comma-separated", lines=6, value="\n".join(DEFAULT_NEGATIVE_PROMPTS), ) run_btn = gr.Button("Run Similarity") summary_out = gr.Textbox(label="Summary", interactive=False) table_out = gr.Dataframe( label="Similarity Scores", headers=["prompt", "probability", "logit"], interactive=False, ) run_btn.click( compute_similarity, inputs=[image_input, positive_input, negative_input], outputs=[summary_out, table_out], ) gr.Examples( examples=[], inputs=[image_input, positive_input, negative_input], label="Examples (add images in the repository to enable)", ) if __name__ == "__main__": demo.launch()