File size: 8,974 Bytes
e546039 b7150d9 3b90c99 b7150d9 3b90c99 b7150d9 3b90c99 fca189c 3b90c99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
---
datasets:
- yaolily/GenS-Video-150K
license: other
pipeline_tag: video-text-to-text
library_name: transformers
---
<p align="center">
π <a href="https://generative-sampler.github.io/" target="_blank">Project Page</a> Β· π <a href="https://arxiv.org/abs/2503.09146" target="_blank">Paper</a> Β· β <a href="https://github.com/yaolinli/GenS" target="_blank">GitHub</a> Β· π <a href="https://huggingface.co/datasets/yaolily/GenS-Video-150K" target="_blank">Dataset</a> Β· π€ <a href="https://huggingface.co/yaolily/GenS" target="_blank">Checkpoints</a>
</p>
## Model Description
**GenS** (Generative Frame Sampler) is a novel approach that identifies question-relevant frames from long videos spanning minutes to hours. Given a long video and a user question, GenS effectively searches through the original massive collection of frames to produce a concise selection and enhances the performance of downstream VideoQA Assistants (such as Qwen2-VL, LLaVA-Video, VILA-v1.5, and Aria) by providing fewer but more informative frames.
GenS is built upon advanced long-context VideoLLMs (such as Aria and Qwen2.5VL), transforming key frame sampling into a generative task.
<img src="https://generative-sampler.github.io/static/images/teaser.png" alt="GenS Framework" style="width: 100%;">
## Key Features of GenS
β¨ **Temporal Understanding:**
GenS effectively captures temporal relationships between successive frames, enabling complex reasoning about temporal sequences such as "immediately after" events in videos.
π **Complex Instruction Understanding:**
Powered by built-in LLMs, GenS comprehends complex and flexible textual instructions, allowing it to interpret nuanced queries and identify the most relevant visual content.
β‘ **Effective Video-Text Alignment:**
Its native multi-modal architecture enables sophisticated multi-hop reasoning by seamlessly aligning long-range temporal cues with language semantics, resulting in more accurate frame selection.
π **State-of-the-Art Performance:**
GenS significantly boosts the performance of various VideoQA models, achieving SOTA results on long-form video benchmarks when integrated with open-source models.
## Performance Highlights
- π **LongVideoBench**: LLaVA-Video-72B w/ GenS achieves **66.8** accuracy (+4.3)
- π **MLVU**: LLaVA-Video-72B w/ GenS achieves **77.0** accuracy (+2.7)
- π **HourVideo**: Aria w/ GenS obtains **39.2** accuracy, while Gemini-1.5-pro w/ GenS obtains **40.7** accuracy
<img src="https://generative-sampler.github.io/static/images/table_main.png" alt="Main Results Table" style="width: 100%;">
<img src="https://generative-sampler.github.io/static/images/hourvideo.png" alt="HourVideo Results Table" style="width: 100%;">
## Quick Start
### Installation
After creating your conda environment, install the required dependencies:
```
pip install transformers==4.45.0 accelerate==0.34.1 sentencepiece==0.2.0 torchvision requests torch Pillow
pip install flash-attn --no-build-isolation
```
### Usage
```
import torch
from PIL import Image
import sys
import os
from typing import List
# Import required libraries
from transformers import AutoProcessor, AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM
from yivl.yivl_model_hf import YiVLForConditionalGeneration, YiVLConfig
from yivl.siglip_navit_490 import NaViTProcessor
from yivl.constants import (
DEFAULT_IMAGE_END_TOKEN,
DEFAULT_IMAGE_START_TOKEN,
DEFAULT_IMAGE_TOKEN,
IMAGE_TOKEN_INDEX,
)
from deepseekv1moe.modeling_deepseek import DeepseekConfig, DeepseekForCausalLM
def setup_model():
"""Set up and load the GenS model and its components."""
# Register custom models with the Auto classes
AutoConfig.register("yi_vl", YiVLConfig)
AutoModel.register(YiVLConfig, YiVLForConditionalGeneration)
AutoConfig.register("deepseek", DeepseekConfig)
AutoModelForCausalLM.register(DeepseekConfig, DeepseekForCausalLM)
# Load model from Hugging Face
model_id = "yaolily/GenS"
# Load configuration
config = AutoConfig.from_pretrained(model_id)
# Load model with optimizations
model = AutoModel.from_pretrained(
model_id,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16
).to(torch.device("cuda"))
# Load tokenizer with special token handling
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, trust_remote_code=True)
if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
try:
tokenizer.add_special_tokens({"pad_token": "<unk>"})
if tokenizer.pad_token_id is None:
tokenizer.add_special_tokens({"pad_token": "<mask>"})
except ValueError:
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})
# Initialize the custom image processor
processor = NaViTProcessor(image_max_size=490)
print("GenS Model loaded successfully!")
return model, tokenizer, processor
def gens_frame_sampler(question: str, frame_paths: List[str], model, tokenizer, processor):
"""
Use GenS model to identify and score relevant frames for a video question.
Args:
question: The question to answer about the video
frame_paths: List of paths to video frames
model: Pre-loaded GenS model
tokenizer: Pre-loaded tokenizer
processor: Pre-loaded image processor
Returns:
The model's response with relevance scores for frames
"""
# Load frames as PIL images
frames = []
for path in frame_paths:
try:
img = Image.open(path).convert("RGB")
# Optional: resize images to expected size
if img.width > 490 or img.height > 490:
ratio = min(490/img.width, 490/img.height)
new_size = (int(img.width * ratio), int(img.height * ratio))
img = img.resize(new_size)
frames.append(img)
except Exception as e:
print(f"Error loading image {path}: {e}")
if not frames:
return "Error: No valid frames could be loaded"
# Create prompt
prompt = """Please identify the video frames most relevant to the given question and provide
their timestamps in seconds along with a relevance score. The score should be on a
scale from 1 to 5, where higher scores indicate greater relevance. Return the output
strictly in the following JSON format: {"timestamp": score, ...}."""
# Format the input as expected by the model
frm_placeholders = ["<image1>" for _ in range(len(frames))]
content = "{}Question: {}
{}".format("".join(frm_placeholders), question, prompt)
question_data = [{"role": "user", "content": content}]
# Apply chat template
formatted_question = tokenizer.apply_chat_template(question_data, add_generation_prompt=True, tokenize=False)
# Process the images and text
inputs = processor(
text=[formatted_question],
images=frames,
padding=True,
return_tensors="pt"
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate the response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
temperature=0.0
)
# Decode and extract the relevant part of the response
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
result = response.split("assistant
")[-1].split("<|im_end|>")[0].strip()
return result
# Example usage
if __name__ == "__main__":
# Load model components
model, tokenizer, processor = setup_model()
# Example video frames (replace with your actual paths)
frame_paths = [
"/path/to/video/frames/00001.jpg",
"/path/to/video/frames/00002.jpg",
# Add more frames...
]
# Example question
question = "Which frames show a person opening the door?"
# Get frame relevance scores
result = gens_frame_sampler(question, frame_paths, model, tokenizer, processor)
print(f"Question: {question}")
print(f"Relevant frames with scores: {result}")
```
**Output Format:**
The model returns relevance scores for frames in JSON format
Example output: `{"15": 5, "16": 4, "45-46": 3, ...}` means frame indexing 15 has relevance score 5, frame indexing 16 has relevance score 4, frame indexing 45-46 has relevance score 3, ...
## Citation
If you find our work helpful, please consider citing.
```
@article{yao2025generative,
title={Generative Frame Sampler for Long Video Understanding},
author={Yao, Linli and Wu, Haoning and Ouyang, Kun and Zhang, Yuanxing and Xiong, Caiming and Chen, Bei and Sun, Xu and Li, Junnan},
journal={arXiv preprint arXiv:2503.09146},
year={2025}
}
``` |