Integrate with Transformers v5 and Sentence Transformers v5.4

#2
by tomaarsen HF Staff - opened
1_Pooling/config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "embedding_dimension": 2560,
3
+ "pooling_mode": "lasttoken",
4
+ "include_prompt": true
5
+ }
README.md CHANGED
@@ -4,6 +4,8 @@ license: apache-2.0
4
  base_model:
5
  - Qwen/Qwen3-VL-4B-Instruct
6
  pipeline_tag: visual-document-retrieval
 
 
7
  ---
8
 
9
  # Eager Embed V1
@@ -28,10 +30,67 @@ Compared to multi-vector (ColBERT-like) architectures, eager-embed-v1 offers a s
28
 
29
  ## How to Get Started with the Model
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  Load the model and define a helper function to encode messages:
32
  ```python
33
  import torch
34
- from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
35
  from transformers.utils.import_utils import is_flash_attn_2_available
36
  from qwen_vl_utils import process_vision_info
37
 
@@ -44,12 +103,13 @@ elif torch.backends.mps.is_available():
44
  DTYPE = torch.bfloat16
45
 
46
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
47
- model = Qwen3VLForConditionalGeneration.from_pretrained(
48
  MODEL_NAME,
49
  attn_implementation=(
50
  "flash_attention_2" if is_flash_attn_2_available() else None
51
  ),
52
- dtype=DTYPE
 
53
  ).to(DEVICE).eval()
54
 
55
  # Function to Encode Message
@@ -87,6 +147,7 @@ sim1 = torch.cosine_similarity(encode_message(query), encode_message(text_1))
87
  sim2 = torch.cosine_similarity(encode_message(query), encode_message(text_2))
88
 
89
  print("Similarities:", sim1.item(), sim2.item())
 
90
  ```
91
 
92
  📈 Image Document Retrieval (Image, Chart, PDF)
@@ -103,6 +164,7 @@ sim1 = torch.cosine_similarity(encode_message(query), encode_message(image_1))
103
  sim2 = torch.cosine_similarity(encode_message(query), encode_message(image_2))
104
 
105
  print("Similarities:", sim1.item(), sim2.item())
 
106
  ```
107
 
108
  ## Training Details
 
4
  base_model:
5
  - Qwen/Qwen3-VL-4B-Instruct
6
  pipeline_tag: visual-document-retrieval
7
+ tags:
8
+ - sentence-transformers
9
  ---
10
 
11
  # Eager Embed V1
 
30
 
31
  ## How to Get Started with the Model
32
 
33
+ ### Using Sentence Transformers
34
+
35
+ Install Sentence Transformers:
36
+ ```bash
37
+ pip install sentence_transformers
38
+ ```
39
+
40
+ ```python
41
+ import requests
42
+ from io import BytesIO
43
+ from PIL import Image
44
+ from sentence_transformers import SentenceTransformer
45
+
46
+ model = SentenceTransformer("eagerworks/eager-embed-v1", trust_remote_code=True)
47
+
48
+ # Multilingual text retrieval
49
+ # `encode_query` automatically prepends the "Query: " prefix the model was trained on.
50
+ queries = ["What is the capital city of Uruguay?"]
51
+ documents = [
52
+ "Montevideo es la capital y la ciudad más poblada de la República Oriental del Uruguay, así como la capital del departamento homónimo",
53
+ "El río Uruguay es un río internacional que forma parte de la cuenca del Plata. Nace en Brasil, recorre unos 1.800 km y desemboca en el Río de la Plata",
54
+ ]
55
+
56
+ query_embeddings = model.encode_query(queries)
57
+ document_embeddings = model.encode_document(documents)
58
+ print(query_embeddings.shape, document_embeddings.shape)
59
+ # (1, 2560) (2, 2560)
60
+
61
+ similarities = model.similarity(query_embeddings, document_embeddings)
62
+ print(similarities)
63
+ # tensor([[0.2907, 0.1573]])
64
+
65
+ # Image document retrieval
66
+ MAX_IMAGE_SIZE = 784
67
+
68
+ def fetch_image(url):
69
+ img = Image.open(BytesIO(requests.get(url).content)).convert("RGB")
70
+ return img.resize((MAX_IMAGE_SIZE, MAX_IMAGE_SIZE))
71
+
72
+ queries = ["Where can we find the animal llama?"]
73
+ documents = [
74
+ fetch_image("https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/animal-llama.png"),
75
+ fetch_image("https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/meta-llama.png"),
76
+ ]
77
+
78
+ query_embeddings = model.encode_query(queries)
79
+ document_embeddings = model.encode_document(documents)
80
+ print(query_embeddings.shape, document_embeddings.shape)
81
+ # (1, 2560) (2, 2560)
82
+
83
+ similarities = model.similarity(query_embeddings, document_embeddings)
84
+ print(similarities)
85
+ # tensor([[0.2709, 0.0930]])
86
+ ```
87
+
88
+ ### Using transformers
89
+
90
  Load the model and define a helper function to encode messages:
91
  ```python
92
  import torch
93
+ from transformers import AutoProcessor, AutoModelForImageTextToText
94
  from transformers.utils.import_utils import is_flash_attn_2_available
95
  from qwen_vl_utils import process_vision_info
96
 
 
103
  DTYPE = torch.bfloat16
104
 
105
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
106
+ model = AutoModelForImageTextToText.from_pretrained(
107
  MODEL_NAME,
108
  attn_implementation=(
109
  "flash_attention_2" if is_flash_attn_2_available() else None
110
  ),
111
+ dtype=DTYPE,
112
+ trust_remote_code=True,
113
  ).to(DEVICE).eval()
114
 
115
  # Function to Encode Message
 
147
  sim2 = torch.cosine_similarity(encode_message(query), encode_message(text_2))
148
 
149
  print("Similarities:", sim1.item(), sim2.item())
150
+ # Similarities: 0.2907 0.1573
151
  ```
152
 
153
  📈 Image Document Retrieval (Image, Chart, PDF)
 
164
  sim2 = torch.cosine_similarity(encode_message(query), encode_message(image_2))
165
 
166
  print("Similarities:", sim1.item(), sim2.item())
167
+ # Similarities: 0.2709 0.0929
168
  ```
169
 
170
  ## Training Details
chat_template.jinja CHANGED
@@ -18,26 +18,18 @@
18
  {{- tool | tojson }}
19
  {%- endfor %}
20
  {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
21
- {%- else %}
22
- {%- if messages[0].role == 'system' %}
23
- {{- '<|im_start|>system\n' }}
24
- {%- if messages[0].content is string %}
25
- {{- messages[0].content }}
26
- {%- else %}
27
- {%- for content in messages[0].content %}
28
- {%- if 'text' in content %}
29
- {{- content.text }}
30
- {%- endif %}
31
- {%- endfor %}
32
- {%- endif %}
33
- {{- '<|im_end|>\n' }}
34
- {%- endif %}
35
  {%- endif %}
36
  {%- set image_count = namespace(value=0) %}
37
  {%- set video_count = namespace(value=0) %}
38
  {%- for message in messages %}
39
- {%- if message.role == "user" %}
40
- {{- '<|im_start|>' + message.role + '\n' }}
 
 
41
  {%- if message.content is string %}
42
  {{- message.content }}
43
  {%- else %}
@@ -118,3 +110,6 @@
118
  {%- if add_generation_prompt %}
119
  {{- '<|im_start|>assistant\n' }}
120
  {%- endif %}
 
 
 
 
18
  {{- tool | tojson }}
19
  {%- endfor %}
20
  {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
21
+ {%- endif %}
22
+ {%- set sys_prefix = '' %}
23
+ {%- if not tools and messages[0].role == 'system' %}
24
+ {%- set sys_prefix = messages[0].content[0].text %}
 
 
 
 
 
 
 
 
 
 
25
  {%- endif %}
26
  {%- set image_count = namespace(value=0) %}
27
  {%- set video_count = namespace(value=0) %}
28
  {%- for message in messages %}
29
+ {%- if message.role == "system" and not tools %}
30
+ {# system text is inlined into the user message below #}
31
+ {%- elif message.role == "user" %}
32
+ {{- '<|im_start|>' + message.role + '\n' + sys_prefix }}
33
  {%- if message.content is string %}
34
  {{- message.content }}
35
  {%- else %}
 
110
  {%- if add_generation_prompt %}
111
  {{- '<|im_start|>assistant\n' }}
112
  {%- endif %}
113
+ {%- if add_embedding_token %}
114
+ {{- '<|endoftext|>' }}
115
+ {%- endif %}
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "Qwen3VLForConditionalGeneration"
4
  ],
 
 
 
 
5
  "dtype": "float32",
6
  "image_token_id": 151655,
7
  "model_type": "qwen3_vl",
 
2
  "architectures": [
3
  "Qwen3VLForConditionalGeneration"
4
  ],
5
+ "auto_map": {
6
+ "AutoModel": "modeling_eager_embed.EagerEmbedModel",
7
+ "AutoModelForImageTextToText": "modeling_eager_embed.EagerEmbedForConditionalGeneration"
8
+ },
9
  "dtype": "float32",
10
  "image_token_id": 151655,
11
  "model_type": "qwen3_vl",
config_sentence_transformers.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "pytorch": "2.10.0+cu128",
4
+ "sentence_transformers": "5.4.0",
5
+ "transformers": "5.5.0"
6
+ },
7
+ "model_type": "SentenceTransformer",
8
+ "prompts": {
9
+ "query": "Query: ",
10
+ "document": ""
11
+ },
12
+ "similarity_fn_name": "cosine"
13
+ }
modeling_eager_embed.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import (
4
+ Qwen3VLForConditionalGeneration,
5
+ Qwen3VLModel,
6
+ )
7
+
8
+
9
+ # The model was trained with transformers==4.57.1, where
10
+ # `Qwen3VLForConditionalGeneration(...).hidden_states[-1]` was the pre-final-norm
11
+ # state of the text decoder. In transformers 5.x that field is now the post-norm
12
+ # `last_hidden_state`. Replacing the text model's final RMSNorm with a no-op
13
+ # restores the representation the model was trained on.
14
+ _NORM_KEY_PATTERN = r"^model\.language_model\.norm\.weight$"
15
+
16
+
17
+ class EagerEmbedModel(Qwen3VLModel):
18
+ _keys_to_ignore_on_load_unexpected = [_NORM_KEY_PATTERN]
19
+
20
+ def __init__(self, config):
21
+ super().__init__(config)
22
+ self.language_model.norm = nn.Identity()
23
+
24
+
25
+ class EagerEmbedForConditionalGeneration(Qwen3VLForConditionalGeneration):
26
+ _keys_to_ignore_on_load_unexpected = [_NORM_KEY_PATTERN]
27
+
28
+ def __init__(self, config):
29
+ super().__init__(config)
30
+ self.model.language_model.norm = nn.Identity()
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.base.modules.transformer.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.sentence_transformer.modules.pooling.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.sentence_transformer.modules.normalize.Normalize"
19
+ }
20
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "transformer_task": "feature-extraction",
3
+ "modality_config": {
4
+ "text": {
5
+ "method": "forward",
6
+ "method_output_name": "last_hidden_state"
7
+ },
8
+ "image": {
9
+ "method": "forward",
10
+ "method_output_name": "last_hidden_state"
11
+ },
12
+ "video": {
13
+ "method": "forward",
14
+ "method_output_name": "last_hidden_state"
15
+ },
16
+ "message": {
17
+ "method": "forward",
18
+ "method_output_name": "last_hidden_state",
19
+ "format": "structured"
20
+ }
21
+ },
22
+ "module_output_name": "token_embeddings",
23
+ "processing_kwargs": {
24
+ "chat_template": {
25
+ "add_generation_prompt": true,
26
+ "add_embedding_token": true
27
+ }
28
+ },
29
+ "unpad_inputs": false
30
+ }