Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Visual-CoT: Chain-of-Thought Reasoning Demo on Hugging Face Spaces | |
| Showcasing Visual Chain-of-Thought with Interactive Benchmark Examples | |
| Paper: Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive | |
| Dataset and Benchmark for Chain-of-Thought Reasoning | |
| https://arxiv.org/abs/2403.16999 | |
| """ | |
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image, ImageDraw, ImageFont | |
| import re | |
| import json | |
| import spaces | |
| from pathlib import Path | |
| import requests | |
| from io import BytesIO | |
| from huggingface_hub import login | |
| from llava.constants import ( | |
| IMAGE_TOKEN_INDEX, | |
| DEFAULT_IMAGE_TOKEN, | |
| DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IM_END_TOKEN, | |
| ) | |
| from llava.conversation import conv_templates | |
| from llava.model.builder import load_pretrained_model | |
| from llava.utils import disable_torch_init | |
| from llava.mm_utils import ( | |
| process_images, | |
| tokenizer_image_token, | |
| get_model_name_from_path, | |
| ) | |
| # ============================================================================= | |
| # Authentication | |
| # ============================================================================= | |
| # Login to Hugging Face using token from Spaces secrets | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| if HF_TOKEN: | |
| try: | |
| login(token=HF_TOKEN, add_to_git_credential=False) | |
| print("✓ Successfully logged in to Hugging Face") | |
| except Exception as e: | |
| print(f"⚠ Warning: Failed to login to Hugging Face: {e}") | |
| print(" Continuing without authentication...") | |
| else: | |
| print("ℹ No HF_TOKEN found, continuing without authentication") | |
| # ============================================================================= | |
| # Configuration | |
| # ============================================================================= | |
| MODEL_PATH = "deepcs233/VisCoT-7b-224" # Hugging Face model ID (smallest version) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Benchmark datasets available | |
| BENCHMARK_DATASETS = [ | |
| "docvqa", | |
| "flickr30k", | |
| "gqa", | |
| "infographicsvqa", | |
| "openimages", | |
| "textcap", | |
| "textvqa", | |
| "vsr", | |
| "cub", | |
| ] | |
| # ============================================================================= | |
| # Model Loading (Global - bfloat16) | |
| # ============================================================================= | |
| print("🔄 Loading Visual-CoT model in bfloat16...") | |
| disable_torch_init() | |
| model_name = get_model_name_from_path(MODEL_PATH) | |
| # Load model globally with bfloat16 precision | |
| tokenizer, model, image_processor, context_len = load_pretrained_model( | |
| MODEL_PATH, | |
| None, | |
| model_name, | |
| load_8bit=False, | |
| load_4bit=False, | |
| device=DEVICE, | |
| ) | |
| # Ensure model is in bfloat16 | |
| if DEVICE == "cuda": | |
| model = model.to(dtype=torch.bfloat16) | |
| print(f"✓ Model loaded in bfloat16 on {DEVICE}") | |
| else: | |
| print(f"✓ Model loaded on {DEVICE} (CPU mode)") | |
| print(f"✓ Model: {model_name}") | |
| print(f"✓ Context length: {context_len}") | |
| print(f"✓ Device: {DEVICE}") | |
| # ============================================================================= | |
| # Utility Functions | |
| # ============================================================================= | |
| def parse_bbox(text): | |
| """Parse bounding box from model output""" | |
| pattern1 = r"###\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]" | |
| pattern2 = r"\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]" | |
| matches = re.findall(pattern1, text) | |
| if not matches: | |
| matches = re.findall(pattern2, text) | |
| if matches: | |
| bbox = [float(x) for x in matches[-1]] | |
| if all(0 <= x <= 1 for x in bbox): | |
| return bbox | |
| return None | |
| def draw_bounding_box(image, bbox, color="red", width=5): | |
| """Draw bounding box on image""" | |
| if bbox is None: | |
| return image | |
| img = image.copy() | |
| draw = ImageDraw.Draw(img) | |
| img_width, img_height = img.size | |
| # Convert normalized to pixel coordinates | |
| x1 = int(bbox[0] * img_width) | |
| y1 = int(bbox[1] * img_height) | |
| x2 = int(bbox[2] * img_width) | |
| y2 = int(bbox[3] * img_height) | |
| # Draw rectangle | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=width) | |
| # Draw label | |
| label = f"ROI: [{bbox[0]:.3f}, {bbox[1]:.3f}, {bbox[2]:.3f}, {bbox[3]:.3f}]" | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14) | |
| except: | |
| font = ImageFont.load_default() | |
| # Text background | |
| bbox_text = draw.textbbox((x1, y1 - 22), label, font=font) | |
| draw.rectangle([bbox_text[0]-2, bbox_text[1]-2, bbox_text[2]+2, bbox_text[3]+2], fill=color) | |
| draw.text((x1, y1 - 22), label, fill="white", font=font) | |
| return img | |
| def load_benchmark_examples(dataset_name, num_examples=5): | |
| """ | |
| Load examples from benchmark dataset | |
| Returns list of (image_path, question, ground_truth_bbox, ground_truth_answer) | |
| """ | |
| benchmark_file = f"viscot_benchmark/benchmark/{dataset_name}.json" | |
| if not os.path.exists(benchmark_file): | |
| return [] | |
| try: | |
| with open(benchmark_file, 'r') as f: | |
| data = json.load(f) | |
| examples = [] | |
| for item in data[:num_examples]: | |
| # Extract information based on dataset structure | |
| image_file = item.get('image', '') | |
| question = item['conversations'][0]['value'].replace('<image>\n', '').split('Please provide')[0].strip() | |
| gt_bbox_str = item['conversations'][1]['value'] if len(item['conversations']) > 1 else None | |
| gt_answer = item['conversations'][3]['value'] if len(item['conversations']) > 3 else None | |
| examples.append({ | |
| 'image': image_file, | |
| 'question': question, | |
| 'gt_bbox': gt_bbox_str, | |
| 'gt_answer': gt_answer, | |
| 'dataset': dataset_name | |
| }) | |
| return examples | |
| except Exception as e: | |
| print(f"Error loading {dataset_name}: {e}") | |
| return [] | |
| # ============================================================================= | |
| # Main Inference Function (with @spaces.GPU decorator) | |
| # ============================================================================= | |
| # Zero GPU allocation for 120 seconds | |
| def generate_viscot_response(image, question, temperature=0.2, max_tokens=512): | |
| """ | |
| Generate Visual-CoT response with bounding box detection | |
| Args: | |
| image: PIL Image | |
| question: str | |
| temperature: float | |
| max_tokens: int | |
| Returns: | |
| tuple: (bbox_response, final_answer, image_with_bbox, processing_info) | |
| """ | |
| if image is None: | |
| return "❌ Please upload an image!", "", None, "" | |
| if not question.strip(): | |
| return "❌ Please enter a question!", "", None, "" | |
| try: | |
| # Model is already loaded globally - use it directly | |
| # Initialize conversation | |
| conv_mode = "llava_v1" | |
| conv = conv_templates[conv_mode].copy() | |
| # ===================================================================== | |
| # STEP 1: Detect Region of Interest (ROI) | |
| # ===================================================================== | |
| prompt_step1 = ( | |
| f"{DEFAULT_IMAGE_TOKEN}\n{question} " | |
| f"Please provide the bounding box coordinate of the region this question asks about." | |
| ) | |
| conv.append_message(conv.roles[0], prompt_step1) | |
| conv.append_message(conv.roles[1], None) | |
| prompt1 = conv.get_prompt() | |
| # Process image | |
| image_tensor = process_images([image], image_processor, model.config) | |
| if isinstance(image_tensor, list): | |
| image_tensor = [img.to(DEVICE, dtype=torch.bfloat16) for img in image_tensor] | |
| else: | |
| image_tensor = image_tensor.to(DEVICE, dtype=torch.bfloat16) | |
| # Tokenize | |
| input_ids = tokenizer_image_token( | |
| prompt1, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" | |
| ).unsqueeze(0).to(DEVICE) | |
| # Generate bbox | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=image_tensor, | |
| do_sample=temperature > 0.001, | |
| temperature=max(temperature, 0.01), | |
| max_new_tokens=128, | |
| use_cache=True, | |
| ) | |
| bbox_response = tokenizer.decode( | |
| output_ids[0, input_ids.shape[1]:], skip_special_tokens=True | |
| ).strip() | |
| # Parse bbox | |
| bbox = parse_bbox(bbox_response) | |
| # ===================================================================== | |
| # STEP 2: Answer Question with ROI Context | |
| # ===================================================================== | |
| conv.messages[-1][-1] = bbox_response | |
| second_question = ( | |
| f"Please answer the question based on the original image and local detail image. {question}" | |
| ) | |
| conv.append_message(conv.roles[0], second_question) | |
| conv.append_message(conv.roles[1], None) | |
| prompt2 = conv.get_prompt() | |
| input_ids = tokenizer_image_token( | |
| prompt2, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" | |
| ).unsqueeze(0).to(DEVICE) | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=image_tensor, | |
| do_sample=temperature > 0.001, | |
| temperature=max(temperature, 0.01), | |
| max_new_tokens=max_tokens, | |
| use_cache=True, | |
| ) | |
| final_answer = tokenizer.decode( | |
| output_ids[0, input_ids.shape[1]:], skip_special_tokens=True | |
| ).strip() | |
| # Visualization | |
| image_with_bbox = draw_bounding_box(image, bbox) if bbox else image | |
| # Processing info | |
| processing_info = f"✓ Processed successfully | Bbox: {bbox if bbox else 'Not detected'}" | |
| return bbox_response, final_answer, image_with_bbox, processing_info | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" | |
| return error_msg, "", None, error_msg | |
| # ============================================================================= | |
| # Gradio Interface | |
| # ============================================================================= | |
| def create_demo(): | |
| """Create Gradio interface""" | |
| # Custom CSS for beautiful UI | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| .header { | |
| text-align: center; | |
| padding: 20px; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| } | |
| .info-box { | |
| background: #f0f7ff; | |
| border-left: 4px solid #3b82f6; | |
| padding: 15px; | |
| border-radius: 5px; | |
| margin: 10px 0; | |
| } | |
| .example-box { | |
| border: 2px solid #e5e7eb; | |
| border-radius: 8px; | |
| padding: 10px; | |
| margin: 5px 0; | |
| } | |
| .metric-card { | |
| background: white; | |
| border-radius: 8px; | |
| padding: 15px; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.1); | |
| margin: 10px 0; | |
| } | |
| """ | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="purple", | |
| ), | |
| css=custom_css, | |
| title="Visual-CoT Demo" | |
| ) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>🌋 Visual-CoT: Chain-of-Thought Reasoning</h1> | |
| <p style="font-size: 18px; margin: 10px 0;"> | |
| Advancing Multi-Modal Language Models with Visual Chain-of-Thought | |
| </p> | |
| <p style="font-size: 14px; opacity: 0.9;"> | |
| 📄 <a href="https://arxiv.org/abs/2403.16999" style="color: white; text-decoration: underline;"> | |
| Paper (NeurIPS 2024 Spotlight) | |
| </a> | | |
| 💻 <a href="https://github.com/deepcs233/Visual-CoT" style="color: white; text-decoration: underline;"> | |
| GitHub | |
| </a> | | |
| 🤗 <a href="https://huggingface.co/datasets/deepcs233/Visual-CoT" style="color: white; text-decoration: underline;"> | |
| Dataset | |
| </a> | |
| </p> | |
| </div> | |
| """) | |
| # Introduction | |
| gr.Markdown(""" | |
| ## 1. Introduction to Visual-CoT | |
| **Visual Chain-of-Thought (VisCoT)** is a multi-modal language model that enables: | |
| 1. **Region Identification**: Detect key regions in images using bounding boxes | |
| 2. **Step-by-Step Reasoning**: Apply Chain-of-Thought methodology for visual understanding | |
| 3. **Question Answering**: Provide interpretable explanations for visual content | |
| ### 1.1 Dataset Statistics | |
| - 438,000 question-answer pairs with bounding box annotations | |
| - 13 diverse benchmarks (DocVQA, GQA, TextVQA, etc.) | |
| - Based on LLaVA-1.5 architecture with CLIP ViT-L/14 vision encoder | |
| """) | |
| # Authentication notice for Zero GPU | |
| gr.HTML(""" | |
| <div class="info-box"> | |
| <p style="margin: 0; font-size: 14px;"> | |
| 🔐 <strong>Authentication Required:</strong> This Space uses Zero GPU which requires you to be logged in to Hugging Face. | |
| If you see quota errors, please <a href="https://huggingface.co/login" target="_blank">login</a> or | |
| <a href="https://huggingface.co/join" target="_blank">create a free account</a>. | |
| </p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # ============================================================ | |
| # Tab 1: Interactive Demo | |
| # ============================================================ | |
| with gr.Tab("Interactive Demo"): | |
| gr.Markdown(""" | |
| ### 2. Interactive Demonstration | |
| **Procedure**: | |
| 1. Upload an image | |
| 2. Enter a question about the image | |
| 3. The model will: | |
| - Step 1: Detect region of interest (ROI) and output bounding box | |
| - Step 2: Analyze the ROI and generate answer | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Input Image", | |
| height=400, | |
| ) | |
| question_input = gr.Textbox( | |
| label="Question", | |
| placeholder="Example: What is unusual about this image?", | |
| lines=3, | |
| ) | |
| with gr.Accordion("Advanced Parameters", open=False): | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.05, | |
| label="Temperature", | |
| info="0 = Deterministic, 1 = Creative" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=128, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="Maximum Output Tokens" | |
| ) | |
| submit_btn = gr.Button("Run Analysis", variant="primary", size="lg") | |
| clear_btn = gr.Button("Clear", size="sm") | |
| with gr.Column(scale=1): | |
| # Output | |
| gr.Markdown("### 3. Results") | |
| with gr.Group(): | |
| gr.Markdown("#### 3.1 Step 1: Region Detection") | |
| bbox_output = gr.Textbox( | |
| label="Detected Bounding Box Coordinates", | |
| lines=2, | |
| show_copy_button=True, | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("#### 3.2 Step 2: Answer Generation") | |
| answer_output = gr.Textbox( | |
| label="Final Answer", | |
| lines=6, | |
| show_copy_button=True, | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("#### 3.3 Visualization") | |
| image_output = gr.Image( | |
| label="Image with Bounding Box Overlay", | |
| type="pil", | |
| height=350, | |
| ) | |
| info_output = gr.Textbox( | |
| label="Processing Info", | |
| lines=1, | |
| visible=False, | |
| ) | |
| # Example images | |
| gr.Markdown("### 📋 Try These Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/extreme_ironing.jpg", "What is unusual about this image?"], | |
| ["examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], | |
| ], | |
| inputs=[image_input, question_input], | |
| label="Click to load example", | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=generate_viscot_response, | |
| inputs=[image_input, question_input, temperature, max_tokens], | |
| outputs=[bbox_output, answer_output, image_output, info_output], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "", "", "", None, ""), | |
| outputs=[image_input, question_input, bbox_output, answer_output, image_output, info_output], | |
| ) | |
| # ============================================================ | |
| # Tab 2: Benchmark Explorer | |
| # ============================================================ | |
| with gr.Tab("Benchmark Explorer"): | |
| gr.Markdown(""" | |
| ### Explore Visual-CoT Benchmark Examples | |
| Select a benchmark dataset and browse annotated examples from our evaluation suite. | |
| These examples showcase the model's performance across diverse visual reasoning tasks. | |
| """) | |
| with gr.Row(): | |
| dataset_dropdown = gr.Dropdown( | |
| choices=BENCHMARK_DATASETS, | |
| value="gqa", | |
| label="🗂️ Select Benchmark Dataset", | |
| info="Choose from 13 diverse benchmarks" | |
| ) | |
| load_examples_btn = gr.Button("📥 Load Examples", variant="secondary") | |
| benchmark_gallery = gr.Gallery( | |
| label="Benchmark Examples", | |
| columns=3, | |
| height=400, | |
| object_fit="contain", | |
| ) | |
| benchmark_info = gr.Markdown(""" | |
| **Select a dataset and click "Load Examples" to view benchmark samples.** | |
| Available benchmarks: | |
| - **DocVQA**: Document visual question answering | |
| - **GQA**: Scene graph question answering | |
| - **TextVQA**: Text-based VQA | |
| - **Flickr30k**: Image captioning & grounding | |
| - **InfographicsVQA**: Infographic understanding | |
| - **OpenImages**: Object detection & description | |
| - And more... | |
| """) | |
| # Placeholder for benchmark loading (would need actual implementation) | |
| load_examples_btn.click( | |
| fn=lambda x: gr.Info(f"Loading {x} examples... (Feature coming soon!)"), | |
| inputs=[dataset_dropdown], | |
| outputs=None, | |
| ) | |
| # ============================================================ | |
| # Tab 3: About & Paper | |
| # ============================================================ | |
| with gr.Tab("About"): | |
| gr.Markdown(""" | |
| ## Paper Information | |
| **Title:** Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive Dataset and Benchmark for Chain-of-Thought Reasoning | |
| **Authors:** Hao Shao, Shengju Qian, Han Xiao, Guanglu Song, Zhuofan Zong, Letian Wang, Yu Liu, Hongsheng Li | |
| **Conference:** NeurIPS 2024 (Spotlight) 🎉 | |
| **Abstract:** | |
| We introduce Visual-CoT, a comprehensive dataset and benchmark for evaluating chain-of-thought reasoning | |
| in multi-modal language models. Our dataset comprises 438K question-answer pairs with intermediate bounding | |
| box annotations highlighting key regions essential for answering questions. We propose a multi-turn processing | |
| pipeline that dynamically focuses on visual inputs and provides interpretable reasoning steps. | |
| --- | |
| ## Model Architecture | |
| ``` | |
| Visual-CoT Pipeline: | |
| Image Input | |
| ↓ | |
| CLIP ViT-L/14 (Vision Encoder) | |
| ↓ | |
| MLP Projector (2-layer) | |
| ↓ | |
| LLaMA/Vicuna (Language Model) | |
| ↓ | |
| Step 1: ROI Detection → Bounding Box | |
| ↓ | |
| Step 2: Question Answering → Final Answer | |
| ``` | |
| --- | |
| ## Key Results | |
| - **Detection Accuracy**: 75.3% (IoU > 0.5) | |
| - **Answer Accuracy**: 82.7% (GPT-3.5 evaluated) | |
| - **Benchmarks**: State-of-the-art on 10+ visual reasoning tasks | |
| - **Model Sizes**: 7B and 13B parameters | |
| - **Resolutions**: 224px and 336px | |
| --- | |
| ## Resources | |
| - **Paper**: [arXiv:2403.16999](https://arxiv.org/abs/2403.16999) | |
| - **Code**: [GitHub](https://github.com/deepcs233/Visual-CoT) | |
| - **Dataset**: [Hugging Face](https://huggingface.co/datasets/deepcs233/Visual-CoT) | |
| - **Project Page**: [https://hao-shao.com/projects/viscot.html](https://hao-shao.com/projects/viscot.html) | |
| - **Models**: | |
| - [VisCoT-7b-224](https://huggingface.co/deepcs233/VisCoT-7b-224) | |
| - [VisCoT-7b-336](https://huggingface.co/deepcs233/VisCoT-7b-336) | |
| - [VisCoT-13b-224](https://huggingface.co/deepcs233/VisCoT-13b-224) | |
| - [VisCoT-13b-336](https://huggingface.co/deepcs233/VisCoT-13b-336) | |
| --- | |
| ## Citation | |
| If you find our work useful, please cite: | |
| ```bibtex | |
| @article{shao2024visual, | |
| title={Visual CoT: Unleashing Chain-of-Thought Reasoning in Multi-Modal Language Models}, | |
| author={Shao, Hao and Qian, Shengju and Xiao, Han and Song, Guanglu and Zong, Zhuofan and Wang, Letian and Liu, Yu and Li, Hongsheng}, | |
| journal={arXiv preprint arXiv:2403.16999}, | |
| year={2024} | |
| } | |
| ``` | |
| --- | |
| ## License | |
| - **Code**: Apache License 2.0 | |
| - **Dataset**: Research use only | |
| - **Models**: Subject to base LLM license (LLaMA) | |
| --- | |
| ## Acknowledgements | |
| This work is built upon: | |
| - [LLaVA](https://github.com/haotian-liu/LLaVA) - Base architecture | |
| - [Shikra](https://github.com/shikras/shikra) - Positional annotations | |
| - [Vicuna](https://github.com/lm-sys/FastChat) - Language model | |
| - [CLIP](https://github.com/openai/CLIP) - Vision encoder | |
| """) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| <div style="text-align: center; color: #666; padding: 20px;"> | |
| <p>Powered by <a href="https://huggingface.co/docs/hub/spaces-zerogpu">Zero GPU</a> on Hugging Face Spaces</p> | |
| </div> | |
| """) | |
| return demo | |
| # ============================================================================= | |
| # Launch | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue(max_size=20) # Enable queue for Zero GPU | |
| demo.launch() | |