Files changed (1) hide show
  1. README.md +217 -0
README.md ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ tags:
7
+ - dense-retrieval
8
+ - latent-reasoning
9
+ - embeddings
10
+ - information-retrieval
11
+ - feature-extraction
12
+ base_model: Qwen/Qwen3-8B
13
+ pipeline_tag: feature-extraction
14
+ datasets:
15
+ - jinjiajie/LaSER-Training
16
+ ---
17
+
18
+ # LaSER-Qwen3-8B
19
+
20
+ **LaSER** (**La**tent **S**pace **E**xplicit **R**easoning) is a self-distillation framework that internalizes explicit Chain-of-Thought reasoning into the latent space of dense retrievers, enabling the model to "think silently" through continuous latent tokens.
21
+
22
+ **LaSER-Qwen3-8B** is the **flagship 8B-parameter** dense retriever built on [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B), achieving **state-of-the-art performance** on reasoning-intensive retrieval benchmarks.
23
+
24
+ > 📄 **Paper:** [LaSER: Internalizing Explicit Reasoning into Latent Space for Dense Retrieval](https://arxiv.org/abs/2603.01425)
25
+ >
26
+ > 💻 **Code:** [https://github.com/ignorejjj/LaSER](https://github.com/ignorejjj/LaSER)
27
+
28
+ ## Model Summary
29
+
30
+ | Attribute | Detail |
31
+ |:---|:---|
32
+ | **Model Type** | Dense Retriever with Latent Thinking |
33
+ | **Base Model** | [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) |
34
+ | **Parameters** | 8B |
35
+ | **Embedding Dimension** | 4096 |
36
+ | **Max Sequence Length** | 8192 (training: 512) |
37
+ | **Similarity Function** | Cosine Similarity |
38
+ | **Latent Thinking Steps (K)** | 3 (default) |
39
+ | **Training Data** | 81K examples from [ReasonEmb](https://huggingface.co/datasets/reasonir/ReasonEmb) |
40
+ | **License** | MIT |
41
+
42
+ ## Highlights
43
+
44
+ - **29.3 nDCG@10** on BRIGHT — surpasses computationally expensive rewrite-then-retrieve pipelines (28.1) while being **~300× faster**
45
+ - **State-of-the-art** across BRIGHT, FollowIR, and BrowseComp-Plus benchmarks
46
+ - Only **~1.7× latency overhead** compared to standard single-pass dense retrievers
47
+
48
+ ## How It Works
49
+
50
+ Unlike standard dense retrievers that encode queries in a single forward pass, LaSER generates **K continuous latent thinking tokens** autoregressively in the embedding space:
51
+
52
+ 1. Encode the input text into embeddings
53
+ 2. At each thinking step, project the last hidden state through the LM head → softmax → compute a probability-weighted soft token from the embedding table
54
+ 3. Append the soft token and repeat for K steps (using KV caching for efficiency)
55
+ 4. Mean-pool the hidden states from all K thinking steps → L2 normalize
56
+
57
+ This enables complex reasoning while maintaining the inference efficiency of standard dense retrievers (~1.7× latency overhead, only ~0.3% of rewrite-then-retrieve pipelines).
58
+
59
+ ## Usage
60
+
61
+ ### Direct Usage with Transformers
62
+
63
+ ```python
64
+ import torch
65
+ import torch.nn.functional as F
66
+ from transformers import AutoModelForCausalLM, AutoTokenizer
67
+
68
+
69
+ def laser_encode(model, tokenizer, texts, max_length=512, num_thinking_steps=3):
70
+ """Encode texts using LaSER's latent thinking mechanism."""
71
+ device = next(model.parameters()).device
72
+ batch = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)
73
+ input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
74
+
75
+ batch_size = input_ids.size(0)
76
+ thinking_slots = num_thinking_steps - 1
77
+ eos_id = tokenizer.eos_token_id
78
+
79
+ if thinking_slots > 0:
80
+ eos_padding = torch.full((batch_size, thinking_slots), eos_id, dtype=input_ids.dtype, device=device)
81
+ mask_padding = torch.ones((batch_size, thinking_slots), dtype=attention_mask.dtype, device=device)
82
+ input_ids = torch.cat([input_ids, eos_padding], dim=1)
83
+ attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
84
+
85
+ input_embeds = model.get_input_embeddings()(input_ids)
86
+ embedding_table = model.get_input_embeddings().weight
87
+ base_seq_len = input_embeds.size(1) - thinking_slots
88
+
89
+ past_key_values = None
90
+ hidden_steps = []
91
+
92
+ for step_idx in range(thinking_slots):
93
+ pos = base_seq_len + step_idx
94
+ step_embeds = input_embeds[:, :pos, :] if past_key_values is None else input_embeds[:, pos-1:pos, :]
95
+ step_mask = attention_mask[:, :pos]
96
+
97
+ outputs = model(inputs_embeds=step_embeds, attention_mask=step_mask,
98
+ output_hidden_states=True, past_key_values=past_key_values,
99
+ use_cache=True, return_dict=True)
100
+ hidden_steps.append(outputs.hidden_states[-1][:, -1, :])
101
+ token_probs = torch.softmax(outputs.logits[:, -1, :], dim=-1)
102
+ new_embed = token_probs @ embedding_table
103
+ past_key_values = outputs.past_key_values
104
+ pre = input_embeds[:, :pos, :]
105
+ post = input_embeds[:, pos+1:, :]
106
+ input_embeds = torch.cat([pre, new_embed.unsqueeze(1), post], dim=1)
107
+
108
+ final_embeds = input_embeds[:, -1:, :] if past_key_values else input_embeds
109
+ outputs = model(inputs_embeds=final_embeds, attention_mask=attention_mask,
110
+ output_hidden_states=True, past_key_values=past_key_values,
111
+ use_cache=True, return_dict=True)
112
+ hidden_steps.append(outputs.hidden_states[-1][:, -1, :])
113
+
114
+ embeddings = torch.stack(hidden_steps, dim=1).mean(dim=1)
115
+ return F.normalize(embeddings, p=2, dim=-1)
116
+
117
+
118
+ # Load model
119
+ model_name = "Alibaba-NLP/LaSER-Qwen3-8B"
120
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
121
+ tokenizer.padding_side = "left"
122
+ if tokenizer.pad_token_id is None:
123
+ tokenizer.pad_token = tokenizer.eos_token
124
+
125
+ model = AutoModelForCausalLM.from_pretrained(
126
+ model_name, torch_dtype=torch.float16, trust_remote_code=True
127
+ ).cuda().eval()
128
+
129
+ # Encode queries and documents
130
+ with torch.inference_mode():
131
+ query_emb = laser_encode(model, tokenizer, ["why is the sky blue"], num_thinking_steps=3)
132
+ doc_emb = laser_encode(model, tokenizer, ["Rayleigh scattering makes short wavelengths scatter more strongly"], num_thinking_steps=3)
133
+
134
+ # Compute similarity
135
+ similarity = (query_emb @ doc_emb.T).item()
136
+ print(f"Cosine similarity: {similarity:.4f}")
137
+ ```
138
+
139
+ ### Batch Encoding
140
+
141
+ ```python
142
+ queries = [
143
+ "What causes tides in the ocean?",
144
+ "How does photosynthesis convert light to energy?",
145
+ "Why do metals conduct electricity?",
146
+ ]
147
+
148
+ with torch.inference_mode():
149
+ query_embeddings = laser_encode(model, tokenizer, queries, num_thinking_steps=3)
150
+ print(f"Batch embeddings shape: {query_embeddings.shape}") # (3, 4096)
151
+ ```
152
+
153
+ ## Evaluation Results
154
+
155
+ ### BRIGHT Benchmark (nDCG@10) — In-Domain
156
+
157
+ | Model | Size | Bio. | Earth. | Econ. | Psy. | Rob. | Stack. | Sus. | Leet. | Pony | AoPS | TheoQ. | TheoT. | **Avg.** |
158
+ |:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
159
+ | Qwen3-Embedding-8B | 8B | 14.7 | 17.9 | 15.5 | 19.9 | 9.1 | 12.9 | 16.5 | 17.4 | 0.8 | 2.5 | 16.8 | 24.5 | 14.0 |
160
+ | Fair Baseline (Qwen3-8B) | 8B | 49.7 | 51.2 | 26.9 | 37.4 | 23.4 | 28.0 | 34.1 | 3.7 | 3.2 | 2.8 | 16.8 | 31.8 | 25.7 |
161
+ | Rewrite-then-Retrieve (Qwen3-8B) † | 8B | 53.1 | 54.3 | 32.1 | 34.8 | 20.5 | 31.1 | 32.2 | 3.2 | 15.2 | 4.1 | 17.4 | 38.8 | 28.1 |
162
+ | GIRCSE (Qwen3-8B) | 8B | **59.0** | **56.5** | 27.2 | 40.3 | 19.0 | 28.5 | 31.4 | 3.2 | 3.6 | 1.7 | 14.0 | 27.2 | 26.0 |
163
+ | **LaSER-Qwen3-8B (Ours)** | **8B** | 58.4 | 48.1 | **28.0** | **40.9** | **17.0** | **29.9** | **28.3** | 1.7 | **5.9** | **1.5** | **14.6** | **19.2** | **29.3** |
164
+
165
+ ### FollowIR Benchmark — Out-of-Domain
166
+
167
+ | Model | Size | Robust04 MAP@5 | News21 nDCG@5 | Core17 MAP@5 | Score | p-MRR |
168
+ |:---|:---:|:---:|:---:|:---:|:---:|:---:|
169
+ | Fair Baseline (Qwen3-8B) | 8B | 2.8 | 18.9 | 11.2 | 11.0 | 1.7 |
170
+ | GIRCSE (Qwen3-8B) | 8B | 3.0 | 22.6 | 8.5 | 11.4 | 2.0 |
171
+ | **LaSER-Qwen3-8B (Ours)** | **8B** | **4.1** | **21.8** | **11.4** | **11.4** | **1.3** |
172
+
173
+ ### BrowseComp-Plus Benchmark — Out-of-Domain
174
+
175
+ | Model | Size | R@5 | R@100 | R@1000 |
176
+ |:---|:---:|:---:|:---:|:---:|
177
+ | Fair Baseline (Qwen3-8B) | 8B | 11.3 | 37.4 | 63.2 |
178
+ | GIRCSE (Qwen3-8B) | 8B | **13.0** | **40.8** | **68.1** |
179
+ | **LaSER-Qwen3-8B (Ours)** | **8B** | 6.8 | 26.8 | 54.9 |
180
+
181
+ ### Latency Analysis (Single A100, Batch Size 8)
182
+
183
+ | Method | Latency (ms) | BRIGHT nDCG@10 |
184
+ |:---|:---:|:---:|
185
+ | Basic Retriever (8B) | ~30 ms | 25.7 |
186
+ | Rewrite-then-Retrieve (8B) | ~4000 ms | 28.1 |
187
+ | **LaSER (8B)** | **~50 ms** | **29.3** |
188
+
189
+ > LaSER achieves the best performance while incurring only **~1.7× latency** over the basic retriever, compared to **~130×** for rewrite-then-retrieve pipelines.
190
+
191
+ ## Training Details
192
+
193
+ - **Training Data:** 81K query-document pairs from [ReasonEmb](https://huggingface.co/datasets/reasonir/ReasonEmb), each with a CoT reasoning path generated by GPT-4o-mini
194
+ - **Method:** LoRA fine-tuning (r=64, α=32) for 1 epoch on 4×A100 GPUs
195
+ - **Loss:** Contrastive learning + Output-level KL distillation (λ₂=10) + Process-level trajectory alignment (λ₃=0.1)
196
+ - **Temperature:** τ=0.02
197
+ - **Thinking Steps:** K=3
198
+
199
+ ## Model Family
200
+
201
+ | Model | Parameters | BRIGHT Avg. | Link |
202
+ |:---|:---:|:---:|:---:|
203
+ | LaSER-Qwen3-0.6B | 0.6B | 23.1 | [🤗 Link](https://huggingface.co/Alibaba-NLP/LaSER-Qwen3-0.6B) |
204
+ | LaSER-Qwen3-4B | 4B | 28.0 | [🤗 Link](https://huggingface.co/Alibaba-NLP/LaSER-Qwen3-4B) |
205
+ | **LaSER-Qwen3-8B** | 8B | 29.3 | [🤗 This model](https://huggingface.co/Alibaba-NLP/LaSER-Qwen3-8B) |
206
+
207
+ ## Citation
208
+
209
+ ```bibtex
210
+ @article{jin2026laser,
211
+ title={LaSER: Internalizing Explicit Reasoning into Latent Space for Dense Retrieval},
212
+ author={Jin, Jiajie and Zhang, Yanzhao and Li, Mingxin and Long, Dingkun and Xie, Pengjun and Zhu, Yutao and Dou, Zhicheng},
213
+ year={2026},
214
+ journal={arXiv preprint},
215
+ url={https://arxiv.org/abs/2603.01425},
216
+ }
217
+ ```