Spaces:
Sleeping
Sleeping
File size: 13,898 Bytes
3fd9d26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 |
import os
import json
import gradio as gr
import plotly.graph_objects as go
import pandas as pd
import time
from PIL import Image
import vlai_template
from src.dam_models import get_dam_original, get_dam_sliding
# App configuration
vlai_template.set_meta(
project_name="DAM-QA Demo",
year="2025",
module="DAM",
description="DAM-QA performance on Visual Question Answering tasks",
meta_items=[
("Original DAM", "Full image processing"),
("DAM-QA", "Sliding window + voting"),
("Datasets", "DocVQA, InfographicVQA, TextVQA, ChartQA, VQAv2"),
],
)
# Global state for models
STATE = {
"dam_original": None,
"dam_sliding": None,
"samples": []
}
# Load sample data
def load_samples():
"""Load sample questions and images."""
try:
with open("samples.json", "r") as f:
samples = json.load(f)
STATE["samples"] = samples
return samples
except Exception as e:
print(f"Error loading samples: {e}")
return []
def init_models():
"""Initialize both DAM models."""
try:
STATE["dam_original"] = get_dam_original()
STATE["dam_sliding"] = get_dam_sliding()
return "β
Both DAM models loaded successfully!"
except Exception as e:
error_msg = f"β Error loading models: {str(e)}"
print(error_msg)
return error_msg
def get_sample_choices():
"""Get list of sample choices for dropdown."""
samples = STATE["samples"]
choices = []
for i, sample in enumerate(samples):
label = f"{sample['dataset']}: {sample['question'][:50]}..."
choices.append((label, i))
return choices
def fill_from_sample(sample_idx):
"""Fill inputs from selected sample."""
if not STATE["samples"] or sample_idx is None or sample_idx >= len(STATE["samples"]):
return None, "", "", None, ""
sample = STATE["samples"][sample_idx]
# Load the sample image
try:
sample_img = Image.open(sample["image"])
return (
sample_img, # sample_image_display
sample["ground_truth"], # ground_truth_display
f"Dataset: {sample['dataset']}\nDescription: {sample['description']}", # sample_info_display
sample_img, # image_input (copy to main input)
sample["question"] # question_input (copy to main input)
)
except Exception as e:
print(f"Error loading sample image {sample['image']}: {e}")
return None, sample["ground_truth"], f"Error loading image: {e}", None, sample["question"]
def compare_models(image, question, max_tokens):
"""Compare both models on the same input."""
if STATE["dam_original"] is None or STATE["dam_sliding"] is None:
return "β Models not loaded. Please wait for models to initialize.", "", "", None, ""
if image is None:
return "β Please provide an image", "", "", None, ""
if not question or not question.strip():
return "β Please provide a question", "", "", None, ""
try:
# Convert to PIL Image if needed
if isinstance(image, str):
img = Image.open(image)
elif hasattr(image, 'save'): # PIL Image
img = image
else:
return "β Invalid image format", "", "", None, ""
# DAM Original prediction
original_answer, original_time = STATE["dam_original"].predict(
img, question, max_tokens
)
# DAM Sliding Window prediction
sliding_answer, sliding_time, voting_details = STATE["dam_sliding"].predict(
img, question, max_tokens
)
# Format results
original_result = f"""
### π DAM Original (Full Image)
**Answer:** {original_answer}
**Inference Time:** {original_time:.2f}s
**Method:** Processes the entire image at once
"""
sliding_result = f"""
### π§© DAM-QA (Sliding Window + Voting)
**Answer:** {sliding_answer}
**Inference Time:** {sliding_time:.2f}s
**Method:** Sliding windows with weighted voting
**Total Windows:** {voting_details.get('total_windows', 'N/A')}
"""
# Create comparison summary
comparison = f"""
## π Comparison Summary
| Method | Answer | Time (s) | Approach |
|--------|--------|----------|----------|
| DAM Original | {original_answer} | {original_time:.2f} | Full image |
| DAM-QA Sliding | {sliding_answer} | {sliding_time:.2f} | Window + voting |
**Speed Difference:** {abs(original_time - sliding_time):.2f}s
**Faster Method:** {'DAM Original' if original_time < sliding_time else 'DAM-QA'}
"""
# Create voting visualization
vote_fig = create_voting_chart(voting_details)
# Detailed voting info
voting_info = format_voting_details(voting_details)
return comparison, original_result, sliding_result, vote_fig, voting_info
except Exception as e:
error_msg = f"β Error during inference: {str(e)}"
return error_msg, "", "", None, ""
def create_voting_chart(voting_details):
"""Create a visualization of the voting process."""
if not voting_details or "vote_summary" not in voting_details:
return None
votes = voting_details["vote_summary"]
if not votes:
return None
answers = list(votes.keys())
weights = list(votes.values())
# Create bar chart
fig = go.Figure(data=[
go.Bar(
x=answers,
y=weights,
text=[f"{w:.3f}" for w in weights],
textposition='auto',
marker_color=['#C4314B' if ans == voting_details.get('final_answer', '') else '#0F6CBD' for ans in answers]
)
])
fig.update_layout(
title="DAM-QA Voting Results",
xaxis_title="Answers",
yaxis_title="Vote Weight",
plot_bgcolor="white",
paper_bgcolor="white",
font=dict(color="black", size=12),
height=400,
margin=dict(l=30, r=20, t=60, b=40)
)
return fig
def format_voting_details(voting_details):
"""Format detailed voting information."""
if not voting_details:
return "No voting details available."
details = []
# Full image vote
if "full_image" in voting_details and voting_details["full_image"]:
full_vote = voting_details["full_image"]
details.append(f"**Full Image Vote:**")
details.append(f"- Answer: {full_vote['answer']}")
details.append(f"- Weight: {full_vote['weight']:.3f}")
details.append("")
# Window votes summary
if "windows" in voting_details:
windows = voting_details["windows"]
details.append(f"**Window Votes:** {len(windows)} windows processed")
# Group by answer
answer_groups = {}
for window in windows:
ans = window["answer"]
if ans not in answer_groups:
answer_groups[ans] = []
answer_groups[ans].append(window)
for answer, windows_for_ans in answer_groups.items():
total_weight = sum(w["weight"] for w in windows_for_ans)
details.append(f"- **{answer}**: {len(windows_for_ans)} windows, total weight: {total_weight:.3f}")
details.append("")
# Final summary
if "vote_summary" in voting_details:
details.append("**Final Vote Tally:**")
for answer, weight in voting_details["vote_summary"].items():
marker = "π" if answer == voting_details.get("final_answer", "") else " "
details.append(f"{marker} {answer}: {weight:.3f}")
return "\n".join(details)
# Force light theme
force_light_theme_js = """
() => {
const params = new URLSearchParams(window.location.search);
if (!params.has('__theme')) {
params.set('__theme', 'light');
window.location.search = params.toString();
}
}
"""
# Main Gradio interface
with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo:
vlai_template.create_header()
gr.HTML(vlai_template.render_info_card(
icon="π€",
title="About this Demo",
description="This demo compares two approaches for Visual Question Answering: DAM (original) processes the full image, while DAM-QA uses a sliding window approach with weighted voting to better handle text-rich images."
))
gr.HTML(vlai_template.render_disclaimer(
text=(
"This demo is for research and educational purposes only. "
"The models are designed for visual question answering on text-rich images. "
"Results may vary based on image quality and question complexity."
)
))
gr.Markdown("### π― **How to Use**: Select a sample or upload your image β Ask a question β Compare both models β Analyze the voting results!")
# Model Status at top
with gr.Accordion("π€ Model Status", open=True):
with gr.Row():
status_display = gr.Markdown("Loading models...")
refresh_btn = gr.Button("π Refresh Status", variant="secondary", scale=1)
with gr.Row(equal_height=False, variant="panel"):
# LEFT: Input Section
with gr.Column(scale=35):
with gr.Accordion("π€ Upload Image & Question", open=True):
image_input = gr.Image(label="Upload Image", type="pil", height=300)
question_input = gr.Textbox(
label="Your Question",
placeholder="Ask a question about the image...",
lines=3
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=10, maximum=200, value=100, step=10,
label="Max Tokens", scale=2
)
compare_btn = gr.Button("π Compare Models", variant="primary", size="lg", scale=1)
with gr.Accordion("π Try Sample Images", open=True):
sample_dropdown = gr.Dropdown(
label="Select Sample Dataset",
choices=[],
value=None,
info="Choose a sample to auto-fill the inputs above"
)
sample_image_display = gr.Image(label="Sample Preview", interactive=False, height=200)
with gr.Row():
ground_truth_display = gr.Textbox(label="Expected Answer", interactive=False, scale=2)
sample_info_display = gr.Textbox(label="Dataset Info", interactive=False, lines=3, scale=1)
# MIDDLE: Results Comparison
with gr.Column(scale=40):
with gr.Accordion("π Model Comparison Results", open=True):
comparison_output = gr.Markdown("Click 'Compare Models' to see results...")
with gr.Row():
with gr.Column():
gr.Markdown("#### π DAM Original")
original_output = gr.Markdown("Results will appear here...")
with gr.Column():
gr.Markdown("#### π§© DAM-QA Sliding Window")
sliding_output = gr.Markdown("Results will appear here...")
# RIGHT: Voting Analysis
with gr.Column(scale=25):
with gr.Accordion("π³οΈ DAM-QA Voting Analysis", open=True):
voting_chart = gr.Plot(label="Vote Weights")
voting_details = gr.Markdown("Voting details will appear here...", max_height=200)
gr.Markdown("""
## π **Key Differences**
- **DAM Original**: Processes the entire image at once, faster but may miss fine details
- **DAM-QA Sliding Window**: Divides image into overlapping windows, slower but better for text-rich images
- **Voting Mechanism**: DAM-QA aggregates predictions from multiple windows using weighted voting
- **Use Cases**: DAM-QA typically performs better on documents, charts, and infographics
""")
vlai_template.create_footer()
# Event handlers
def on_load():
# Load samples first
samples = load_samples()
choices = [(f"{s['dataset']}: {s['question'][:50]}...", i) for i, s in enumerate(samples)]
# Load models immediately (this will take time but ensures they're ready)
print("Loading DAM models...")
status = init_models()
print(f"Model initialization complete: {status}")
return status, gr.Dropdown(choices=choices, value=0 if choices else None)
def refresh_status():
"""Check current model status."""
if STATE["dam_original"] is not None and STATE["dam_sliding"] is not None:
return "β
Both DAM models loaded successfully!"
else:
return "π Models not loaded. Click to retry."
def retry_loading():
"""Retry loading models."""
return init_models()
demo.load(
fn=on_load,
outputs=[status_display, sample_dropdown]
)
# Add refresh button functionality
refresh_btn.click(
fn=refresh_status,
outputs=[status_display]
)
sample_dropdown.change(
fn=fill_from_sample,
inputs=[sample_dropdown],
outputs=[sample_image_display, ground_truth_display, sample_info_display, image_input, question_input]
)
compare_btn.click(
fn=compare_models,
inputs=[image_input, question_input, max_tokens_slider],
outputs=[comparison_output, original_output, sliding_output, voting_chart, voting_details]
)
if __name__ == "__main__":
demo.launch(
share=False,
show_error=True,
allowed_paths=["sample_images", "static"]
) |