DAM-QA_Demo / app.py
duongtruongbinh's picture
Initial commit
3fd9d26
import os
import json
import gradio as gr
import plotly.graph_objects as go
import pandas as pd
import time
from PIL import Image
import vlai_template
from src.dam_models import get_dam_original, get_dam_sliding
# App configuration
vlai_template.set_meta(
project_name="DAM-QA Demo",
year="2025",
module="DAM",
description="DAM-QA performance on Visual Question Answering tasks",
meta_items=[
("Original DAM", "Full image processing"),
("DAM-QA", "Sliding window + voting"),
("Datasets", "DocVQA, InfographicVQA, TextVQA, ChartQA, VQAv2"),
],
)
# Global state for models
STATE = {
"dam_original": None,
"dam_sliding": None,
"samples": []
}
# Load sample data
def load_samples():
"""Load sample questions and images."""
try:
with open("samples.json", "r") as f:
samples = json.load(f)
STATE["samples"] = samples
return samples
except Exception as e:
print(f"Error loading samples: {e}")
return []
def init_models():
"""Initialize both DAM models."""
try:
STATE["dam_original"] = get_dam_original()
STATE["dam_sliding"] = get_dam_sliding()
return "βœ… Both DAM models loaded successfully!"
except Exception as e:
error_msg = f"❌ Error loading models: {str(e)}"
print(error_msg)
return error_msg
def get_sample_choices():
"""Get list of sample choices for dropdown."""
samples = STATE["samples"]
choices = []
for i, sample in enumerate(samples):
label = f"{sample['dataset']}: {sample['question'][:50]}..."
choices.append((label, i))
return choices
def fill_from_sample(sample_idx):
"""Fill inputs from selected sample."""
if not STATE["samples"] or sample_idx is None or sample_idx >= len(STATE["samples"]):
return None, "", "", None, ""
sample = STATE["samples"][sample_idx]
# Load the sample image
try:
sample_img = Image.open(sample["image"])
return (
sample_img, # sample_image_display
sample["ground_truth"], # ground_truth_display
f"Dataset: {sample['dataset']}\nDescription: {sample['description']}", # sample_info_display
sample_img, # image_input (copy to main input)
sample["question"] # question_input (copy to main input)
)
except Exception as e:
print(f"Error loading sample image {sample['image']}: {e}")
return None, sample["ground_truth"], f"Error loading image: {e}", None, sample["question"]
def compare_models(image, question, max_tokens):
"""Compare both models on the same input."""
if STATE["dam_original"] is None or STATE["dam_sliding"] is None:
return "❌ Models not loaded. Please wait for models to initialize.", "", "", None, ""
if image is None:
return "❌ Please provide an image", "", "", None, ""
if not question or not question.strip():
return "❌ Please provide a question", "", "", None, ""
try:
# Convert to PIL Image if needed
if isinstance(image, str):
img = Image.open(image)
elif hasattr(image, 'save'): # PIL Image
img = image
else:
return "❌ Invalid image format", "", "", None, ""
# DAM Original prediction
original_answer, original_time = STATE["dam_original"].predict(
img, question, max_tokens
)
# DAM Sliding Window prediction
sliding_answer, sliding_time, voting_details = STATE["dam_sliding"].predict(
img, question, max_tokens
)
# Format results
original_result = f"""
### πŸ” DAM Original (Full Image)
**Answer:** {original_answer}
**Inference Time:** {original_time:.2f}s
**Method:** Processes the entire image at once
"""
sliding_result = f"""
### 🧩 DAM-QA (Sliding Window + Voting)
**Answer:** {sliding_answer}
**Inference Time:** {sliding_time:.2f}s
**Method:** Sliding windows with weighted voting
**Total Windows:** {voting_details.get('total_windows', 'N/A')}
"""
# Create comparison summary
comparison = f"""
## πŸ“Š Comparison Summary
| Method | Answer | Time (s) | Approach |
|--------|--------|----------|----------|
| DAM Original | {original_answer} | {original_time:.2f} | Full image |
| DAM-QA Sliding | {sliding_answer} | {sliding_time:.2f} | Window + voting |
**Speed Difference:** {abs(original_time - sliding_time):.2f}s
**Faster Method:** {'DAM Original' if original_time < sliding_time else 'DAM-QA'}
"""
# Create voting visualization
vote_fig = create_voting_chart(voting_details)
# Detailed voting info
voting_info = format_voting_details(voting_details)
return comparison, original_result, sliding_result, vote_fig, voting_info
except Exception as e:
error_msg = f"❌ Error during inference: {str(e)}"
return error_msg, "", "", None, ""
def create_voting_chart(voting_details):
"""Create a visualization of the voting process."""
if not voting_details or "vote_summary" not in voting_details:
return None
votes = voting_details["vote_summary"]
if not votes:
return None
answers = list(votes.keys())
weights = list(votes.values())
# Create bar chart
fig = go.Figure(data=[
go.Bar(
x=answers,
y=weights,
text=[f"{w:.3f}" for w in weights],
textposition='auto',
marker_color=['#C4314B' if ans == voting_details.get('final_answer', '') else '#0F6CBD' for ans in answers]
)
])
fig.update_layout(
title="DAM-QA Voting Results",
xaxis_title="Answers",
yaxis_title="Vote Weight",
plot_bgcolor="white",
paper_bgcolor="white",
font=dict(color="black", size=12),
height=400,
margin=dict(l=30, r=20, t=60, b=40)
)
return fig
def format_voting_details(voting_details):
"""Format detailed voting information."""
if not voting_details:
return "No voting details available."
details = []
# Full image vote
if "full_image" in voting_details and voting_details["full_image"]:
full_vote = voting_details["full_image"]
details.append(f"**Full Image Vote:**")
details.append(f"- Answer: {full_vote['answer']}")
details.append(f"- Weight: {full_vote['weight']:.3f}")
details.append("")
# Window votes summary
if "windows" in voting_details:
windows = voting_details["windows"]
details.append(f"**Window Votes:** {len(windows)} windows processed")
# Group by answer
answer_groups = {}
for window in windows:
ans = window["answer"]
if ans not in answer_groups:
answer_groups[ans] = []
answer_groups[ans].append(window)
for answer, windows_for_ans in answer_groups.items():
total_weight = sum(w["weight"] for w in windows_for_ans)
details.append(f"- **{answer}**: {len(windows_for_ans)} windows, total weight: {total_weight:.3f}")
details.append("")
# Final summary
if "vote_summary" in voting_details:
details.append("**Final Vote Tally:**")
for answer, weight in voting_details["vote_summary"].items():
marker = "πŸ†" if answer == voting_details.get("final_answer", "") else " "
details.append(f"{marker} {answer}: {weight:.3f}")
return "\n".join(details)
# Force light theme
force_light_theme_js = """
() => {
const params = new URLSearchParams(window.location.search);
if (!params.has('__theme')) {
params.set('__theme', 'light');
window.location.search = params.toString();
}
}
"""
# Main Gradio interface
with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo:
vlai_template.create_header()
gr.HTML(vlai_template.render_info_card(
icon="πŸ€–",
title="About this Demo",
description="This demo compares two approaches for Visual Question Answering: DAM (original) processes the full image, while DAM-QA uses a sliding window approach with weighted voting to better handle text-rich images."
))
gr.HTML(vlai_template.render_disclaimer(
text=(
"This demo is for research and educational purposes only. "
"The models are designed for visual question answering on text-rich images. "
"Results may vary based on image quality and question complexity."
)
))
gr.Markdown("### 🎯 **How to Use**: Select a sample or upload your image β†’ Ask a question β†’ Compare both models β†’ Analyze the voting results!")
# Model Status at top
with gr.Accordion("πŸ€– Model Status", open=True):
with gr.Row():
status_display = gr.Markdown("Loading models...")
refresh_btn = gr.Button("πŸ”„ Refresh Status", variant="secondary", scale=1)
with gr.Row(equal_height=False, variant="panel"):
# LEFT: Input Section
with gr.Column(scale=35):
with gr.Accordion("πŸ“€ Upload Image & Question", open=True):
image_input = gr.Image(label="Upload Image", type="pil", height=300)
question_input = gr.Textbox(
label="Your Question",
placeholder="Ask a question about the image...",
lines=3
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=10, maximum=200, value=100, step=10,
label="Max Tokens", scale=2
)
compare_btn = gr.Button("πŸ” Compare Models", variant="primary", size="lg", scale=1)
with gr.Accordion("πŸ“‹ Try Sample Images", open=True):
sample_dropdown = gr.Dropdown(
label="Select Sample Dataset",
choices=[],
value=None,
info="Choose a sample to auto-fill the inputs above"
)
sample_image_display = gr.Image(label="Sample Preview", interactive=False, height=200)
with gr.Row():
ground_truth_display = gr.Textbox(label="Expected Answer", interactive=False, scale=2)
sample_info_display = gr.Textbox(label="Dataset Info", interactive=False, lines=3, scale=1)
# MIDDLE: Results Comparison
with gr.Column(scale=40):
with gr.Accordion("πŸ“Š Model Comparison Results", open=True):
comparison_output = gr.Markdown("Click 'Compare Models' to see results...")
with gr.Row():
with gr.Column():
gr.Markdown("#### πŸ” DAM Original")
original_output = gr.Markdown("Results will appear here...")
with gr.Column():
gr.Markdown("#### 🧩 DAM-QA Sliding Window")
sliding_output = gr.Markdown("Results will appear here...")
# RIGHT: Voting Analysis
with gr.Column(scale=25):
with gr.Accordion("πŸ—³οΈ DAM-QA Voting Analysis", open=True):
voting_chart = gr.Plot(label="Vote Weights")
voting_details = gr.Markdown("Voting details will appear here...", max_height=200)
gr.Markdown("""
## πŸ“‹ **Key Differences**
- **DAM Original**: Processes the entire image at once, faster but may miss fine details
- **DAM-QA Sliding Window**: Divides image into overlapping windows, slower but better for text-rich images
- **Voting Mechanism**: DAM-QA aggregates predictions from multiple windows using weighted voting
- **Use Cases**: DAM-QA typically performs better on documents, charts, and infographics
""")
vlai_template.create_footer()
# Event handlers
def on_load():
# Load samples first
samples = load_samples()
choices = [(f"{s['dataset']}: {s['question'][:50]}...", i) for i, s in enumerate(samples)]
# Load models immediately (this will take time but ensures they're ready)
print("Loading DAM models...")
status = init_models()
print(f"Model initialization complete: {status}")
return status, gr.Dropdown(choices=choices, value=0 if choices else None)
def refresh_status():
"""Check current model status."""
if STATE["dam_original"] is not None and STATE["dam_sliding"] is not None:
return "βœ… Both DAM models loaded successfully!"
else:
return "πŸ”„ Models not loaded. Click to retry."
def retry_loading():
"""Retry loading models."""
return init_models()
demo.load(
fn=on_load,
outputs=[status_display, sample_dropdown]
)
# Add refresh button functionality
refresh_btn.click(
fn=refresh_status,
outputs=[status_display]
)
sample_dropdown.change(
fn=fill_from_sample,
inputs=[sample_dropdown],
outputs=[sample_image_display, ground_truth_display, sample_info_display, image_input, question_input]
)
compare_btn.click(
fn=compare_models,
inputs=[image_input, question_input, max_tokens_slider],
outputs=[comparison_output, original_output, sliding_output, voting_chart, voting_details]
)
if __name__ == "__main__":
demo.launch(
share=False,
show_error=True,
allowed_paths=["sample_images", "static"]
)