dung-vpt-uney commited on
Commit
31a530c
·
1 Parent(s): f39b78a

Update Visual-CoT demo - 2025-10-12 23:34:21

Browse files

Fixes:
- Fix LLaVA config registration error (compatibility with newer transformers)
- Update Gradio to latest version (security fixes)
- Auto-deployed via update script

Files changed (1) hide show
  1. app.py +97 -61
app.py CHANGED
@@ -34,6 +34,21 @@ from llava.mm_utils import (
34
  get_model_name_from_path,
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # =============================================================================
38
  # Authentication
39
  # =============================================================================
@@ -66,14 +81,16 @@ MODEL_PATH = "deepcs233/VisCoT-7b-224" # Default: smallest/fastest
66
  CURRENT_MODEL_NAME = "VisCoT-7B-224 (Fastest)"
67
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
68
 
69
- # Benchmark datasets from HF collection
70
- BENCHMARK_DATASETS = {
71
- "GQA": "tuandunghcmut/gqa_cot",
72
- "TextVQA": "tuandunghcmut/textvqa_cot",
73
- "DocVQA": "tuandunghcmut/docvqa_cot",
74
- "Flickr30K": "tuandunghcmut/flickr30k_cot",
75
- "InfographicsVQA": "tuandunghcmut/infographicsvqa_cot",
76
- }
 
 
77
 
78
  # =============================================================================
79
  # Model Loading (Global - bfloat16)
@@ -156,36 +173,13 @@ def switch_model(model_choice):
156
  # =============================================================================
157
 
158
  def load_benchmark_example(dataset_name, index=0):
159
- """Load an example from benchmark dataset"""
160
- try:
161
- from datasets import load_dataset
162
-
163
- dataset_path = BENCHMARK_DATASETS.get(dataset_name)
164
- if not dataset_path:
165
- return None, "Dataset not found", "", "", ""
166
-
167
- # Load dataset
168
- dataset = load_dataset(dataset_path, split="train")
169
-
170
- if index >= len(dataset):
171
- index = 0
172
-
173
- example = dataset[index]
174
-
175
- # Extract fields
176
- image = example.get("image")
177
- question = example.get("question", "")
178
- bbox = example.get("bbox", "")
179
- answer = example.get("answer", "")
180
-
181
- info = f"Dataset: {dataset_name} | Example {index + 1}/{len(dataset)}"
182
-
183
- return image, question, bbox, answer, info
184
-
185
- except Exception as e:
186
- error_msg = f"Error loading benchmark: {str(e)}"
187
- print(error_msg)
188
- return None, error_msg, "", "", ""
189
 
190
  # =============================================================================
191
  # Utility Functions
@@ -610,16 +604,42 @@ def create_demo():
610
  visible=False,
611
  )
612
 
613
- # Example images
614
- gr.Markdown("### 📋 Try These Examples")
615
- gr.Examples(
616
- examples=[
617
- ["examples/extreme_ironing.jpg", "What is unusual about this image?"],
618
- ["examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
619
- ],
620
- inputs=[image_input, question_input],
621
- label="Click to load example",
622
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
 
624
  # Event handlers
625
  submit_btn.click(
@@ -698,19 +718,35 @@ def create_demo():
698
  interactive=False,
699
  )
700
 
701
- gr.Markdown("""
702
- ---
703
-
704
- ### Dataset Information
705
-
706
- 1. **GQA** - Scene graph question answering with compositional reasoning
707
- 2. **TextVQA** - Questions requiring reading and understanding text in images
708
- 3. **DocVQA** - Document understanding and information extraction
709
- 4. **Visual7W** - Visual question answering with pointing and telling tasks
710
- 5. **Flickr30k** - Image captioning and visual grounding
711
-
712
- **Note:** Examples are loaded directly from the [Visual-CoT Hugging Face Collection](https://huggingface.co/collections/tuandunghcmut/visual-chain-of-thought-reasoning-benchmarks-68e25b22c3c095c6f87baba0).
713
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
 
715
  # Event handlers
716
  def load_and_update(dataset_name, index):
 
34
  get_model_name_from_path,
35
  )
36
 
37
+ # Import benchmark loader for local datasets
38
+ try:
39
+ from benchmark_loader import (
40
+ get_all_dataset_names,
41
+ load_benchmark_example_for_gradio,
42
+ get_random_examples_for_gradio,
43
+ get_dataset_info,
44
+ get_dataset_stats,
45
+ )
46
+ BENCHMARK_LOADER_AVAILABLE = True
47
+ print("✅ Benchmark loader module imported successfully")
48
+ except ImportError as e:
49
+ BENCHMARK_LOADER_AVAILABLE = False
50
+ print(f"⚠️ Benchmark loader not available: {e}")
51
+
52
  # =============================================================================
53
  # Authentication
54
  # =============================================================================
 
81
  CURRENT_MODEL_NAME = "VisCoT-7B-224 (Fastest)"
82
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
83
 
84
+ # Benchmark datasets - will be loaded from benchmark_loader module
85
+ if BENCHMARK_LOADER_AVAILABLE:
86
+ BENCHMARK_DATASETS = get_all_dataset_names()
87
+ print(f" Loaded {len(BENCHMARK_DATASETS)} benchmark datasets")
88
+ stats = get_dataset_stats()
89
+ total_examples = sum(s.get("total_examples", 0) for s in stats.values() if "error" not in s)
90
+ print(f"📊 Total examples across all benchmarks: {total_examples:,}")
91
+ else:
92
+ BENCHMARK_DATASETS = ["GQA", "TextVQA", "DocVQA", "Visual7W", "Flickr30k"]
93
+ print("⚠️ Using fallback benchmark list")
94
 
95
  # =============================================================================
96
  # Model Loading (Global - bfloat16)
 
173
  # =============================================================================
174
 
175
  def load_benchmark_example(dataset_name, index=0):
176
+ """Load an example from benchmark dataset using benchmark_loader"""
177
+ if BENCHMARK_LOADER_AVAILABLE:
178
+ return load_benchmark_example_for_gradio(dataset_name, index)
179
+ else:
180
+ # Fallback for when benchmark_loader is not available
181
+ error_msg = "Benchmark loader module not available"
182
+ return None, error_msg, "", "", error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  # =============================================================================
185
  # Utility Functions
 
604
  visible=False,
605
  )
606
 
607
+ # Example images from benchmarks
608
+ gr.Markdown("### 📋 Try These Examples from Benchmarks")
609
+
610
+ # Generate examples from multiple benchmarks if available
611
+ if BENCHMARK_LOADER_AVAILABLE:
612
+ try:
613
+ benchmark_examples = get_random_examples_for_gradio(count=6)
614
+ if benchmark_examples:
615
+ gr.Examples(
616
+ examples=benchmark_examples,
617
+ inputs=[image_input, question_input],
618
+ label="Click to load random benchmark examples",
619
+ )
620
+ else:
621
+ gr.Markdown("*Benchmark examples loading failed. Check if images are available.*")
622
+ except Exception as e:
623
+ gr.Markdown(f"*Could not load benchmark examples: {e}*")
624
+ # Fallback to default examples
625
+ gr.Examples(
626
+ examples=[
627
+ ["examples/extreme_ironing.jpg", "What is unusual about this image?"],
628
+ ["examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
629
+ ],
630
+ inputs=[image_input, question_input],
631
+ label="Click to load example",
632
+ )
633
+ else:
634
+ # Fallback examples when benchmark loader not available
635
+ gr.Examples(
636
+ examples=[
637
+ ["examples/extreme_ironing.jpg", "What is unusual about this image?"],
638
+ ["examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
639
+ ],
640
+ inputs=[image_input, question_input],
641
+ label="Click to load example",
642
+ )
643
 
644
  # Event handlers
645
  submit_btn.click(
 
718
  interactive=False,
719
  )
720
 
721
+ # Dataset information - dynamically generated
722
+ if BENCHMARK_LOADER_AVAILABLE:
723
+ dataset_info_md = "---\n\n### Available Benchmark Datasets\n\n"
724
+ stats = get_dataset_stats()
725
+ for i, (name, info) in enumerate(stats.items(), 1):
726
+ if "error" not in info:
727
+ dataset_info_md += f"{i}. **{name}** ({info['total_examples']:,} examples): {info['description']}\n"
728
+ else:
729
+ dataset_info_md += f"{i}. **{name}**: {info['error']}\n"
730
+
731
+ total_examples = sum(s.get("total_examples", 0) for s in stats.values() if "error" not in s)
732
+ dataset_info_md += f"\n**Total:** {total_examples:,} annotated examples across {len(stats)} benchmarks\n"
733
+ dataset_info_md += "\n**Source:** Local JSONL files from Visual-CoT dataset"
734
+
735
+ gr.Markdown(dataset_info_md)
736
+ else:
737
+ gr.Markdown("""
738
+ ---
739
+
740
+ ### Dataset Information
741
+
742
+ 1. **GQA** - Scene graph question answering with compositional reasoning
743
+ 2. **TextVQA** - Questions requiring reading and understanding text in images
744
+ 3. **DocVQA** - Document understanding and information extraction
745
+ 4. **Visual7W** - Visual question answering with pointing and telling tasks
746
+ 5. **Flickr30k** - Image captioning and visual grounding
747
+
748
+ **Note:** Benchmark loader module not available.
749
+ """)
750
 
751
  # Event handlers
752
  def load_and_update(dataset_name, index):