File size: 8,940 Bytes
4c4ec5c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | ---
license: cc-by-4.0
pipeline_tag: image-text-to-text
library_name: transformers
---
# LEGO: A Model for Multi-View 3D Scene Understanding
This repository contains the official weights for **LEGO**, a baseline method for multi-view reasoning in 3D scene understanding. LEGO leverages knowledge from pre-trained 2D LVLMs (specifically fine-tuning a Fuyu-8B model) and is trained using the **TripAlign** pre-training dataset. It is evaluated on **MV-ScanQA**, a novel 3D question answering dataset designed to rigorously test multi-view compositional reasoning.
LEGO achieves state-of-the-art performance on MV-ScanQA, as well as on existing benchmarks for 3D dense captioning and question answering.
This model was presented in the paper [Advancing 3D Scene Understanding with MV-ScanQA Multi-View Reasoning Evaluation and TripAlign Pre-training Dataset](https://huggingface.co/papers/2508.11058).
- 🏠 [Project Page](https://matthewdm0816.github.io/tripalign-mvscanqa)
- 💻 [GitHub Repository](https://github.com/matthewdm0816/MV-ScanQA-TripAlign)
<div align="center">
<img src="https://raw.githubusercontent.com/matthewdm0816/MV-ScanQA-TripAlign/main/docs/teasor-mm-lego.svg" alt="LEGO Teaser Image" width="70%"/>
</div>
## Overview of LEGO, MV-ScanQA, and TripAlign
The **MV-ScanQA** dataset addresses limitations in existing 3D vision-language datasets by introducing questions that explicitly require integrating information from multiple views, thus rigorously testing multi-view compositional reasoning over distant objects.
To facilitate training for such demanding scenarios, the **TripAlign** dataset is introduced. This large-scale, low-cost 2D-3D-language pre-training corpus contains 1M `<2D view, set of 3D objects, text>` triplets, providing richer, view-grounded multi-object multimodal alignment signals than previous single-object annotations.
**LEGO** (Large-scale Multi-View Grounding Objective) is the baseline method developed to tackle the multi-view reasoning challenge in MV-ScanQA. It transfers knowledge from pre-trained 2D LVLMs (like Fuyu-8B, which this model fine-tunes) to the 3D domain with TripAlign.
## Usage
This model is a PEFT (Parameter-Efficient Fine-Tuning) LoRA adapter built on top of the `adept/fuyu-8b` base model. You can load and use it with the `transformers` and `peft` libraries.
First, ensure you have the necessary libraries installed:
```bash
pip install transformers accelerate peft torch torchvision pillow
```
Below is a sample code for inference. Please note that the image pre-processing functions (`build_transform`, `find_closest_aspect_ratio`, `dynamic_preprocess`, `load_image`) are adapted from the original repository's usage patterns for Fuyu-based models.
```python
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
# Define the base model and the LoRA adapter ID
base_model_name_or_path = "adept/fuyu-8b"
# Replace 'your-org/your-repo' with the actual model ID on Hugging Face Hub
peft_model_id = "your-org/your-repo" # e.g., kmichiru/LEGO
# Load the base model
print(f"Loading base model: {base_model_name_or_path}...")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
device_map="auto" # Use 'auto' to load across available devices
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, trust_remote_code=True, use_fast=False)
# Load the PEFT adapter weights on top of the base model
print(f"Loading LoRA adapter: {peft_model_id}...")
model = PeftModel.from_pretrained(base_model, peft_model_id).eval()
print("Model loaded successfully!")
# Example usage (replace with your image path and question)
# You might need to download a sample image, e.g., from the GitHub repo
# A dummy image for testing:
# from PIL import ImageDraw
# dummy_image = Image.new('RGB', (800, 600), color = 'red')
# draw = ImageDraw.Draw(dummy_image)
# draw.text((10,10), "Sample Image", fill=(0,0,0))
# dummy_image.save("sample_image.png")
image_path = "sample_image.png" # Replace with path to a real image
if not Path(image_path).exists():
print(f"Warning: Image '{image_path}' not found. Please provide a valid image path or create a dummy image.")
# Exit or handle gracefully if no image is available for execution
exit()
pixel_values = load_image(image_path, max_num=6).to(torch.bfloat16).cuda() # Ensure image is on GPU
generation_config = dict(max_new_tokens=1024, do_sample=True)
question = "Describe the main objects in this 3D scene." # Example question
# For a Fuyu model, the prompt format might be specific. Refer to Fuyu documentation.
# This example uses a basic chat format.
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
print(f'User: {question}
Assistant: {response}')
# Example for 3D question answering (assuming the model outputs bounding box coordinates)
question_with_bbox = "What is the bounding box of the chair in this scene?"
response_bbox, history_bbox = model.chat(tokenizer, pixel_values, question_with_bbox, generation_config, history=None, return_history=True)
print(f'User: {question_with_bbox}
Assistant: {response_bbox}')
```
## Citation
If you find this codebase useful, please consider citing our work:
```bibtex
@inproceedings{mo2025mvscanqa,
title={Advancing 3D Scene Understanding with MV-ScanQA Multi-View Reasoning Evaluation and TripAlign Pre-training Dataset},
author={Mo, Wentao and Chen, QingChao and Peng, Yuxin and Huang, Siyuan and Liu, Yang},
booktitle={Proceedings of the 33rd ACM International Conference on Multimedia},
year={2025},
}
```
## License
This code repository and datasets are licensed under a [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/) license.
Copyright (c) 2025 Wentao Mo. |