dung-vpt-uney commited on
Commit
69afdf8
·
1 Parent(s): 31a530c

Update Visual-CoT demo - 2025-10-12 23:40:03

Browse files

Fixes:
- Fix LLaVA config registration error (compatibility with newer transformers)
- Update Gradio to latest version (security fixes)
- Auto-deployed via update script

Files changed (1) hide show
  1. app.py +113 -100
app.py CHANGED
@@ -34,20 +34,7 @@ from llava.mm_utils import (
34
  get_model_name_from_path,
35
  )
36
 
37
- # Import benchmark loader for local datasets
38
- try:
39
- from benchmark_loader import (
40
- get_all_dataset_names,
41
- load_benchmark_example_for_gradio,
42
- get_random_examples_for_gradio,
43
- get_dataset_info,
44
- get_dataset_stats,
45
- )
46
- BENCHMARK_LOADER_AVAILABLE = True
47
- print("✅ Benchmark loader module imported successfully")
48
- except ImportError as e:
49
- BENCHMARK_LOADER_AVAILABLE = False
50
- print(f"⚠️ Benchmark loader not available: {e}")
51
 
52
  # =============================================================================
53
  # Authentication
@@ -81,16 +68,48 @@ MODEL_PATH = "deepcs233/VisCoT-7b-224" # Default: smallest/fastest
81
  CURRENT_MODEL_NAME = "VisCoT-7B-224 (Fastest)"
82
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
83
 
84
- # Benchmark datasets - will be loaded from benchmark_loader module
85
- if BENCHMARK_LOADER_AVAILABLE:
86
- BENCHMARK_DATASETS = get_all_dataset_names()
87
- print(f" Loaded {len(BENCHMARK_DATASETS)} benchmark datasets")
88
- stats = get_dataset_stats()
89
- total_examples = sum(s.get("total_examples", 0) for s in stats.values() if "error" not in s)
90
- print(f"📊 Total examples across all benchmarks: {total_examples:,}")
91
- else:
92
- BENCHMARK_DATASETS = ["GQA", "TextVQA", "DocVQA", "Visual7W", "Flickr30k"]
93
- print("⚠️ Using fallback benchmark list")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # =============================================================================
96
  # Model Loading (Global - bfloat16)
@@ -173,12 +192,51 @@ def switch_model(model_choice):
173
  # =============================================================================
174
 
175
  def load_benchmark_example(dataset_name, index=0):
176
- """Load an example from benchmark dataset using benchmark_loader"""
177
- if BENCHMARK_LOADER_AVAILABLE:
178
- return load_benchmark_example_for_gradio(dataset_name, index)
179
- else:
180
- # Fallback for when benchmark_loader is not available
181
- error_msg = "Benchmark loader module not available"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  return None, error_msg, "", "", error_msg
183
 
184
  # =============================================================================
@@ -405,7 +463,7 @@ def create_demo():
405
  .header {
406
  text-align: center;
407
  padding: 20px;
408
- background: linear-gradient(135deg, #475569 0%, #334155 100%);
409
  color: white;
410
  border-radius: 10px;
411
  margin-bottom: 20px;
@@ -437,8 +495,8 @@ def create_demo():
437
 
438
  with gr.Blocks(
439
  theme=gr.themes.Soft(
440
- primary_hue="slate",
441
- secondary_hue="gray",
442
  neutral_hue="slate",
443
  ),
444
  css=custom_css,
@@ -604,42 +662,16 @@ def create_demo():
604
  visible=False,
605
  )
606
 
607
- # Example images from benchmarks
608
- gr.Markdown("### 📋 Try These Examples from Benchmarks")
609
-
610
- # Generate examples from multiple benchmarks if available
611
- if BENCHMARK_LOADER_AVAILABLE:
612
- try:
613
- benchmark_examples = get_random_examples_for_gradio(count=6)
614
- if benchmark_examples:
615
- gr.Examples(
616
- examples=benchmark_examples,
617
- inputs=[image_input, question_input],
618
- label="Click to load random benchmark examples",
619
- )
620
- else:
621
- gr.Markdown("*Benchmark examples loading failed. Check if images are available.*")
622
- except Exception as e:
623
- gr.Markdown(f"*Could not load benchmark examples: {e}*")
624
- # Fallback to default examples
625
- gr.Examples(
626
- examples=[
627
- ["examples/extreme_ironing.jpg", "What is unusual about this image?"],
628
- ["examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
629
- ],
630
- inputs=[image_input, question_input],
631
- label="Click to load example",
632
- )
633
- else:
634
- # Fallback examples when benchmark loader not available
635
- gr.Examples(
636
- examples=[
637
- ["examples/extreme_ironing.jpg", "What is unusual about this image?"],
638
- ["examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
639
- ],
640
- inputs=[image_input, question_input],
641
- label="Click to load example",
642
- )
643
 
644
  # Event handlers
645
  submit_btn.click(
@@ -668,9 +700,9 @@ def create_demo():
668
  with gr.Column(scale=2):
669
  dataset_dropdown = gr.Dropdown(
670
  choices=list(BENCHMARK_DATASETS.keys()),
671
- value="GQA",
672
  label="Select Benchmark Dataset",
673
- info="Choose from 5 core benchmarks"
674
  )
675
  with gr.Column(scale=1):
676
  example_index = gr.Number(
@@ -718,35 +750,16 @@ def create_demo():
718
  interactive=False,
719
  )
720
 
721
- # Dataset information - dynamically generated
722
- if BENCHMARK_LOADER_AVAILABLE:
723
- dataset_info_md = "---\n\n### Available Benchmark Datasets\n\n"
724
- stats = get_dataset_stats()
725
- for i, (name, info) in enumerate(stats.items(), 1):
726
- if "error" not in info:
727
- dataset_info_md += f"{i}. **{name}** ({info['total_examples']:,} examples): {info['description']}\n"
728
- else:
729
- dataset_info_md += f"{i}. **{name}**: {info['error']}\n"
730
-
731
- total_examples = sum(s.get("total_examples", 0) for s in stats.values() if "error" not in s)
732
- dataset_info_md += f"\n**Total:** {total_examples:,} annotated examples across {len(stats)} benchmarks\n"
733
- dataset_info_md += "\n**Source:** Local JSONL files from Visual-CoT dataset"
734
-
735
- gr.Markdown(dataset_info_md)
736
- else:
737
- gr.Markdown("""
738
- ---
739
-
740
- ### Dataset Information
741
-
742
- 1. **GQA** - Scene graph question answering with compositional reasoning
743
- 2. **TextVQA** - Questions requiring reading and understanding text in images
744
- 3. **DocVQA** - Document understanding and information extraction
745
- 4. **Visual7W** - Visual question answering with pointing and telling tasks
746
- 5. **Flickr30k** - Image captioning and visual grounding
747
-
748
- **Note:** Benchmark loader module not available.
749
- """)
750
 
751
  # Event handlers
752
  def load_and_update(dataset_name, index):
 
34
  get_model_name_from_path,
35
  )
36
 
37
+ # No need for local benchmark loader - using HF datasets directly
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # =============================================================================
40
  # Authentication
 
68
  CURRENT_MODEL_NAME = "VisCoT-7B-224 (Fastest)"
69
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
70
 
71
+ # Benchmark datasets from Visual Chain-of-Thought Reasoning Benchmarks Collection
72
+ # https://huggingface.co/collections/tuandunghcmut/visual-chain-of-thought-reasoning-benchmarks
73
+ BENCHMARK_DATASETS = {
74
+ "Visual-CoT": {
75
+ "path": "deepcs233/Visual-CoT",
76
+ "description": "Main Visual-CoT dataset with 438K question-answer pairs",
77
+ },
78
+ "GQA": {
79
+ "path": "lmms-lab/GQA",
80
+ "description": "Scene graph question answering (24.2M examples)",
81
+ },
82
+ "RefCOCO": {
83
+ "path": "lmms-lab/RefCOCO",
84
+ "description": "Referring expression comprehension (17.6K examples)",
85
+ },
86
+ "RefCOCO+": {
87
+ "path": "lmms-lab/RefCOCOplus",
88
+ "description": "RefCOCO with no location words (7.58K examples)",
89
+ },
90
+ "RefCOCOg": {
91
+ "path": "lmms-lab/RefCOCOg",
92
+ "description": "RefCOCO with longer expressions (12.6K examples)",
93
+ },
94
+ "POPE": {
95
+ "path": "lmms-lab/POPE",
96
+ "description": "Polling-based Object Probing Evaluation (18K examples)",
97
+ },
98
+ "ScienceQA": {
99
+ "path": "lmms-lab/ScienceQA",
100
+ "description": "Science question answering (12.6K examples)",
101
+ },
102
+ "MM-GCoT": {
103
+ "path": "AQUA6/MM-GCoT",
104
+ "description": "Multi-Modal Graph Chain-of-Thought (64.9K examples)",
105
+ },
106
+ "VGR": {
107
+ "path": "BytedanceDouyinContent/VGR",
108
+ "description": "Visual Grounding & Reasoning (90K examples)",
109
+ },
110
+ }
111
+
112
+ print(f"✅ Configured {len(BENCHMARK_DATASETS)} benchmark datasets from HF collection")
113
 
114
  # =============================================================================
115
  # Model Loading (Global - bfloat16)
 
192
  # =============================================================================
193
 
194
  def load_benchmark_example(dataset_name, index=0):
195
+ """Load an example from HF benchmark dataset"""
196
+ try:
197
+ from datasets import load_dataset
198
+
199
+ dataset_info = BENCHMARK_DATASETS.get(dataset_name)
200
+ if not dataset_info:
201
+ return None, "Dataset not found", "", "", ""
202
+
203
+ dataset_path = dataset_info["path"]
204
+
205
+ # Load dataset
206
+ print(f"Loading {dataset_name} from {dataset_path}...")
207
+ dataset = load_dataset(dataset_path, split="train", streaming=True)
208
+
209
+ # Get specific index (for streaming, we need to iterate)
210
+ for i, example in enumerate(dataset):
211
+ if i == index:
212
+ # Extract fields (structure varies by dataset)
213
+ image = example.get("image")
214
+ question = example.get("question", example.get("text", ""))
215
+
216
+ # Try to get bounding box in various formats
217
+ bbox = example.get("bbox", example.get("bboxes", ""))
218
+ if isinstance(bbox, list) and bbox:
219
+ bbox_str = str(bbox)
220
+ else:
221
+ bbox_str = "No bounding box available"
222
+
223
+ answer = example.get("answer", example.get("label", ""))
224
+
225
+ status = f"📊 Dataset: {dataset_name} | Example {index + 1}\n{dataset_info['description']}"
226
+
227
+ return image, question, bbox_str, answer, status
228
+
229
+ # Stop after a few iterations for efficiency
230
+ if i > index + 10:
231
+ break
232
+
233
+ return None, "Index out of range", "", "", "Could not find example at this index"
234
+
235
+ except Exception as e:
236
+ error_msg = f"Error loading {dataset_name}: {str(e)}"
237
+ print(error_msg)
238
+ import traceback
239
+ traceback.print_exc()
240
  return None, error_msg, "", "", error_msg
241
 
242
  # =============================================================================
 
463
  .header {
464
  text-align: center;
465
  padding: 20px;
466
+ background: linear-gradient(135deg, #1e3a8a 0%, #1e40af 100%);
467
  color: white;
468
  border-radius: 10px;
469
  margin-bottom: 20px;
 
495
 
496
  with gr.Blocks(
497
  theme=gr.themes.Soft(
498
+ primary_hue="blue",
499
+ secondary_hue="indigo",
500
  neutral_hue="slate",
501
  ),
502
  css=custom_css,
 
662
  visible=False,
663
  )
664
 
665
+ # Example images
666
+ gr.Markdown("### 📋 Try These Examples")
667
+ gr.Examples(
668
+ examples=[
669
+ ["examples/extreme_ironing.jpg", "What is unusual about this image?"],
670
+ ["examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
671
+ ],
672
+ inputs=[image_input, question_input],
673
+ label="Click to load example",
674
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
 
676
  # Event handlers
677
  submit_btn.click(
 
700
  with gr.Column(scale=2):
701
  dataset_dropdown = gr.Dropdown(
702
  choices=list(BENCHMARK_DATASETS.keys()),
703
+ value="Visual-CoT",
704
  label="Select Benchmark Dataset",
705
+ info="Choose from 9 visual reasoning benchmarks"
706
  )
707
  with gr.Column(scale=1):
708
  example_index = gr.Number(
 
750
  interactive=False,
751
  )
752
 
753
+ # Dataset information - dynamically generated from BENCHMARK_DATASETS
754
+ dataset_info_md = "---\n\n### Available Benchmark Datasets\n\n"
755
+ for i, (name, info) in enumerate(BENCHMARK_DATASETS.items(), 1):
756
+ dataset_info_md += f"{i}. **{name}**: {info['description']}\n"
757
+ dataset_info_md += f" - Path: `{info['path']}`\n"
758
+
759
+ dataset_info_md += f"\n**Total:** {len(BENCHMARK_DATASETS)} benchmarks from Visual Chain-of-Thought Reasoning Collection\n"
760
+ dataset_info_md += "\n**Source:** [Hugging Face Collection](https://huggingface.co/collections/tuandunghcmut/visual-chain-of-thought-reasoning-benchmarks)"
761
+
762
+ gr.Markdown(dataset_info_md)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
 
764
  # Event handlers
765
  def load_and_update(dataset_name, index):