import os import json import gradio as gr import plotly.graph_objects as go import pandas as pd import time from PIL import Image import vlai_template from src.dam_models import get_dam_original, get_dam_sliding # App configuration vlai_template.set_meta( project_name="DAM-QA Demo", year="2025", module="DAM", description="DAM-QA performance on Visual Question Answering tasks", meta_items=[ ("Original DAM", "Full image processing"), ("DAM-QA", "Sliding window + voting"), ("Datasets", "DocVQA, InfographicVQA, TextVQA, ChartQA, VQAv2"), ], ) # Global state for models STATE = { "dam_original": None, "dam_sliding": None, "samples": [] } # Load sample data def load_samples(): """Load sample questions and images.""" try: with open("samples.json", "r") as f: samples = json.load(f) STATE["samples"] = samples return samples except Exception as e: print(f"Error loading samples: {e}") return [] def init_models(): """Initialize both DAM models.""" try: STATE["dam_original"] = get_dam_original() STATE["dam_sliding"] = get_dam_sliding() return "✅ Both DAM models loaded successfully!" except Exception as e: error_msg = f"❌ Error loading models: {str(e)}" print(error_msg) return error_msg def get_sample_choices(): """Get list of sample choices for dropdown.""" samples = STATE["samples"] choices = [] for i, sample in enumerate(samples): label = f"{sample['dataset']}: {sample['question'][:50]}..." choices.append((label, i)) return choices def fill_from_sample(sample_idx): """Fill inputs from selected sample.""" if not STATE["samples"] or sample_idx is None or sample_idx >= len(STATE["samples"]): return None, "", "", None, "" sample = STATE["samples"][sample_idx] # Load the sample image try: sample_img = Image.open(sample["image"]) return ( sample_img, # sample_image_display sample["ground_truth"], # ground_truth_display f"Dataset: {sample['dataset']}\nDescription: {sample['description']}", # sample_info_display sample_img, # image_input (copy to main input) sample["question"] # question_input (copy to main input) ) except Exception as e: print(f"Error loading sample image {sample['image']}: {e}") return None, sample["ground_truth"], f"Error loading image: {e}", None, sample["question"] def compare_models(image, question, max_tokens): """Compare both models on the same input.""" if STATE["dam_original"] is None or STATE["dam_sliding"] is None: return "❌ Models not loaded. Please wait for models to initialize.", "", "", None, "" if image is None: return "❌ Please provide an image", "", "", None, "" if not question or not question.strip(): return "❌ Please provide a question", "", "", None, "" try: # Convert to PIL Image if needed if isinstance(image, str): img = Image.open(image) elif hasattr(image, 'save'): # PIL Image img = image else: return "❌ Invalid image format", "", "", None, "" # DAM Original prediction original_answer, original_time = STATE["dam_original"].predict( img, question, max_tokens ) # DAM Sliding Window prediction sliding_answer, sliding_time, voting_details = STATE["dam_sliding"].predict( img, question, max_tokens ) # Format results original_result = f""" ### 🔍 DAM Original (Full Image) **Answer:** {original_answer} **Inference Time:** {original_time:.2f}s **Method:** Processes the entire image at once """ sliding_result = f""" ### 🧩 DAM-QA (Sliding Window + Voting) **Answer:** {sliding_answer} **Inference Time:** {sliding_time:.2f}s **Method:** Sliding windows with weighted voting **Total Windows:** {voting_details.get('total_windows', 'N/A')} """ # Create comparison summary comparison = f""" ## 📊 Comparison Summary | Method | Answer | Time (s) | Approach | |--------|--------|----------|----------| | DAM Original | {original_answer} | {original_time:.2f} | Full image | | DAM-QA Sliding | {sliding_answer} | {sliding_time:.2f} | Window + voting | **Speed Difference:** {abs(original_time - sliding_time):.2f}s **Faster Method:** {'DAM Original' if original_time < sliding_time else 'DAM-QA'} """ # Create voting visualization vote_fig = create_voting_chart(voting_details) # Detailed voting info voting_info = format_voting_details(voting_details) return comparison, original_result, sliding_result, vote_fig, voting_info except Exception as e: error_msg = f"❌ Error during inference: {str(e)}" return error_msg, "", "", None, "" def create_voting_chart(voting_details): """Create a visualization of the voting process.""" if not voting_details or "vote_summary" not in voting_details: return None votes = voting_details["vote_summary"] if not votes: return None answers = list(votes.keys()) weights = list(votes.values()) # Create bar chart fig = go.Figure(data=[ go.Bar( x=answers, y=weights, text=[f"{w:.3f}" for w in weights], textposition='auto', marker_color=['#C4314B' if ans == voting_details.get('final_answer', '') else '#0F6CBD' for ans in answers] ) ]) fig.update_layout( title="DAM-QA Voting Results", xaxis_title="Answers", yaxis_title="Vote Weight", plot_bgcolor="white", paper_bgcolor="white", font=dict(color="black", size=12), height=400, margin=dict(l=30, r=20, t=60, b=40) ) return fig def format_voting_details(voting_details): """Format detailed voting information.""" if not voting_details: return "No voting details available." details = [] # Full image vote if "full_image" in voting_details and voting_details["full_image"]: full_vote = voting_details["full_image"] details.append(f"**Full Image Vote:**") details.append(f"- Answer: {full_vote['answer']}") details.append(f"- Weight: {full_vote['weight']:.3f}") details.append("") # Window votes summary if "windows" in voting_details: windows = voting_details["windows"] details.append(f"**Window Votes:** {len(windows)} windows processed") # Group by answer answer_groups = {} for window in windows: ans = window["answer"] if ans not in answer_groups: answer_groups[ans] = [] answer_groups[ans].append(window) for answer, windows_for_ans in answer_groups.items(): total_weight = sum(w["weight"] for w in windows_for_ans) details.append(f"- **{answer}**: {len(windows_for_ans)} windows, total weight: {total_weight:.3f}") details.append("") # Final summary if "vote_summary" in voting_details: details.append("**Final Vote Tally:**") for answer, weight in voting_details["vote_summary"].items(): marker = "🏆" if answer == voting_details.get("final_answer", "") else " " details.append(f"{marker} {answer}: {weight:.3f}") return "\n".join(details) # Force light theme force_light_theme_js = """ () => { const params = new URLSearchParams(window.location.search); if (!params.has('__theme')) { params.set('__theme', 'light'); window.location.search = params.toString(); } } """ # Main Gradio interface with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo: vlai_template.create_header() gr.HTML(vlai_template.render_info_card( icon="🤖", title="About this Demo", description="This demo compares two approaches for Visual Question Answering: DAM (original) processes the full image, while DAM-QA uses a sliding window approach with weighted voting to better handle text-rich images." )) gr.HTML(vlai_template.render_disclaimer( text=( "This demo is for research and educational purposes only. " "The models are designed for visual question answering on text-rich images. " "Results may vary based on image quality and question complexity." ) )) gr.Markdown("### 🎯 **How to Use**: Select a sample or upload your image → Ask a question → Compare both models → Analyze the voting results!") # Model Status at top with gr.Accordion("🤖 Model Status", open=True): with gr.Row(): status_display = gr.Markdown("Loading models...") refresh_btn = gr.Button("🔄 Refresh Status", variant="secondary", scale=1) with gr.Row(equal_height=False, variant="panel"): # LEFT: Input Section with gr.Column(scale=35): with gr.Accordion("📤 Upload Image & Question", open=True): image_input = gr.Image(label="Upload Image", type="pil", height=300) question_input = gr.Textbox( label="Your Question", placeholder="Ask a question about the image...", lines=3 ) with gr.Row(): max_tokens_slider = gr.Slider( minimum=10, maximum=200, value=100, step=10, label="Max Tokens", scale=2 ) compare_btn = gr.Button("🔍 Compare Models", variant="primary", size="lg", scale=1) with gr.Accordion("📋 Try Sample Images", open=True): sample_dropdown = gr.Dropdown( label="Select Sample Dataset", choices=[], value=None, info="Choose a sample to auto-fill the inputs above" ) sample_image_display = gr.Image(label="Sample Preview", interactive=False, height=200) with gr.Row(): ground_truth_display = gr.Textbox(label="Expected Answer", interactive=False, scale=2) sample_info_display = gr.Textbox(label="Dataset Info", interactive=False, lines=3, scale=1) # MIDDLE: Results Comparison with gr.Column(scale=40): with gr.Accordion("📊 Model Comparison Results", open=True): comparison_output = gr.Markdown("Click 'Compare Models' to see results...") with gr.Row(): with gr.Column(): gr.Markdown("#### 🔍 DAM Original") original_output = gr.Markdown("Results will appear here...") with gr.Column(): gr.Markdown("#### 🧩 DAM-QA Sliding Window") sliding_output = gr.Markdown("Results will appear here...") # RIGHT: Voting Analysis with gr.Column(scale=25): with gr.Accordion("🗳️ DAM-QA Voting Analysis", open=True): voting_chart = gr.Plot(label="Vote Weights") voting_details = gr.Markdown("Voting details will appear here...", max_height=200) gr.Markdown(""" ## 📋 **Key Differences** - **DAM Original**: Processes the entire image at once, faster but may miss fine details - **DAM-QA Sliding Window**: Divides image into overlapping windows, slower but better for text-rich images - **Voting Mechanism**: DAM-QA aggregates predictions from multiple windows using weighted voting - **Use Cases**: DAM-QA typically performs better on documents, charts, and infographics """) vlai_template.create_footer() # Event handlers def on_load(): # Load samples first samples = load_samples() choices = [(f"{s['dataset']}: {s['question'][:50]}...", i) for i, s in enumerate(samples)] # Load models immediately (this will take time but ensures they're ready) print("Loading DAM models...") status = init_models() print(f"Model initialization complete: {status}") return status, gr.Dropdown(choices=choices, value=0 if choices else None) def refresh_status(): """Check current model status.""" if STATE["dam_original"] is not None and STATE["dam_sliding"] is not None: return "✅ Both DAM models loaded successfully!" else: return "🔄 Models not loaded. Click to retry." def retry_loading(): """Retry loading models.""" return init_models() demo.load( fn=on_load, outputs=[status_display, sample_dropdown] ) # Add refresh button functionality refresh_btn.click( fn=refresh_status, outputs=[status_display] ) sample_dropdown.change( fn=fill_from_sample, inputs=[sample_dropdown], outputs=[sample_image_display, ground_truth_display, sample_info_display, image_input, question_input] ) compare_btn.click( fn=compare_models, inputs=[image_input, question_input, max_tokens_slider], outputs=[comparison_output, original_output, sliding_output, voting_chart, voting_details] ) if __name__ == "__main__": demo.launch( share=False, show_error=True, allowed_paths=["sample_images", "static"] )