File size: 8,974 Bytes
e546039
 
 
b7150d9
 
 
3b90c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7150d9
 
3b90c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7150d9
 
3b90c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fca189c
3b90c99
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
---
datasets:
- yaolily/GenS-Video-150K
license: other
pipeline_tag: video-text-to-text
library_name: transformers
---

<p align="center">
πŸ”— <a href="https://generative-sampler.github.io/" target="_blank">Project Page</a> Β· πŸ“– <a href="https://arxiv.org/abs/2503.09146" target="_blank">Paper</a> Β· ⭐ <a href="https://github.com/yaolinli/GenS" target="_blank">GitHub</a> Β· πŸ“Š <a href="https://huggingface.co/datasets/yaolily/GenS-Video-150K" target="_blank">Dataset</a> Β· πŸ€— <a href="https://huggingface.co/yaolily/GenS" target="_blank">Checkpoints</a>
</p>

## Model Description

**GenS** (Generative Frame Sampler) is a novel approach that identifies question-relevant frames from long videos spanning minutes to hours. Given a long video and a user question, GenS effectively searches through the original massive collection of frames to produce a concise selection and enhances the performance of downstream VideoQA Assistants (such as Qwen2-VL, LLaVA-Video, VILA-v1.5, and Aria) by providing fewer but more informative frames.

GenS is built upon advanced long-context VideoLLMs (such as Aria and Qwen2.5VL), transforming key frame sampling into a generative task.

<img src="https://generative-sampler.github.io/static/images/teaser.png" alt="GenS Framework" style="width: 100%;">

## Key Features of GenS

✨ **Temporal Understanding:**
GenS effectively captures temporal relationships between successive frames, enabling complex reasoning about temporal sequences such as "immediately after" events in videos.

πŸ“ **Complex Instruction Understanding:**
Powered by built-in LLMs, GenS comprehends complex and flexible textual instructions, allowing it to interpret nuanced queries and identify the most relevant visual content.

⚑ **Effective Video-Text Alignment:**
Its native multi-modal architecture enables sophisticated multi-hop reasoning by seamlessly aligning long-range temporal cues with language semantics, resulting in more accurate frame selection.

πŸŽ‰ **State-of-the-Art Performance:**
GenS significantly boosts the performance of various VideoQA models, achieving SOTA results on long-form video benchmarks when integrated with open-source models.

## Performance Highlights
-   πŸ† **LongVideoBench**: LLaVA-Video-72B w/ GenS achieves **66.8** accuracy (+4.3)
-   πŸ† **MLVU**: LLaVA-Video-72B w/ GenS achieves **77.0** accuracy (+2.7)
-   πŸ† **HourVideo**: Aria w/ GenS obtains **39.2** accuracy, while Gemini-1.5-pro w/ GenS obtains **40.7** accuracy


<img src="https://generative-sampler.github.io/static/images/table_main.png" alt="Main Results Table" style="width: 100%;">
<img src="https://generative-sampler.github.io/static/images/hourvideo.png" alt="HourVideo Results Table" style="width: 100%;">

## Quick Start

### Installation
After creating your conda environment, install the required dependencies:
```
pip install transformers==4.45.0 accelerate==0.34.1 sentencepiece==0.2.0 torchvision requests torch Pillow
pip install flash-attn --no-build-isolation
```

### Usage

```
import torch
from PIL import Image
import sys
import os
from typing import List

# Import required libraries
from transformers import AutoProcessor, AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM
from yivl.yivl_model_hf import YiVLForConditionalGeneration, YiVLConfig
from yivl.siglip_navit_490 import NaViTProcessor
from yivl.constants import (
    DEFAULT_IMAGE_END_TOKEN,
    DEFAULT_IMAGE_START_TOKEN,
    DEFAULT_IMAGE_TOKEN,
    IMAGE_TOKEN_INDEX,
)
from deepseekv1moe.modeling_deepseek import DeepseekConfig, DeepseekForCausalLM


def setup_model():
    """Set up and load the GenS model and its components."""
    
    # Register custom models with the Auto classes
    AutoConfig.register("yi_vl", YiVLConfig)
    AutoModel.register(YiVLConfig, YiVLForConditionalGeneration)
    AutoConfig.register("deepseek", DeepseekConfig)
    AutoModelForCausalLM.register(DeepseekConfig, DeepseekForCausalLM)
    
    # Load model from Hugging Face
    model_id = "yaolily/GenS"
    
    # Load configuration
    config = AutoConfig.from_pretrained(model_id)
    
    # Load model with optimizations
    model = AutoModel.from_pretrained(
        model_id,
        attn_implementation="flash_attention_2",
        low_cpu_mem_usage=True,
        torch_dtype=torch.bfloat16
    ).to(torch.device("cuda"))
    
    # Load tokenizer with special token handling
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, trust_remote_code=True)
    if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
        try:
            tokenizer.add_special_tokens({"pad_token": "<unk>"})
            if tokenizer.pad_token_id is None:
                tokenizer.add_special_tokens({"pad_token": "<mask>"})
        except ValueError:
            tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})
    
    # Initialize the custom image processor
    processor = NaViTProcessor(image_max_size=490)
    
    print("GenS Model loaded successfully!")
    return model, tokenizer, processor


def gens_frame_sampler(question: str, frame_paths: List[str], model, tokenizer, processor):
    """
    Use GenS model to identify and score relevant frames for a video question.
    
    Args:
        question: The question to answer about the video
        frame_paths: List of paths to video frames
        model: Pre-loaded GenS model
        tokenizer: Pre-loaded tokenizer
        processor: Pre-loaded image processor
        
    Returns:
        The model's response with relevance scores for frames
    """
    # Load frames as PIL images
    frames = []
    for path in frame_paths:
        try:
            img = Image.open(path).convert("RGB")
            # Optional: resize images to expected size
            if img.width > 490 or img.height > 490:
                ratio = min(490/img.width, 490/img.height)
                new_size = (int(img.width * ratio), int(img.height * ratio))
                img = img.resize(new_size)
            frames.append(img)
        except Exception as e:
            print(f"Error loading image {path}: {e}")
    
    if not frames:
        return "Error: No valid frames could be loaded"
    
    # Create prompt
    prompt = """Please identify the video frames most relevant to the given question and provide 
              their timestamps in seconds along with a relevance score. The score should be on a 
              scale from 1 to 5, where higher scores indicate greater relevance. Return the output 
              strictly in the following JSON format: {"timestamp": score, ...}."""
    
    # Format the input as expected by the model
    frm_placeholders = ["<image1>" for _ in range(len(frames))]
    content = "{}Question: {}
{}".format("".join(frm_placeholders), question, prompt)
    question_data = [{"role": "user", "content": content}]
    
    # Apply chat template
    formatted_question = tokenizer.apply_chat_template(question_data, add_generation_prompt=True, tokenize=False)
    
    # Process the images and text
    inputs = processor(
        text=[formatted_question],
        images=frames,
        padding=True,
        return_tensors="pt"
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate the response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=False,
            temperature=0.0
        )
    
    # Decode and extract the relevant part of the response
    response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    result = response.split("assistant
")[-1].split("<|im_end|>")[0].strip()
    
    return result


# Example usage
if __name__ == "__main__":
    # Load model components
    model, tokenizer, processor = setup_model()
    
    # Example video frames (replace with your actual paths)
    frame_paths = [
        "/path/to/video/frames/00001.jpg", 
        "/path/to/video/frames/00002.jpg",
        # Add more frames...
    ]
    
    # Example question
    question = "Which frames show a person opening the door?"
    
    # Get frame relevance scores
    result = gens_frame_sampler(question, frame_paths, model, tokenizer, processor)
    
    print(f"Question: {question}")
    print(f"Relevant frames with scores: {result}")
```
**Output Format:**
The model returns relevance scores for frames in JSON format
Example output: `{"15": 5, "16": 4, "45-46": 3, ...}` means frame indexing 15 has relevance score 5, frame indexing 16 has relevance score 4, frame indexing 45-46 has relevance score 3, ...



## Citation
If you find our work helpful, please consider citing.
```
@article{yao2025generative,
    title={Generative Frame Sampler for Long Video Understanding},
    author={Yao, Linli and Wu, Haoning and Ouyang, Kun and Zhang, Yuanxing and Xiong, Caiming and Chen, Bei and Sun, Xu and Li, Junnan},
    journal={arXiv preprint arXiv:2503.09146},
    year={2025}
}
```