| import gradio as gr |
| from PIL import Image |
| from similarity_score import JinaV4SimilarityMapper |
| import torch |
| import base64 |
| import io |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| import gradio as gr |
| from PIL import Image |
| |
| from similarity_score import JinaV4SimilarityMapper |
| import torch |
| import base64 |
| import io |
| import logging |
| import os |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
|
|
| |
|
|
| def decode_base64_to_pil(base64_str): |
| """Converts a base64 string to a PIL Image for Gradio display.""" |
| if not base64_str: |
| return None |
| image_data = base64.b64decode(base64_str) |
| return Image.open(io.BytesIO(image_data)) |
|
|
| def update_heatmap_display(selected_token, heatmaps_dict): |
| """ |
| Callback for when a user clicks a token button. |
| """ |
| if not selected_token or not heatmaps_dict or selected_token not in heatmaps_dict: |
| return None |
| |
| b64_str = heatmaps_dict[selected_token] |
| return decode_base64_to_pil(b64_str) |
|
|
| def analyze_multimodal(api_key, source_text, target_text, image_upload, image_url): |
| """ |
| Main execution function. |
| """ |
| |
| if not api_key or not api_key.strip(): |
| raise gr.Error("Please provide a valid Jina API Key.") |
| |
| |
| final_image = None |
| if image_upload is not None: |
| final_image = image_upload |
| elif image_url and image_url.strip(): |
| final_image = image_url.strip() |
| else: |
| raise gr.Error("Please provide an image via Upload or URL.") |
|
|
| if not target_text: |
| raise gr.Error("Target Candidate text is required for heatmap generation.") |
|
|
| try: |
| |
| mapper = JinaV4SimilarityMapper( |
| client_type="web", |
| task="text-matching", |
| device="cpu" |
| ) |
| mapper.model.set_api_key(api_key) |
|
|
| |
| score_results = mapper.calculate_multimodal_consistency( |
| source=source_text, |
| candidate=target_text, |
| image=final_image |
| ) |
|
|
| |
| ui_tokens, heatmaps_dict, _ = mapper.get_token_similarity_maps( |
| query=source_text, |
| image=final_image |
| ) |
|
|
| |
| if ui_tokens: |
| first_token = ui_tokens[0] |
| first_image = decode_base64_to_pil(heatmaps_dict[first_token]) |
| |
| return ( |
| score_results, |
| gr.update(choices=ui_tokens, value=first_token, visible=True), |
| first_image, |
| heatmaps_dict, |
| gr.update(visible=True) |
| ) |
| else: |
| return ( |
| score_results, |
| gr.update(choices=[], visible=False), |
| None, |
| {}, |
| gr.update(visible=True) |
| ) |
|
|
| except Exception as e: |
| logging.error(f"Analysis Failed: {e}") |
| raise gr.Error(f"An unexpected error occurred: {str(e)}") |
|
|
|
|
| |
|
|
| css = """ |
| .token-selector .wrap { |
| gap: 5px; |
| } |
| .token-selector .item { |
| padding: 5px 10px; |
| border-radius: 5px; |
| border: 1px solid #ddd; |
| background: #f9f9f9; |
| } |
| .token-selector .item.selected { |
| background: #ffe0b2; |
| border-color: #ffb74d; |
| font-weight: bold; |
| } |
| """ |
|
|
| with gr.Blocks(title="Multimodal Consistency & Grounding", css=css) as demo: |
| |
| |
| heatmaps_state = gr.State({}) |
|
|
| gr.Markdown( |
| """ |
| # 📐 Multimodal Consistency & Visual Grounding |
| **Jina Embeddings v4**: Evaluate translation quality and visualize word-to-pixel attention. |
| """ |
| ) |
| |
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| api_key_input = gr.Textbox( |
| label="Jina API Key", |
| type="password", |
| placeholder="jina_...", |
| |
| ) |
| |
| source_input = gr.Textbox( |
| label="Source Text", |
| placeholder="Original English text...", |
| lines=2 |
| ) |
| |
| target_input = gr.Textbox( |
| label="Candidate Text (Target)", |
| placeholder="Translated or Candidate text...", |
| lines=2 |
| ) |
| |
| with gr.Tab("Image Upload"): |
| img_upload_input = gr.Image(label="Upload", type="pil") |
| |
| with gr.Tab("Image URL"): |
| img_url_input = gr.Textbox(label="URL", placeholder="https://...") |
|
|
| submit_btn = gr.Button("Analyze & Visualize", variant="primary") |
|
|
| |
| |
| gr.Examples( |
| examples=[ |
| [ |
| "A grey cat is sleeping on a blue velvet sofa.", |
| "Eine graue Katze schläft auf einem blauen Samtsofa.", |
| "cat.png", |
| None |
| ], |
| |
| [ |
| "A grey dog is sleeping on a blue velvet sofa.", |
| "Eine graue Hund schläft auf einem blauen Samtsofa.", |
| "dog.png", |
| None |
| ] |
| ], |
| inputs=[ |
| |
| source_input, |
| target_input, |
| img_upload_input, |
| img_url_input |
| ], |
| label="Click to load Example", |
| cache_examples=False |
| ) |
|
|
| |
| with gr.Column(scale=1): |
| |
| |
| gr.Markdown("### 📊 Consistency Scores") |
| json_output = gr.JSON(label="Metric Results") |
| |
| |
| with gr.Group(visible=False) as visual_group: |
| gr.Markdown("### 👁️ Visual Grounding (Candidate Text)") |
| gr.Markdown("_Click on a word below to see where the model looks in the image._") |
| |
| image_display = gr.Image( |
| label="Heatmap Overlay", |
| type="pil", |
| interactive=False |
| ) |
| |
| token_selector = gr.Radio( |
| choices=[], |
| label="Select Token", |
| interactive=True, |
| elem_classes="token-selector" |
| ) |
|
|
| |
|
|
| submit_btn.click( |
| fn=analyze_multimodal, |
| inputs=[ |
| api_key_input, |
| source_input, |
| target_input, |
| img_upload_input, |
| img_url_input |
| ], |
| outputs=[ |
| json_output, |
| token_selector, |
| image_display, |
| heatmaps_state, |
| visual_group |
| ] |
| ) |
|
|
| token_selector.change( |
| fn=update_heatmap_display, |
| inputs=[token_selector, heatmaps_state], |
| outputs=[image_display] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |
| |
|
|
| def decode_base64_to_pil(base64_str): |
| """Converts a base64 string to a PIL Image for Gradio display.""" |
| if not base64_str: |
| return None |
| image_data = base64.b64decode(base64_str) |
| return Image.open(io.BytesIO(image_data)) |
|
|
| def update_heatmap_display(selected_token, heatmaps_dict): |
| """ |
| Callback for when a user clicks a token button. |
| Retrieves the corresponding base64 heatmap, converts to PIL, and updates the image. |
| """ |
| if not selected_token or not heatmaps_dict or selected_token not in heatmaps_dict: |
| return None |
| |
| b64_str = heatmaps_dict[selected_token] |
| return decode_base64_to_pil(b64_str) |
|
|
| def analyze_multimodal(api_key, source_text, target_text, image_upload, image_url): |
| """ |
| Main execution function. |
| 1. Calculates Consistency Scores (JSON). |
| 2. Generates Visual Grounding Heatmaps for the Target Text. |
| """ |
| |
| if not api_key or not api_key.strip(): |
| raise gr.Error("Please provide a valid Jina API Key.") |
| |
| |
| final_image = None |
| if image_upload is not None: |
| final_image = image_upload |
| elif image_url and image_url.strip(): |
| final_image = image_url.strip() |
| else: |
| raise gr.Error("Please provide an image via Upload or URL.") |
|
|
| if not target_text: |
| raise gr.Error("Target Candidate text is required for heatmap generation.") |
|
|
| try: |
| |
| |
| mapper = JinaV4SimilarityMapper( |
| client_type="web", |
| task="text-matching", |
| device="cpu" |
| ) |
| mapper.model.set_api_key(api_key) |
|
|
| |
| score_results = mapper.calculate_multimodal_consistency( |
| source=source_text, |
| candidate=target_text, |
| image=final_image |
| ) |
|
|
| |
| |
| ui_tokens, heatmaps_dict, _ = mapper.get_token_similarity_maps( |
| query=source_text, |
| image=final_image |
| ) |
|
|
| |
| if ui_tokens: |
| first_token = ui_tokens[0] |
| first_image = decode_base64_to_pil(heatmaps_dict[first_token]) |
| |
| |
| |
| |
| |
| |
| |
| |
| return ( |
| score_results, |
| gr.update(choices=ui_tokens, value=first_token, visible=True), |
| first_image, |
| heatmaps_dict, |
| gr.update(visible=True) |
| ) |
| else: |
| return ( |
| score_results, |
| gr.update(choices=[], visible=False), |
| None, |
| {}, |
| gr.update(visible=True) |
| ) |
|
|
| except Exception as e: |
| logging.error(f"Analysis Failed: {e}") |
| raise gr.Error(f"An unexpected error occurred: {str(e)}") |
|
|
|
|
| |
|
|
| |
| css = """ |
| .token-selector .wrap { |
| gap: 5px; |
| } |
| .token-selector .item { |
| padding: 5px 10px; |
| border-radius: 5px; |
| border: 1px solid #ddd; |
| background: #f9f9f9; |
| } |
| .token-selector .item.selected { |
| background: #ffe0b2; /* Orange highlight for Jina style */ |
| border-color: #ffb74d; |
| font-weight: bold; |
| } |
| """ |
|
|
| with gr.Blocks(title="Multimodal Consistency & Grounding", css=css) as demo: |
| |
| |
| heatmaps_state = gr.State({}) |
|
|
| gr.Markdown( |
| """ |
| # 📐 Multimodal Consistency & Visual Grounding |
| **Jina Embeddings v4**: Evaluate translation quality and visualize word-to-pixel attention. |
| """ |
| ) |
| |
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| api_key_input = gr.Textbox( |
| label="Jina API Key", |
| type="password", |
| placeholder="jina_..." |
| ) |
| |
| source_input = gr.Textbox( |
| label="Source Text", |
| placeholder="Original English text...", |
| lines=2, |
| value="A group of cyclists riding nearby the ocean" |
| ) |
| |
| target_input = gr.Textbox( |
| label="Candidate Text (Target)", |
| placeholder="Translated or Candidate text...", |
| lines=2, |
| value="Eine Gruppe von Radfahrern fährt in der Nähe des Ozeans" |
| ) |
| |
| with gr.Tab("Image Upload"): |
| img_upload_input = gr.Image(label="Upload", type="pil") |
| |
| with gr.Tab("Image URL"): |
| img_url_input = gr.Textbox(label="URL", placeholder="https://...", value = "https://cdn.duvine.com/wp-content/uploads/2016/04/17095703/Slides_mallorca_FOR-WEB.jpg") |
| |
| submit_btn = gr.Button("Analyze & Visualize", variant="primary") |
|
|
| |
| with gr.Column(scale=1): |
| |
| |
| gr.Markdown("### 📊 Consistency Scores") |
| json_output = gr.JSON(label="Metric Results") |
| |
| |
| with gr.Group(visible=False) as visual_group: |
| gr.Markdown("### 👁️ Visual Grounding (Candidate Text)") |
| gr.Markdown("_Click on a word below to see where the model looks in the image._") |
| |
| |
| image_display = gr.Image( |
| label="Heatmap Overlay", |
| type="pil", |
| interactive=False |
| ) |
| |
| |
| token_selector = gr.Radio( |
| choices=[], |
| label="Select Token", |
| interactive=True, |
| elem_classes="token-selector" |
| ) |
|
|
| |
|
|
| |
| submit_btn.click( |
| fn=analyze_multimodal, |
| inputs=[ |
| api_key_input, |
| source_input, |
| target_input, |
| img_upload_input, |
| img_url_input |
| ], |
| outputs=[ |
| json_output, |
| token_selector, |
| image_display, |
| heatmaps_state, |
| visual_group |
| ] |
| ) |
|
|
| |
| |
| token_selector.change( |
| fn=update_heatmap_display, |
| inputs=[token_selector, heatmaps_state], |
| outputs=[image_display] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |