Deva8's picture
Update app.py
88597e4 verified
import gradio as gr
import pandas as pd
from huggingface_hub import hf_hub_download
from PIL import Image
import zipfile
import os
# ==================== Configuration ====================
REPO_ID = "Deva8/Generative-VQA-V2-Curated"
CACHE_DIR = "./dataset_cache"
# Global state
dataset_state = {
"df": None,
"images_dir": None,
"loaded": False,
"error": None
}
# ==================== Dataset Loading ====================
def initialize_dataset():
"""Download and setup the dataset"""
if dataset_state["loaded"]:
return
try:
print("πŸ“₯ Downloading metadata...")
csv_path = hf_hub_download(
repo_id=REPO_ID,
filename="main_metadata.csv",
repo_type="dataset",
cache_dir=CACHE_DIR
)
dataset_state["df"] = pd.read_csv(csv_path)
print(f"βœ“ Loaded {len(dataset_state['df']):,} examples")
print("πŸ“¦ Downloading images (10GB, please wait)...")
zip_path = hf_hub_download(
repo_id=REPO_ID,
filename="gen_vqa_v2-images.zip",
repo_type="dataset",
cache_dir=CACHE_DIR
)
dataset_state["images_dir"] = os.path.join(CACHE_DIR, "extracted")
if not os.path.exists(dataset_state["images_dir"]):
print("πŸ“‚ Extracting images...")
os.makedirs(dataset_state["images_dir"], exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zf:
zf.extractall(dataset_state["images_dir"])
dataset_state["loaded"] = True
print("βœ… Dataset ready!")
except Exception as e:
dataset_state["error"] = str(e)
print(f"❌ Error: {e}")
# Load dataset on startup
initialize_dataset()
# ==================== Helper Functions ====================
def load_image(file_path):
"""Load image with error handling"""
try:
full_path = os.path.join(dataset_state["images_dir"], file_path)
return Image.open(full_path).convert('RGB')
except Exception as e:
print(f"Error loading image: {e}")
return None
def check_dataset():
"""Check if dataset is loaded"""
if not dataset_state["loaded"]:
msg = dataset_state["error"] if dataset_state["error"] else "Dataset is loading..."
return None, f"⏳ {msg}", "", ""
# ==================== Main Functions ====================
def show_random():
"""Display a random example"""
check = check_dataset()
if check:
return check
sample = dataset_state["df"].sample(1).iloc[0]
img = load_image(sample['file_name'])
return (
img,
sample['question'],
sample['answer'],
f"Image ID: {sample['image_id']} | Question ID: {sample['question_id']}"
)
def search_question(query):
"""Search by question keywords"""
check = check_dataset()
if check:
return check
if not query or len(query.strip()) < 2:
return None, "Enter at least 2 characters", "", ""
matches = dataset_state["df"][
dataset_state["df"]['question'].str.contains(query, case=False, na=False)
]
if len(matches) == 0:
return None, f"No matches for '{query}'", "", ""
sample = matches.sample(1).iloc[0]
img = load_image(sample['file_name'])
return (
img,
sample['question'],
sample['answer'],
f"Found {len(matches):,} matches | Showing random example"
)
def search_answer(query):
"""Search by answer"""
check = check_dataset()
if check:
return check
if not query:
return None, "Enter an answer", "", ""
matches = dataset_state["df"][
dataset_state["df"]['answer'].str.lower() == query.lower().strip()
]
if len(matches) == 0:
return None, f"No examples with answer '{query}'", "", ""
sample = matches.sample(1).iloc[0]
img = load_image(sample['file_name'])
return (
img,
sample['question'],
sample['answer'],
f"Found {len(matches):,} examples | Showing random"
)
def get_stats():
"""Get dataset statistics"""
if not dataset_state["loaded"]:
return "Dataset loading..."
df = dataset_state["df"]
top_answers = df['answer'].value_counts().head(10)
stats = f"""# πŸ“Š Dataset Statistics
**Total Examples:** {len(df):,}
**Unique Images:** {df['image_id'].nunique():,}
**Unique Answers:** {df['answer'].nunique():,}
## Top 10 Answers
"""
for i, (ans, count) in enumerate(top_answers.items(), 1):
stats += f"{i}. **{ans}** - {count:,} examples\n"
return stats
# ==================== Gradio Interface ====================
with gr.Blocks(
title="VQA Dataset Explorer",
theme=gr.themes.Soft(primary_hue="blue")
) as demo:
gr.Markdown("""
# 🎯 Generative VQA-V2 Dataset Explorer
Explore 135K+ curated visual question-answer pairs from the
[Generative-VQA-V2-Curated](https://huggingface.co/datasets/Deva8/Generative-VQA-V2-Curated) dataset.
""")
with gr.Tabs():
# Random Samples Tab
with gr.TabItem("🎲 Random"):
gr.Markdown("### Browse random examples")
btn_random = gr.Button("πŸ”„ Show Random Example", variant="primary", size="lg")
with gr.Row():
img_random = gr.Image(label="Image", height=400)
with gr.Column():
q_random = gr.Textbox(label="❓ Question", lines=3)
a_random = gr.Textbox(label="βœ… Answer", lines=2)
m_random = gr.Textbox(label="ℹ️ Info", lines=1)
btn_random.click(
show_random,
outputs=[img_random, q_random, a_random, m_random]
)
# Question Search Tab
with gr.TabItem("πŸ” Search Questions"):
gr.Markdown("### Find questions containing keywords")
with gr.Row():
query_q = gr.Textbox(
label="Search",
placeholder="e.g., color, wearing, many, holding",
scale=4
)
btn_q = gr.Button("πŸ”Ž Search", variant="primary", scale=1)
with gr.Row():
img_q = gr.Image(label="Image", height=400)
with gr.Column():
q_q = gr.Textbox(label="❓ Question", lines=3)
a_q = gr.Textbox(label="βœ… Answer", lines=2)
m_q = gr.Textbox(label="ℹ️ Info", lines=1)
btn_q.click(search_question, inputs=[query_q], outputs=[img_q, q_q, a_q, m_q])
query_q.submit(search_question, inputs=[query_q], outputs=[img_q, q_q, a_q, m_q])
# Answer Search Tab
with gr.TabItem("🎯 Search Answers"):
gr.Markdown("### Find examples with specific answers")
with gr.Row():
query_a = gr.Textbox(
label="Answer",
placeholder="e.g., red, cat, pizza, 2",
scale=4
)
btn_a = gr.Button("πŸ”Ž Search", variant="primary", scale=1)
gr.Markdown("**Popular:** white, black, blue, red, 2, 3, dog, cat, pizza")
with gr.Row():
img_a = gr.Image(label="Image", height=400)
with gr.Column():
q_a = gr.Textbox(label="❓ Question", lines=3)
a_a = gr.Textbox(label="βœ… Answer", lines=2)
m_a = gr.Textbox(label="ℹ️ Info", lines=1)
btn_a.click(search_answer, inputs=[query_a], outputs=[img_a, q_a, a_a, m_a])
query_a.submit(search_answer, inputs=[query_a], outputs=[img_a, q_a, a_a, m_a])
# Statistics Tab
with gr.TabItem("πŸ“Š Stats"):
gr.Markdown("### Dataset overview and analysis")
btn_stats = gr.Button("πŸ“ˆ Load Statistics", variant="primary")
stats_md = gr.Markdown()
btn_stats.click(get_stats, outputs=[stats_md])
gr.Markdown("""
---
### About
This dataset is a curated version of VQA v2 with:
- βœ… No yes/no questions
- βœ… Balanced answer distribution
- βœ… Filtered ambiguous questions
**Dataset:** [Deva8/Generative-VQA-V2-Curated](https://huggingface.co/datasets/Deva8/Generative-VQA-V2-Curated)
**License:** CC BY 4.0
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)