Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| import pandas as pd | |
| import time | |
| from PIL import Image | |
| import vlai_template | |
| from src.dam_models import get_dam_original, get_dam_sliding | |
| # App configuration | |
| vlai_template.set_meta( | |
| project_name="DAM-QA Demo", | |
| year="2025", | |
| module="DAM", | |
| description="DAM-QA performance on Visual Question Answering tasks", | |
| meta_items=[ | |
| ("Original DAM", "Full image processing"), | |
| ("DAM-QA", "Sliding window + voting"), | |
| ("Datasets", "DocVQA, InfographicVQA, TextVQA, ChartQA, VQAv2"), | |
| ], | |
| ) | |
| # Global state for models | |
| STATE = { | |
| "dam_original": None, | |
| "dam_sliding": None, | |
| "samples": [] | |
| } | |
| # Load sample data | |
| def load_samples(): | |
| """Load sample questions and images.""" | |
| try: | |
| with open("samples.json", "r") as f: | |
| samples = json.load(f) | |
| STATE["samples"] = samples | |
| return samples | |
| except Exception as e: | |
| print(f"Error loading samples: {e}") | |
| return [] | |
| def init_models(): | |
| """Initialize both DAM models.""" | |
| try: | |
| STATE["dam_original"] = get_dam_original() | |
| STATE["dam_sliding"] = get_dam_sliding() | |
| return "β Both DAM models loaded successfully!" | |
| except Exception as e: | |
| error_msg = f"β Error loading models: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| def get_sample_choices(): | |
| """Get list of sample choices for dropdown.""" | |
| samples = STATE["samples"] | |
| choices = [] | |
| for i, sample in enumerate(samples): | |
| label = f"{sample['dataset']}: {sample['question'][:50]}..." | |
| choices.append((label, i)) | |
| return choices | |
| def fill_from_sample(sample_idx): | |
| """Fill inputs from selected sample.""" | |
| if not STATE["samples"] or sample_idx is None or sample_idx >= len(STATE["samples"]): | |
| return None, "", "", None, "" | |
| sample = STATE["samples"][sample_idx] | |
| # Load the sample image | |
| try: | |
| sample_img = Image.open(sample["image"]) | |
| return ( | |
| sample_img, # sample_image_display | |
| sample["ground_truth"], # ground_truth_display | |
| f"Dataset: {sample['dataset']}\nDescription: {sample['description']}", # sample_info_display | |
| sample_img, # image_input (copy to main input) | |
| sample["question"] # question_input (copy to main input) | |
| ) | |
| except Exception as e: | |
| print(f"Error loading sample image {sample['image']}: {e}") | |
| return None, sample["ground_truth"], f"Error loading image: {e}", None, sample["question"] | |
| def compare_models(image, question, max_tokens): | |
| """Compare both models on the same input.""" | |
| if STATE["dam_original"] is None or STATE["dam_sliding"] is None: | |
| return "β Models not loaded. Please wait for models to initialize.", "", "", None, "" | |
| if image is None: | |
| return "β Please provide an image", "", "", None, "" | |
| if not question or not question.strip(): | |
| return "β Please provide a question", "", "", None, "" | |
| try: | |
| # Convert to PIL Image if needed | |
| if isinstance(image, str): | |
| img = Image.open(image) | |
| elif hasattr(image, 'save'): # PIL Image | |
| img = image | |
| else: | |
| return "β Invalid image format", "", "", None, "" | |
| # DAM Original prediction | |
| original_answer, original_time = STATE["dam_original"].predict( | |
| img, question, max_tokens | |
| ) | |
| # DAM Sliding Window prediction | |
| sliding_answer, sliding_time, voting_details = STATE["dam_sliding"].predict( | |
| img, question, max_tokens | |
| ) | |
| # Format results | |
| original_result = f""" | |
| ### π DAM Original (Full Image) | |
| **Answer:** {original_answer} | |
| **Inference Time:** {original_time:.2f}s | |
| **Method:** Processes the entire image at once | |
| """ | |
| sliding_result = f""" | |
| ### π§© DAM-QA (Sliding Window + Voting) | |
| **Answer:** {sliding_answer} | |
| **Inference Time:** {sliding_time:.2f}s | |
| **Method:** Sliding windows with weighted voting | |
| **Total Windows:** {voting_details.get('total_windows', 'N/A')} | |
| """ | |
| # Create comparison summary | |
| comparison = f""" | |
| ## π Comparison Summary | |
| | Method | Answer | Time (s) | Approach | | |
| |--------|--------|----------|----------| | |
| | DAM Original | {original_answer} | {original_time:.2f} | Full image | | |
| | DAM-QA Sliding | {sliding_answer} | {sliding_time:.2f} | Window + voting | | |
| **Speed Difference:** {abs(original_time - sliding_time):.2f}s | |
| **Faster Method:** {'DAM Original' if original_time < sliding_time else 'DAM-QA'} | |
| """ | |
| # Create voting visualization | |
| vote_fig = create_voting_chart(voting_details) | |
| # Detailed voting info | |
| voting_info = format_voting_details(voting_details) | |
| return comparison, original_result, sliding_result, vote_fig, voting_info | |
| except Exception as e: | |
| error_msg = f"β Error during inference: {str(e)}" | |
| return error_msg, "", "", None, "" | |
| def create_voting_chart(voting_details): | |
| """Create a visualization of the voting process.""" | |
| if not voting_details or "vote_summary" not in voting_details: | |
| return None | |
| votes = voting_details["vote_summary"] | |
| if not votes: | |
| return None | |
| answers = list(votes.keys()) | |
| weights = list(votes.values()) | |
| # Create bar chart | |
| fig = go.Figure(data=[ | |
| go.Bar( | |
| x=answers, | |
| y=weights, | |
| text=[f"{w:.3f}" for w in weights], | |
| textposition='auto', | |
| marker_color=['#C4314B' if ans == voting_details.get('final_answer', '') else '#0F6CBD' for ans in answers] | |
| ) | |
| ]) | |
| fig.update_layout( | |
| title="DAM-QA Voting Results", | |
| xaxis_title="Answers", | |
| yaxis_title="Vote Weight", | |
| plot_bgcolor="white", | |
| paper_bgcolor="white", | |
| font=dict(color="black", size=12), | |
| height=400, | |
| margin=dict(l=30, r=20, t=60, b=40) | |
| ) | |
| return fig | |
| def format_voting_details(voting_details): | |
| """Format detailed voting information.""" | |
| if not voting_details: | |
| return "No voting details available." | |
| details = [] | |
| # Full image vote | |
| if "full_image" in voting_details and voting_details["full_image"]: | |
| full_vote = voting_details["full_image"] | |
| details.append(f"**Full Image Vote:**") | |
| details.append(f"- Answer: {full_vote['answer']}") | |
| details.append(f"- Weight: {full_vote['weight']:.3f}") | |
| details.append("") | |
| # Window votes summary | |
| if "windows" in voting_details: | |
| windows = voting_details["windows"] | |
| details.append(f"**Window Votes:** {len(windows)} windows processed") | |
| # Group by answer | |
| answer_groups = {} | |
| for window in windows: | |
| ans = window["answer"] | |
| if ans not in answer_groups: | |
| answer_groups[ans] = [] | |
| answer_groups[ans].append(window) | |
| for answer, windows_for_ans in answer_groups.items(): | |
| total_weight = sum(w["weight"] for w in windows_for_ans) | |
| details.append(f"- **{answer}**: {len(windows_for_ans)} windows, total weight: {total_weight:.3f}") | |
| details.append("") | |
| # Final summary | |
| if "vote_summary" in voting_details: | |
| details.append("**Final Vote Tally:**") | |
| for answer, weight in voting_details["vote_summary"].items(): | |
| marker = "π" if answer == voting_details.get("final_answer", "") else " " | |
| details.append(f"{marker} {answer}: {weight:.3f}") | |
| return "\n".join(details) | |
| # Force light theme | |
| force_light_theme_js = """ | |
| () => { | |
| const params = new URLSearchParams(window.location.search); | |
| if (!params.has('__theme')) { | |
| params.set('__theme', 'light'); | |
| window.location.search = params.toString(); | |
| } | |
| } | |
| """ | |
| # Main Gradio interface | |
| with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo: | |
| vlai_template.create_header() | |
| gr.HTML(vlai_template.render_info_card( | |
| icon="π€", | |
| title="About this Demo", | |
| description="This demo compares two approaches for Visual Question Answering: DAM (original) processes the full image, while DAM-QA uses a sliding window approach with weighted voting to better handle text-rich images." | |
| )) | |
| gr.HTML(vlai_template.render_disclaimer( | |
| text=( | |
| "This demo is for research and educational purposes only. " | |
| "The models are designed for visual question answering on text-rich images. " | |
| "Results may vary based on image quality and question complexity." | |
| ) | |
| )) | |
| gr.Markdown("### π― **How to Use**: Select a sample or upload your image β Ask a question β Compare both models β Analyze the voting results!") | |
| # Model Status at top | |
| with gr.Accordion("π€ Model Status", open=True): | |
| with gr.Row(): | |
| status_display = gr.Markdown("Loading models...") | |
| refresh_btn = gr.Button("π Refresh Status", variant="secondary", scale=1) | |
| with gr.Row(equal_height=False, variant="panel"): | |
| # LEFT: Input Section | |
| with gr.Column(scale=35): | |
| with gr.Accordion("π€ Upload Image & Question", open=True): | |
| image_input = gr.Image(label="Upload Image", type="pil", height=300) | |
| question_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask a question about the image...", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| max_tokens_slider = gr.Slider( | |
| minimum=10, maximum=200, value=100, step=10, | |
| label="Max Tokens", scale=2 | |
| ) | |
| compare_btn = gr.Button("π Compare Models", variant="primary", size="lg", scale=1) | |
| with gr.Accordion("π Try Sample Images", open=True): | |
| sample_dropdown = gr.Dropdown( | |
| label="Select Sample Dataset", | |
| choices=[], | |
| value=None, | |
| info="Choose a sample to auto-fill the inputs above" | |
| ) | |
| sample_image_display = gr.Image(label="Sample Preview", interactive=False, height=200) | |
| with gr.Row(): | |
| ground_truth_display = gr.Textbox(label="Expected Answer", interactive=False, scale=2) | |
| sample_info_display = gr.Textbox(label="Dataset Info", interactive=False, lines=3, scale=1) | |
| # MIDDLE: Results Comparison | |
| with gr.Column(scale=40): | |
| with gr.Accordion("π Model Comparison Results", open=True): | |
| comparison_output = gr.Markdown("Click 'Compare Models' to see results...") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### π DAM Original") | |
| original_output = gr.Markdown("Results will appear here...") | |
| with gr.Column(): | |
| gr.Markdown("#### π§© DAM-QA Sliding Window") | |
| sliding_output = gr.Markdown("Results will appear here...") | |
| # RIGHT: Voting Analysis | |
| with gr.Column(scale=25): | |
| with gr.Accordion("π³οΈ DAM-QA Voting Analysis", open=True): | |
| voting_chart = gr.Plot(label="Vote Weights") | |
| voting_details = gr.Markdown("Voting details will appear here...", max_height=200) | |
| gr.Markdown(""" | |
| ## π **Key Differences** | |
| - **DAM Original**: Processes the entire image at once, faster but may miss fine details | |
| - **DAM-QA Sliding Window**: Divides image into overlapping windows, slower but better for text-rich images | |
| - **Voting Mechanism**: DAM-QA aggregates predictions from multiple windows using weighted voting | |
| - **Use Cases**: DAM-QA typically performs better on documents, charts, and infographics | |
| """) | |
| vlai_template.create_footer() | |
| # Event handlers | |
| def on_load(): | |
| # Load samples first | |
| samples = load_samples() | |
| choices = [(f"{s['dataset']}: {s['question'][:50]}...", i) for i, s in enumerate(samples)] | |
| # Load models immediately (this will take time but ensures they're ready) | |
| print("Loading DAM models...") | |
| status = init_models() | |
| print(f"Model initialization complete: {status}") | |
| return status, gr.Dropdown(choices=choices, value=0 if choices else None) | |
| def refresh_status(): | |
| """Check current model status.""" | |
| if STATE["dam_original"] is not None and STATE["dam_sliding"] is not None: | |
| return "β Both DAM models loaded successfully!" | |
| else: | |
| return "π Models not loaded. Click to retry." | |
| def retry_loading(): | |
| """Retry loading models.""" | |
| return init_models() | |
| demo.load( | |
| fn=on_load, | |
| outputs=[status_display, sample_dropdown] | |
| ) | |
| # Add refresh button functionality | |
| refresh_btn.click( | |
| fn=refresh_status, | |
| outputs=[status_display] | |
| ) | |
| sample_dropdown.change( | |
| fn=fill_from_sample, | |
| inputs=[sample_dropdown], | |
| outputs=[sample_image_display, ground_truth_display, sample_info_display, image_input, question_input] | |
| ) | |
| compare_btn.click( | |
| fn=compare_models, | |
| inputs=[image_input, question_input, max_tokens_slider], | |
| outputs=[comparison_output, original_output, sliding_output, voting_chart, voting_details] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| show_error=True, | |
| allowed_paths=["sample_images", "static"] | |
| ) |