|
|
--- |
|
|
language: |
|
|
- en |
|
|
library_name: transformers |
|
|
license: apache-2.0 |
|
|
pipeline_tag: image-text-to-text |
|
|
tags: |
|
|
- Sentence Similarity |
|
|
- Embedding |
|
|
- zero-shot-image-classification |
|
|
- video-text-to-text |
|
|
--- |
|
|
|
|
|
# UME-R1-7B |
|
|
|
|
|
## Model Summary |
|
|
|
|
|
The model has undergone a cold-start SFT stage and an RL stage of training, and is capable of embedding text, images, multiple images, and videos. In particular, UME-R1 can generate either discriminative or generative embeddings as needed, and the generative embeddings possess the potential for test-time scaling. |
|
|
|
|
|
- **Repository:** [UME-R1](https://github.com/XMUDeepLIT/UME-R1) |
|
|
- **Paper:** [UME-R1](https://arxiv.org/abs/2511.00405) |
|
|
|
|
|
## Train/Eval Data |
|
|
- Train data: https://huggingface.co/datasets/zhibinlan/UME-sft-train |
|
|
- Eval data: https://huggingface.co/datasets/TIGER-Lab/MMEB-V2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
## Model Performance |
|
|
UME-R1 significantly outperforms discriminative embeddings and can provide discriminative or generative representations as needed. Its oracle performance—selecting the best between discriminative and generative—far exceeds using either mode alone. |
|
|
|
|
|
|
|
|
<img src="./figures/main_result.png" alt="MMEB-V2" width="1200" height="auto"> |
|
|
<!--  --> |
|
|
|
|
|
In addition, UME-R1 can produce improved embedding representations through repeated sampling, indicating that generative embeddings also hold strong promise for inference-time scaling. |
|
|
|
|
|
<img src="./figures/scaling.png" alt="pass@k" width="1200" height="auto"> |
|
|
|
|
|
### Quick Start |
|
|
|
|
|
First clone our github |
|
|
```bash |
|
|
git clone https://github.com/DeepLearnXMU/UME-R1 |
|
|
cd UME-R1 |
|
|
bash setup.sh |
|
|
``` |
|
|
|
|
|
Below, we provide simple examples to show how to use UME-R1 with 🤗 Transformers. |
|
|
|
|
|
Example of obtaining generative embeddings: |
|
|
|
|
|
```python |
|
|
from transformers import Qwen2VLForConditionalGeneration,AutoProcessor |
|
|
from qwen_vl_utils import process_vision_info |
|
|
import torch |
|
|
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
"zhibinlan/UME-R1-7B", |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="flash_attention_2", |
|
|
device_map="cuda:0", |
|
|
) |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("zhibinlan/UME-R1-7B") |
|
|
|
|
|
prompt = '''Represent the above input text, images, videos, or any combination of the three as embeddings. |
|
|
First output the thinking process in <think> </think> tags and then summarize the entire input in a word or sentence. |
|
|
Finally, use the <gen_emb> tag to represent the entire input.''' |
|
|
|
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": "assets/example.jpg", |
|
|
}, |
|
|
{"type": "text", "text": "Represent the given image with the following question: What is in the image?\n<disc_emb>\n" + prompt}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
# Preparation for inference |
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
inputs = inputs.to(model.device) |
|
|
|
|
|
# Inference: Generation of the output |
|
|
generated_output = model.generate(**inputs, max_new_tokens=8192, output_hidden_states=True, return_dict_in_generate=True, use_cache=True) |
|
|
# Post-process the output |
|
|
generated_ids = generated_output.sequences |
|
|
hidden_states = generated_output.hidden_states |
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
|
|
|
def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID): |
|
|
|
|
|
embedding_idx = [] |
|
|
for i, out_ids in enumerate(generated_ids_trimmed): |
|
|
embed_exist = False |
|
|
for j in range(len(out_ids) - 1, -1, -1): |
|
|
if out_ids[j] == EMBEDDING_TOKEN_ID: |
|
|
embedding_idx.append(j + 1) |
|
|
embed_exist = True |
|
|
break |
|
|
if not embed_exist: |
|
|
embedding_idx.append(-1) |
|
|
|
|
|
return embedding_idx |
|
|
|
|
|
def normalize_reps(reps): |
|
|
reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
|
|
return reps |
|
|
|
|
|
# Get the last hidden state of the <gen_emb> token |
|
|
embedding_idx = get_embedding_idx(generated_ids_trimmed, processor.tokenizer.get_vocab()["<gen_emb>"]) |
|
|
embedding_reps = hidden_states[embedding_idx[0]][-1].squeeze(1) |
|
|
|
|
|
# Normalize the representations |
|
|
embedding_reps = normalize_reps(embedding_reps) |
|
|
|
|
|
output_text = processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False |
|
|
) |
|
|
``` |
|
|
|
|
|
<details> |
|
|
<summary>Example of obtaining discriminative embeddings</summary> |
|
|
|
|
|
```python |
|
|
from transformers import Qwen2VLForConditionalGeneration,AutoProcessor |
|
|
from qwen_vl_utils import process_vision_info |
|
|
import torch |
|
|
|
|
|
pretrained_path = "zhibinlan/UME-R1-7B" |
|
|
|
|
|
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. |
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
pretrained_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="flash_attention_2", |
|
|
device_map="cuda:0", |
|
|
) |
|
|
|
|
|
# default processor |
|
|
processor = AutoProcessor.from_pretrained(pretrained_path) |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": "UME-R1/assets/example.jpg", |
|
|
}, |
|
|
{"type": "text", "text": "Represent the given image with the following question: What is in the image?\n<disc_emb>\n"}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
# Preparation for inference |
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
inputs = inputs.to(model.device) |
|
|
|
|
|
def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID): |
|
|
|
|
|
embedding_idx = [] |
|
|
# Search from the last token forward |
|
|
for i, out_ids in enumerate(generated_ids_trimmed): |
|
|
embed_exist = False |
|
|
for j in range(len(out_ids) - 1, -1, -1): |
|
|
if out_ids[j] == EMBEDDING_TOKEN_ID: |
|
|
embedding_idx.append(j) |
|
|
embed_exist = True |
|
|
break |
|
|
if not embed_exist: |
|
|
embedding_idx.append(-1) |
|
|
|
|
|
return embedding_idx |
|
|
|
|
|
def normalize_reps(reps): |
|
|
# Normalize the representations |
|
|
reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
|
|
return reps |
|
|
|
|
|
output = model(**inputs, output_hidden_states=True, return_dict=True) |
|
|
hidden_states = output.hidden_states[-1][0] |
|
|
# print("output.hidden_states shape: ", hidden_states.shape) |
|
|
embedding_idx = get_embedding_idx(inputs['input_ids'], processor.tokenizer.get_vocab()["<disc_emb>"]) |
|
|
|
|
|
# Get the last hidden state of the <gen_emb> token |
|
|
embedding_reps = hidden_states[embedding_idx[0]] |
|
|
|
|
|
# Normalize the representations |
|
|
embedding_reps = normalize_reps(embedding_reps) |
|
|
``` |
|
|
|
|
|
</details> |
|
|
|
|
|
<details> |
|
|
<summary>Multi image inference</summary> |
|
|
|
|
|
```python |
|
|
# Messages containing multiple images and a text query |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": "file:///path/to/image1.jpg"}, |
|
|
{"type": "image", "image": "file:///path/to/image2.jpg"}, |
|
|
{"type": "text", "text": "Represent the given images."}, |
|
|
], |
|
|
} |
|
|
] |
|
|
``` |
|
|
|
|
|
</details> |
|
|
|
|
|
<details> |
|
|
<summary>Video inference</summary> |
|
|
|
|
|
```python |
|
|
# Messages containing a images list as a video and a text query |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "video", |
|
|
"video": [ |
|
|
"file:///path/to/frame1.jpg", |
|
|
"file:///path/to/frame2.jpg", |
|
|
"file:///path/to/frame3.jpg", |
|
|
"file:///path/to/frame4.jpg", |
|
|
], |
|
|
}, |
|
|
{"type": "text", "text": "Represent this video."}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
# Messages containing a local video path and a text query |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "video", |
|
|
"video": "file:///path/to/video1.mp4", |
|
|
"max_pixels": 360 * 420, |
|
|
"fps": 1.0, |
|
|
}, |
|
|
{"type": "text", "text": "Represent this video."}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
# Messages containing a video url and a text query |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "video", |
|
|
"video": "https://path/to/video.mp4", |
|
|
"min_pixels": 4 * 28 * 28, |
|
|
"max_pixels": 256 * 28 * 28, |
|
|
"total_pixels": 20480 * 28 * 28, |
|
|
}, |
|
|
{"type": "text", "text": "Represent this video."}, |
|
|
], |
|
|
} |
|
|
] |
|
|
image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True) |
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
fps=fps, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
**video_kwargs, |
|
|
) |
|
|
``` |
|
|
|
|
|
</details> |
|
|
|
|
|
|
|
|
For more usage tips, please refer to our [Github page](https://github.com/DeepLearnXMU/UME-R1). |
|
|
|
|
|
|
|
|
## Citation |
|
|
If you find our work useful, please consider citing it. |
|
|
``` |
|
|
@article{lan2025ume, |
|
|
title={UME-R1: Exploring Reasoning-Driven Generative Multimodal Embeddings}, |
|
|
author={Lan, Zhibin and Niu, Liqiang and Meng, Fandong and Zhou, Jie and Su, Jinsong}, |
|
|
journal={arXiv preprint arXiv:2511.00405}, |
|
|
year={2025} |
|
|
} |
|
|
``` |