samiulhaq's picture
Update app.py
0a3ca6f verified
import gradio as gr
from PIL import Image
from similarity_score import JinaV4SimilarityMapper
import torch
import base64
import io
import logging
# Configure Logging
logging.basicConfig(level=logging.INFO)
import gradio as gr
from PIL import Image
# Assuming similarity_score is a local module you have
from similarity_score import JinaV4SimilarityMapper
import torch
import base64
import io
import logging
import os
# Configure Logging
logging.basicConfig(level=logging.INFO)
# --- Helper Functions ---
def decode_base64_to_pil(base64_str):
"""Converts a base64 string to a PIL Image for Gradio display."""
if not base64_str:
return None
image_data = base64.b64decode(base64_str)
return Image.open(io.BytesIO(image_data))
def update_heatmap_display(selected_token, heatmaps_dict):
"""
Callback for when a user clicks a token button.
"""
if not selected_token or not heatmaps_dict or selected_token not in heatmaps_dict:
return None
b64_str = heatmaps_dict[selected_token]
return decode_base64_to_pil(b64_str)
def analyze_multimodal(api_key, source_text, target_text, image_upload, image_url):
"""
Main execution function.
"""
# 1. Validation & Setup
if not api_key or not api_key.strip():
raise gr.Error("Please provide a valid Jina API Key.")
# Determine Image Source (Priority: Upload > URL)
final_image = None
if image_upload is not None:
final_image = image_upload
elif image_url and image_url.strip():
final_image = image_url.strip()
else:
raise gr.Error("Please provide an image via Upload or URL.")
if not target_text:
raise gr.Error("Target Candidate text is required for heatmap generation.")
try:
# 2. Initialize Mapper
mapper = JinaV4SimilarityMapper(
client_type="web",
task="text-matching",
device="cpu"
)
mapper.model.set_api_key(api_key)
# 3. Step A: Calculate Scores (The Triangle)
score_results = mapper.calculate_multimodal_consistency(
source=source_text,
candidate=target_text,
image=final_image
)
# 4. Step B: Generate Heatmaps for the Candidate (Target) Text
ui_tokens, heatmaps_dict, _ = mapper.get_token_similarity_maps(
query=source_text,
image=final_image
)
# 5. Prepare Initial State for UI
if ui_tokens:
first_token = ui_tokens[0]
first_image = decode_base64_to_pil(heatmaps_dict[first_token])
return (
score_results,
gr.update(choices=ui_tokens, value=first_token, visible=True),
first_image,
heatmaps_dict,
gr.update(visible=True)
)
else:
return (
score_results,
gr.update(choices=[], visible=False),
None,
{},
gr.update(visible=True)
)
except Exception as e:
logging.error(f"Analysis Failed: {e}")
raise gr.Error(f"An unexpected error occurred: {str(e)}")
# --- Gradio UI Layout ---
css = """
.token-selector .wrap {
gap: 5px;
}
.token-selector .item {
padding: 5px 10px;
border-radius: 5px;
border: 1px solid #ddd;
background: #f9f9f9;
}
.token-selector .item.selected {
background: #ffe0b2;
border-color: #ffb74d;
font-weight: bold;
}
"""
with gr.Blocks(title="Multimodal Consistency & Grounding", css=css) as demo:
# State to hold the heavy base64 strings
heatmaps_state = gr.State({})
gr.Markdown(
"""
# 📐 Multimodal Consistency & Visual Grounding
**Jina Embeddings v4**: Evaluate translation quality and visualize word-to-pixel attention.
"""
)
with gr.Row():
# --- Left Column: Inputs ---
with gr.Column(scale=1):
api_key_input = gr.Textbox(
label="Jina API Key",
type="password",
placeholder="jina_...",
# Keeping your default key here
)
source_input = gr.Textbox(
label="Source Text",
placeholder="Original English text...",
lines=2
)
target_input = gr.Textbox(
label="Candidate Text (Target)",
placeholder="Translated or Candidate text...",
lines=2
)
with gr.Tab("Image Upload"):
img_upload_input = gr.Image(label="Upload", type="pil")
with gr.Tab("Image URL"):
img_url_input = gr.Textbox(label="URL", placeholder="https://...")
submit_btn = gr.Button("Analyze & Visualize", variant="primary")
# --- EXAMPLES SECTION ADDED HERE ---
# Ensure 'cat.png' is in the same directory as this script
gr.Examples(
examples=[
[
"A grey cat is sleeping on a blue velvet sofa.",
"Eine graue Katze schläft auf einem blauen Samtsofa.",
"cat.png", # This must map to img_upload_input
None # This maps to img_url_input
],
[
"A grey dog is sleeping on a blue velvet sofa.",
"Eine graue Hund schläft auf einem blauen Samtsofa.",
"dog.png", # This must map to img_upload_input
None # This maps to img_url_input
]
],
inputs=[
source_input,
target_input,
img_upload_input,
img_url_input
],
label="Click to load Example",
cache_examples=False # Set to True if you want to pre-compute (requires API access on launch)
)
# --- Right Column: Outputs ---
with gr.Column(scale=1):
# 1. Scores Section
gr.Markdown("### 📊 Consistency Scores")
json_output = gr.JSON(label="Metric Results")
# 2. Visual Grounding Section (Hidden until run)
with gr.Group(visible=False) as visual_group:
gr.Markdown("### 👁️ Visual Grounding (Candidate Text)")
gr.Markdown("_Click on a word below to see where the model looks in the image._")
image_display = gr.Image(
label="Heatmap Overlay",
type="pil",
interactive=False
)
token_selector = gr.Radio(
choices=[],
label="Select Token",
interactive=True,
elem_classes="token-selector"
)
# --- Event Wiring ---
submit_btn.click(
fn=analyze_multimodal,
inputs=[
api_key_input,
source_input,
target_input,
img_upload_input,
img_url_input
],
outputs=[
json_output,
token_selector,
image_display,
heatmaps_state,
visual_group
]
)
token_selector.change(
fn=update_heatmap_display,
inputs=[token_selector, heatmaps_state],
outputs=[image_display]
)
# --- Launch ---
if __name__ == "__main__":
demo.launch()
# --- Helper Functions ---
def decode_base64_to_pil(base64_str):
"""Converts a base64 string to a PIL Image for Gradio display."""
if not base64_str:
return None
image_data = base64.b64decode(base64_str)
return Image.open(io.BytesIO(image_data))
def update_heatmap_display(selected_token, heatmaps_dict):
"""
Callback for when a user clicks a token button.
Retrieves the corresponding base64 heatmap, converts to PIL, and updates the image.
"""
if not selected_token or not heatmaps_dict or selected_token not in heatmaps_dict:
return None
b64_str = heatmaps_dict[selected_token]
return decode_base64_to_pil(b64_str)
def analyze_multimodal(api_key, source_text, target_text, image_upload, image_url):
"""
Main execution function.
1. Calculates Consistency Scores (JSON).
2. Generates Visual Grounding Heatmaps for the Target Text.
"""
# 1. Validation & Setup
if not api_key or not api_key.strip():
raise gr.Error("Please provide a valid Jina API Key.")
# Determine Image Source (Priority: Upload > URL)
final_image = None
if image_upload is not None:
final_image = image_upload
elif image_url and image_url.strip():
final_image = image_url.strip()
else:
raise gr.Error("Please provide an image via Upload or URL.")
if not target_text:
raise gr.Error("Target Candidate text is required for heatmap generation.")
try:
# 2. Initialize Mapper
# Force 'web' client to use the API key and 'text-matching' for semantic comparison
mapper = JinaV4SimilarityMapper(
client_type="web",
task="text-matching",
device="cpu"
)
mapper.model.set_api_key(api_key)
# 3. Step A: Calculate Scores (The Triangle)
score_results = mapper.calculate_multimodal_consistency(
source=source_text,
candidate=target_text,
image=final_image
)
# 4. Step B: Generate Heatmaps for the Candidate (Target) Text
# We want to see how the Candidate grounds to the image
ui_tokens, heatmaps_dict, _ = mapper.get_token_similarity_maps(
query=source_text,
image=final_image
)
# 5. Prepare Initial State for UI
if ui_tokens:
first_token = ui_tokens[0]
first_image = decode_base64_to_pil(heatmaps_dict[first_token])
# Return values mapping to outputs:
# 1. JSON Scores
# 2. Radio Button Choices (Tokens)
# 3. Radio Button Value (First Token)
# 4. Heatmap Image (First Image)
# 5. State: Heatmaps Dict
# 6. Visibility update for the container
return (
score_results,
gr.update(choices=ui_tokens, value=first_token, visible=True),
first_image,
heatmaps_dict,
gr.update(visible=True)
)
else:
return (
score_results,
gr.update(choices=[], visible=False),
None,
{},
gr.update(visible=True)
)
except Exception as e:
logging.error(f"Analysis Failed: {e}")
raise gr.Error(f"An unexpected error occurred: {str(e)}")
# --- Gradio UI Layout ---
# Custom CSS to make the Radio buttons look like tags/chips
css = """
.token-selector .wrap {
gap: 5px;
}
.token-selector .item {
padding: 5px 10px;
border-radius: 5px;
border: 1px solid #ddd;
background: #f9f9f9;
}
.token-selector .item.selected {
background: #ffe0b2; /* Orange highlight for Jina style */
border-color: #ffb74d;
font-weight: bold;
}
"""
with gr.Blocks(title="Multimodal Consistency & Grounding", css=css) as demo:
# State to hold the heavy base64 strings so we don't re-compute on every click
heatmaps_state = gr.State({})
gr.Markdown(
"""
# 📐 Multimodal Consistency & Visual Grounding
**Jina Embeddings v4**: Evaluate translation quality and visualize word-to-pixel attention.
"""
)
with gr.Row():
# --- Left Column: Inputs ---
with gr.Column(scale=1):
api_key_input = gr.Textbox(
label="Jina API Key",
type="password",
placeholder="jina_..."
)
source_input = gr.Textbox(
label="Source Text",
placeholder="Original English text...",
lines=2,
value="A group of cyclists riding nearby the ocean"
)
target_input = gr.Textbox(
label="Candidate Text (Target)",
placeholder="Translated or Candidate text...",
lines=2,
value="Eine Gruppe von Radfahrern fährt in der Nähe des Ozeans"
)
with gr.Tab("Image Upload"):
img_upload_input = gr.Image(label="Upload", type="pil")
with gr.Tab("Image URL"):
img_url_input = gr.Textbox(label="URL", placeholder="https://...", value = "https://cdn.duvine.com/wp-content/uploads/2016/04/17095703/Slides_mallorca_FOR-WEB.jpg")
submit_btn = gr.Button("Analyze & Visualize", variant="primary")
# --- Right Column: Outputs ---
with gr.Column(scale=1):
# 1. Scores Section
gr.Markdown("### 📊 Consistency Scores")
json_output = gr.JSON(label="Metric Results")
# 2. Visual Grounding Section (Hidden until run)
with gr.Group(visible=False) as visual_group:
gr.Markdown("### 👁️ Visual Grounding (Candidate Text)")
gr.Markdown("_Click on a word below to see where the model looks in the image._")
# The Image Display
image_display = gr.Image(
label="Heatmap Overlay",
type="pil",
interactive=False
)
# The Interactive Token Selector
token_selector = gr.Radio(
choices=[],
label="Select Token",
interactive=True,
elem_classes="token-selector" # apply CSS
)
# --- Event Wiring ---
# 1. Main Analysis Event
submit_btn.click(
fn=analyze_multimodal,
inputs=[
api_key_input,
source_input,
target_input,
img_upload_input,
img_url_input
],
outputs=[
json_output, # The Score JSON
token_selector, # The Radio Buttons
image_display, # The Image (sets the first one)
heatmaps_state, # The hidden state
visual_group # Visibility toggle
]
)
# 2. Token Click Event
# When user clicks a word, update the image from State
token_selector.change(
fn=update_heatmap_display,
inputs=[token_selector, heatmaps_state],
outputs=[image_display]
)
# --- Launch ---
if __name__ == "__main__":
demo.launch()