nielsr HF Staff commited on
Commit
d312f24
·
verified ·
1 Parent(s): d6bfff1

Add pipeline tag, paper link, and improve documentation

Browse files

Hi, I'm Niels from the Hugging Face community science team. This PR improves the model card for EmbeddingRWKV by:
- Adding the `text-retrieval` pipeline tag for better discoverability.
- Linking the research paper: [EmbeddingRWKV: State-Centric Retrieval with Reusable States](https://huggingface.co/papers/2601.07861).
- Linking the official [GitHub repository](https://github.com/howard-hou/EmbeddingRWKV).
- Updating usage snippets to match the code found in the project's repository.

Files changed (1) hide show
  1. README.md +37 -65
README.md CHANGED
@@ -1,9 +1,15 @@
1
  ---
2
  license: apache-2.0
 
3
  ---
 
4
  # EmbeddingRWKV
5
 
6
- A high-efficiency text embedding and reranking model based on RWKV architecture.
 
 
 
 
7
 
8
  ## 📦 Installation
9
 
@@ -48,7 +54,7 @@ documents = [
48
  ]
49
  # Encode Query
50
  q_tokens = tokenizer.encode(query, add_eos=True)
51
- q_emb, _ = emb_model.forward(q_tokens, None) # shape: [1, Dim]
52
 
53
  # Encode Documents (Batch)
54
  doc_batch = [tokenizer.encode(doc, add_eos=True) for doc in documents]
@@ -58,17 +64,16 @@ for i in range(len(doc_batch)):
58
  # Prepend 0s (Left Padding)
59
  doc_batch[i] = [0] * pad_len + doc_batch[i]
60
 
61
- d_embs, _ = emb_model.forward(doc_batch, None)
62
 
63
  # Calculate Cosine Similarity
64
  scores_emb = F.cosine_similarity(q_emb, d_embs)
65
- print("\nEmbeddingRWKV Cosine Similarity:")
 
66
  for doc, score in zip(documents, scores_emb):
67
  print(f"[{score.item():.4f}] {doc}")
68
  ```
69
 
70
- For production use cases, running inference in batches is significantly faster.
71
-
72
  ### ⚠️ Critical Performance Tip: Pad to Same Length
73
 
74
  While the model supports batches with variable sequence lengths, **we strongly recommend padding all sequences to the same length** for maximum GPU throughput.
@@ -81,14 +86,7 @@ While the model supports batches with variable sequence lengths, **we strongly r
81
 
82
  The `RWKVReRanker` utilizes the final hidden state produced by the main `EmbeddingRWKV` model to score the relevance between a query and a document.
83
 
84
- ### Online Mode
85
-
86
- #### Workflow
87
- 1. **Format** Query and Document based on Online template.
88
- 2. Run the **Embedding Model** to generate the final State.
89
- 3. Feed the **TimeMixing State** (`state[1]`) into the **ReRanker** to get a relevance score.
90
-
91
- #### 📝 Online Mode Usage Example
92
 
93
  ```python
94
  import torch
@@ -96,7 +94,6 @@ from rwkv_emb.tokenizer import RWKVTokenizer
96
  from rwkv_emb.model import EmbeddingRWKV, RWKVReRanker
97
 
98
  # 1. Load Models
99
- # The ReRanker weights are stored in the differernt checkpoint
100
  emb_model = EmbeddingRWKV(model_path='/path/to/EmbeddingRWKV.pth')
101
  reranker = RWKVReRanker(model_path='/path/to/RWKVReRanker.pth')
102
 
@@ -111,34 +108,30 @@ documents = [
111
  ]
112
 
113
  # 3. Construct Input Pairs
114
- # We treat the Query and Document as a single sequence.
115
  pairs = []
116
- online_template = "Instruct: Given a query, retrieve documents that answer the query\nDocument: {document}\nQuery: {query}"
 
 
117
  for doc in documents:
118
- # Format: Instruct + Document + Query
119
  text = online_template.format(document=doc, query=query)
120
  pairs.append(text)
121
 
122
- # 4. Tokenize & Pad (Critical for Batch Performance)
123
  batch_tokens = [tokenizer.encode(p, add_eos=True) for p in pairs]
124
-
125
- # Left pad to same length for efficiency
126
  max_len = max(len(t) for t in batch_tokens)
127
  for i in range(len(batch_tokens)):
128
  batch_tokens[i] = [0] * (max_len - len(batch_tokens[i])) + batch_tokens[i]
129
 
130
  # 5. Get States from Embedding Model
131
- # We don't need the embedding output here, we only need the final 'state'
132
  _, state = emb_model.forward(batch_tokens, None)
133
 
134
  # 6. Score with ReRanker
135
- # The ReRanker expects the TimeMixing State: state[1]
136
- # state[1] shape: [Layers, Batch, Heads, HeadSize, HeadSize]
137
  logits = reranker.forward(state[1])
138
- scores = torch.sigmoid(logits) # Convert logits to probabilities (0-1)
139
 
140
  # 7. Print Results
141
- print("\nRWKVReRanker Online Scores:")
 
142
  for doc, score in zip(documents, scores):
143
  print(f"[{score:.4f}] {doc}")
144
  ```
@@ -146,71 +139,39 @@ for doc, score in zip(documents, scores):
146
  ### Offline Mode (Cached Doc State)
147
  For scenarios where documents are static but queries change (e.g., Search Engines, RAG), you can **pre-compute and cache the document states**. This reduces query-time latency from O(L_doc + L_query) to just O(L_query).
148
 
149
- #### Workflow
150
-
151
- 1. **Indexing**: Process `Instruct + Document` -\> Save State.
152
- 2. **Querying**: Load State -\> Process `Query` -\> Score.
153
-
154
- #### 📝 Offline Mode Usage Example
155
-
156
  ```python
157
  # --- Phase 1: Indexing (Pre-computation) ---
158
- # Note: Do NOT add EOS here, because the sequence continues with the query later.
159
- doc_template = "Instruct: Given a query, retrieve documents that answer the query\nDocument: {document}\n"
 
160
  cached_states = []
161
 
162
- print("Indexing documents...")
163
  for doc in documents:
164
  text = doc_template.format(document=doc)
165
- # add_eos=False is CRITICAL here
166
  tokens = tokenizer.encode(text, add_eos=False)
167
-
168
- # Forward pass
169
  _, state = emb_model.forward(tokens, None)
170
-
171
- # Move state to CPU to save GPU memory during storage
172
- # State structure: [Tensor(Tokenshift), Tensor(TimeMix)]
173
  cpu_state = [s.cpu() for s in state]
174
  cached_states.append(cpu_state)
175
- # Save cached states to disk (optional)
176
- torch.save(cached_states, 'cached_doc_states.pth')
177
 
178
  # --- Phase 2: Querying (Fast Retrieval) ---
179
  query_template = "Query: {query}"
180
  query_text = query_template.format(query=query)
181
- # Now we add EOS to mark the end of the full sequence
182
  query_tokens = tokenizer.encode(query_text, add_eos=True)
183
 
184
- print(f"Processing query: '{query}' against {len(cached_states)} cached docs...")
185
-
186
- # We can batch the query processing against multiple document states
187
- # 1. Prepare a batch of states (Move back to GPU)
188
- # Note: We must CLONE/DEEPCOPY because RWKV modifies state in-place!
189
  batch_states = [[], []]
190
  for cpu_s in cached_states:
191
- batch_states[0].append(cpu_s[0].clone().cuda()) # Tokenshift State
192
- batch_states[1].append(cpu_s[1].clone().cuda()) # TimeMix State
193
 
194
- # Stack into batch tensors
195
- # State[0]: [Layers, 2, 1, Hidden] -> Stack dim 2 -> [Layers, 2, Batch, Hidden]
196
- # State[1]: [Layers, 1, Heads, HeadSize, HeadSize] -> Stack dim 1 -> [Layers, Batch, Heads, ...]
197
  state_input = [
198
  torch.stack(batch_states[0], dim=2).squeeze(3),
199
  torch.stack(batch_states[1], dim=1).squeeze(2)
200
  ]
201
 
202
- # 2. Prepare query tokens (Broadcast query to batch size)
203
- batch_size = len(documents)
204
- batch_query_tokens = [query_tokens] * batch_size
205
-
206
- # 3. Fast Forward (Only processing query tokens!)
207
  _, final_state = emb_model.forward(batch_query_tokens, state_input)
208
  logits = reranker.forward(final_state[1])
209
  scores = torch.sigmoid(logits)
210
-
211
- print("\nRWKVReRanker Offline Scores:")
212
- for doc, score in zip(documents, scores):
213
- print(f"[{score:.4f}] {doc}")
214
  ```
215
 
216
  ## Summary of Differences
@@ -221,4 +182,15 @@ for doc, score in zip(documents, scores):
221
  | **Latency** | Extremely Fast | Slow O(L_doc + L_query) | Fast O(L_query) only |
222
  | **Input** | Query & Doc separate | `Instruct + Doc + Query` | `Query` (on top of cached Doc) |
223
  | **Storage** | Low (Vector only) | None | High (Stores Hidden States) |
224
- | **Best For** | Initial Retrieval (Top-k) | Reranking few candidates | Reranking many candidates |
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ pipeline_tag: text-retrieval
4
  ---
5
+
6
  # EmbeddingRWKV
7
 
8
+ EmbeddingRWKV is a high-efficiency text embedding and reranking model based on the RWKV architecture, introduced in the paper [EmbeddingRWKV: State-Centric Retrieval with Reusable States](https://huggingface.co/papers/2601.07861).
9
+
10
+ It utilizes **State-Centric Retrieval**, a unified retrieval paradigm that uses "states" as a bridge to connect embedding models and rerankers, significantly improving inference speed for reranking tasks.
11
+
12
+ [**Paper**](https://huggingface.co/papers/2601.07861) | [**GitHub**](https://github.com/howard-hou/EmbeddingRWKV)
13
 
14
  ## 📦 Installation
15
 
 
54
  ]
55
  # Encode Query
56
  q_tokens = tokenizer.encode(query, add_eos=True)
57
+ q_emb, _ = emb_model.forward_text_only(q_tokens, None) # shape: [1, Dim]
58
 
59
  # Encode Documents (Batch)
60
  doc_batch = [tokenizer.encode(doc, add_eos=True) for doc in documents]
 
64
  # Prepend 0s (Left Padding)
65
  doc_batch[i] = [0] * pad_len + doc_batch[i]
66
 
67
+ d_embs, _ = emb_model.forward_text_only(doc_batch, None)
68
 
69
  # Calculate Cosine Similarity
70
  scores_emb = F.cosine_similarity(q_emb, d_embs)
71
+ print("
72
+ EmbeddingRWKV Cosine Similarity:")
73
  for doc, score in zip(documents, scores_emb):
74
  print(f"[{score.item():.4f}] {doc}")
75
  ```
76
 
 
 
77
  ### ⚠️ Critical Performance Tip: Pad to Same Length
78
 
79
  While the model supports batches with variable sequence lengths, **we strongly recommend padding all sequences to the same length** for maximum GPU throughput.
 
86
 
87
  The `RWKVReRanker` utilizes the final hidden state produced by the main `EmbeddingRWKV` model to score the relevance between a query and a document.
88
 
89
+ ### Online Mode Usage Example
 
 
 
 
 
 
 
90
 
91
  ```python
92
  import torch
 
94
  from rwkv_emb.model import EmbeddingRWKV, RWKVReRanker
95
 
96
  # 1. Load Models
 
97
  emb_model = EmbeddingRWKV(model_path='/path/to/EmbeddingRWKV.pth')
98
  reranker = RWKVReRanker(model_path='/path/to/RWKVReRanker.pth')
99
 
 
108
  ]
109
 
110
  # 3. Construct Input Pairs
 
111
  pairs = []
112
+ online_template = "Instruct: Given a query, retrieve documents that answer the query
113
+ Document: {document}
114
+ Query: {query}"
115
  for doc in documents:
 
116
  text = online_template.format(document=doc, query=query)
117
  pairs.append(text)
118
 
119
+ # 4. Tokenize & Pad
120
  batch_tokens = [tokenizer.encode(p, add_eos=True) for p in pairs]
 
 
121
  max_len = max(len(t) for t in batch_tokens)
122
  for i in range(len(batch_tokens)):
123
  batch_tokens[i] = [0] * (max_len - len(batch_tokens[i])) + batch_tokens[i]
124
 
125
  # 5. Get States from Embedding Model
 
126
  _, state = emb_model.forward(batch_tokens, None)
127
 
128
  # 6. Score with ReRanker
 
 
129
  logits = reranker.forward(state[1])
130
+ scores = torch.sigmoid(logits)
131
 
132
  # 7. Print Results
133
+ print("
134
+ RWKVReRanker Online Scores:")
135
  for doc, score in zip(documents, scores):
136
  print(f"[{score:.4f}] {doc}")
137
  ```
 
139
  ### Offline Mode (Cached Doc State)
140
  For scenarios where documents are static but queries change (e.g., Search Engines, RAG), you can **pre-compute and cache the document states**. This reduces query-time latency from O(L_doc + L_query) to just O(L_query).
141
 
 
 
 
 
 
 
 
142
  ```python
143
  # --- Phase 1: Indexing (Pre-computation) ---
144
+ doc_template = "Instruct: Given a query, retrieve documents that answer the query
145
+ Document: {document}
146
+ "
147
  cached_states = []
148
 
 
149
  for doc in documents:
150
  text = doc_template.format(document=doc)
 
151
  tokens = tokenizer.encode(text, add_eos=False)
 
 
152
  _, state = emb_model.forward(tokens, None)
 
 
 
153
  cpu_state = [s.cpu() for s in state]
154
  cached_states.append(cpu_state)
 
 
155
 
156
  # --- Phase 2: Querying (Fast Retrieval) ---
157
  query_template = "Query: {query}"
158
  query_text = query_template.format(query=query)
 
159
  query_tokens = tokenizer.encode(query_text, add_eos=True)
160
 
 
 
 
 
 
161
  batch_states = [[], []]
162
  for cpu_s in cached_states:
163
+ batch_states[0].append(cpu_s[0].clone().cuda())
164
+ batch_states[1].append(cpu_s[1].clone().cuda())
165
 
 
 
 
166
  state_input = [
167
  torch.stack(batch_states[0], dim=2).squeeze(3),
168
  torch.stack(batch_states[1], dim=1).squeeze(2)
169
  ]
170
 
171
+ batch_query_tokens = [query_tokens] * len(documents)
 
 
 
 
172
  _, final_state = emb_model.forward(batch_query_tokens, state_input)
173
  logits = reranker.forward(final_state[1])
174
  scores = torch.sigmoid(logits)
 
 
 
 
175
  ```
176
 
177
  ## Summary of Differences
 
182
  | **Latency** | Extremely Fast | Slow O(L_doc + L_query) | Fast O(L_query) only |
183
  | **Input** | Query & Doc separate | `Instruct + Doc + Query` | `Query` (on top of cached Doc) |
184
  | **Storage** | Low (Vector only) | None | High (Stores Hidden States) |
185
+ | **Best For** | Initial Retrieval (Top-k) | Reranking few candidates | Reranking many candidates |
186
+
187
+ ## Citation
188
+
189
+ ```bibtex
190
+ @article{hou2025embeddingrwkv,
191
+ title={EmbeddingRWKV: State-Centric Retrieval with Reusable States},
192
+ author={Hou, Howard and others},
193
+ journal={arXiv preprint arXiv:2601.07861},
194
+ year={2026}
195
+ }
196
+ ```