clip_test / app.py
Hangyang Shen
nice
c583e10
import os
import torch
import clip
import gradio as gr
import pandas as pd
from PIL import Image
DEFAULT_POSITIVE_PROMPTS = [
"A creative data visualization design",
"A clear and understandable data visualization",
]
DEFAULT_NEGATIVE_PROMPTS = [
"A common and boring data visualization design",
"A confused or messy data visualization",
"An unrelated or off-topic image",
]
def get_device():
return "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
device = get_device()
model, preprocess = clip.load("ViT-B/32", device=device)
return device, model, preprocess
device, model, preprocess = load_model()
def parse_prompts(prompt_text: str, defaults):
if not prompt_text or not prompt_text.strip():
return defaults
lines = [line.strip() for line in prompt_text.splitlines() if line.strip()]
if len(lines) == 1 and "," in lines[0]:
lines = [seg.strip() for seg in lines[0].split(",") if seg.strip()]
return lines if lines else defaults
def compute_similarity(image: Image.Image, positive_text: str, negative_text: str):
if image is None:
raise gr.Error("Please upload an image.")
positive_prompts = parse_prompts(positive_text, DEFAULT_POSITIVE_PROMPTS)
negative_prompts = parse_prompts(negative_text, DEFAULT_NEGATIVE_PROMPTS)
prompts = positive_prompts + negative_prompts
image = image.convert("RGB")
image_input = preprocess(image).unsqueeze(0).to(device)
text_tokens = clip.tokenize(prompts).to(device)
with torch.no_grad():
logits_per_image, _ = model(image_input, text_tokens)
logits = logits_per_image[0].cpu().numpy()
probs = logits_per_image.softmax(dim=-1)[0].cpu().numpy()
df = pd.DataFrame(
{
"prompt": prompts,
"probability": probs,
"logit": logits,
}
).sort_values("probability", ascending=False)
top_prompt = df.iloc[0]["prompt"]
top_prob = df.iloc[0]["probability"]
summary = f"Top match: {top_prompt} ({top_prob * 100:.2f}%)"
return summary, df
description = (
"Upload a single chart image and compare it against text prompts using CLIP. "
"Provide your own prompts (one per line or comma-separated), or leave blank to use defaults."
)
with gr.Blocks() as demo:
gr.Markdown("# CLIP Single-Image Similarity")
gr.Markdown(description)
with gr.Row():
image_input = gr.Image(type="pil", label="Chart Image")
positive_input = gr.Textbox(
label="Positive Prompts (optional)",
placeholder="One per line or comma-separated",
lines=6,
value="\n".join(DEFAULT_POSITIVE_PROMPTS),
)
negative_input = gr.Textbox(
label="Negative Prompts (optional)",
placeholder="One per line or comma-separated",
lines=6,
value="\n".join(DEFAULT_NEGATIVE_PROMPTS),
)
run_btn = gr.Button("Run Similarity")
summary_out = gr.Textbox(label="Summary", interactive=False)
table_out = gr.Dataframe(
label="Similarity Scores",
headers=["prompt", "probability", "logit"],
interactive=False,
)
run_btn.click(
compute_similarity,
inputs=[image_input, positive_input, negative_input],
outputs=[summary_out, table_out],
)
gr.Examples(
examples=[],
inputs=[image_input, positive_input, negative_input],
label="Examples (add images in the repository to enable)",
)
if __name__ == "__main__":
demo.launch()