miike-ai commited on
Commit
d9fe3f3
·
verified ·
1 Parent(s): ac833d9

Fix compression to handle all new tokens (chunked prefill support)

Browse files
Files changed (1) hide show
  1. modeling_lean_llama.py +73 -41
modeling_lean_llama.py CHANGED
@@ -137,6 +137,8 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
137
  kv_granularity: dict[str, str] = dict(getattr(config, "leanllm_kv_granularity", {}))
138
 
139
  self._kv_granularity_map: dict[int, str] = {}
 
 
140
 
141
  for layer_idx in self._kv_layers:
142
  key = str(layer_idx)
@@ -148,6 +150,7 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
148
  if f_alpha is not None:
149
  f_alpha = float(f_alpha)
150
  self._kv_granularity_map[layer_idx] = kv_granularity.get(key, "per_token")
 
151
 
152
  # Attach value compressor at the same path used by convert_to_hf_model.py
153
  # so that from_pretrained auto-loads the saved weights.
@@ -162,7 +165,6 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
162
  )
163
  self.model.layers[layer_idx].self_attn.leanllm_v_compressor = v_mod
164
 
165
-
166
  # Key compressor (if not values-only)
167
  if not self._values_only and key in kv_key_dims:
168
  k_dim = int(kv_key_dims[key])
@@ -180,6 +182,48 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
180
  # KV cache compression (applied after each forward pass)
181
  # ------------------------------------------------------------------
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def _compress_past(self, past_key_values: Any) -> Any:
184
  if past_key_values is None or not self._kv_layers:
185
  return past_key_values
@@ -189,36 +233,19 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
189
  for layer_idx in self._kv_layers:
190
  layer_cache = past_key_values.layers[layer_idx]
191
  v = layer_cache.values # [B, H, T, D]
192
- gran = self._kv_granularity_map[layer_idx]
193
- v_last = v[:, :, -1:, :]
194
- v_dtype = v.dtype
195
-
196
- compressor = self.model.layers[layer_idx].self_attn.leanllm_v_compressor
197
- comp_dtype = compressor.encoder.weight.dtype
198
- v_vec = _kv_to_vec(v_last, gran).to(comp_dtype)
199
- v_rec = compressor.decode(compressor.encode(v_vec), x_orig=None)
200
- v_new_last = _vec_to_kv(v_rec, v_last).to(dtype=v_dtype)
201
-
202
- if v.shape[2] == 1:
203
- layer_cache.values = v_new_last
204
- else:
205
- layer_cache.values = torch.cat([v[:, :, :-1, :], v_new_last], dim=2)
206
 
207
  if not self._values_only and hasattr(
208
  self.model.layers[layer_idx].self_attn, "leanllm_k_compressor"
209
  ):
210
- k = layer_cache.keys
211
- k_last = k[:, :, -1:, :]
212
- k_dtype = k.dtype
213
- k_comp = self.model.layers[layer_idx].self_attn.leanllm_k_compressor
214
- k_comp_dtype = k_comp.encoder.weight.dtype
215
- k_vec = _kv_to_vec(k_last, gran).to(k_comp_dtype)
216
- k_rec = k_comp.decode(k_comp.encode(k_vec), x_orig=None)
217
- k_new_last = _vec_to_kv(k_rec, k_last).to(dtype=k_dtype)
218
- if k.shape[2] == 1:
219
- layer_cache.keys = k_new_last
220
- else:
221
- layer_cache.keys = torch.cat([k[:, :, :-1, :], k_new_last], dim=2)
222
 
223
  return past_key_values
224
 
@@ -227,26 +254,31 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
227
  past_list = list(past_key_values)
228
  for layer_idx in self._kv_layers:
229
  k, v = past_list[layer_idx]
230
- gran = self._kv_granularity_map[layer_idx]
231
- v_last = v[:, :, -1:, :]
232
- v_dtype = v.dtype
233
-
234
- compressor = self.model.layers[layer_idx].self_attn.leanllm_v_compressor
235
- comp_dtype = compressor.encoder.weight.dtype
236
- v_vec = _kv_to_vec(v_last, gran).to(comp_dtype)
237
- v_rec = compressor.decode(compressor.encode(v_vec), x_orig=None)
238
- v_new_last = _vec_to_kv(v_rec, v_last).to(dtype=v_dtype)
239
-
240
- if v.shape[2] == 1:
241
- past_list[layer_idx] = (k, v_new_last)
242
  else:
243
- v_new = torch.cat([v[:, :, :-1, :], v_new_last], dim=2)
244
- past_list[layer_idx] = (k, v_new)
 
 
245
  return tuple(past_list)
246
 
247
  return past_key_values
248
 
249
  def forward(self, *args: Any, **kwargs: Any) -> CausalLMOutputWithPast:
 
 
 
 
 
 
 
250
  outputs = super().forward(*args, **kwargs)
251
  if hasattr(outputs, "past_key_values"):
252
  outputs.past_key_values = self._compress_past(outputs.past_key_values)
 
137
  kv_granularity: dict[str, str] = dict(getattr(config, "leanllm_kv_granularity", {}))
138
 
139
  self._kv_granularity_map: dict[int, str] = {}
140
+ # Track how many tokens have been compressed per layer to avoid re-compressing
141
+ self._compressed_up_to: dict[int, int] = {}
142
 
143
  for layer_idx in self._kv_layers:
144
  key = str(layer_idx)
 
150
  if f_alpha is not None:
151
  f_alpha = float(f_alpha)
152
  self._kv_granularity_map[layer_idx] = kv_granularity.get(key, "per_token")
153
+ self._compressed_up_to[layer_idx] = 0
154
 
155
  # Attach value compressor at the same path used by convert_to_hf_model.py
156
  # so that from_pretrained auto-loads the saved weights.
 
165
  )
166
  self.model.layers[layer_idx].self_attn.leanllm_v_compressor = v_mod
167
 
 
168
  # Key compressor (if not values-only)
169
  if not self._values_only and key in kv_key_dims:
170
  k_dim = int(kv_key_dims[key])
 
182
  # KV cache compression (applied after each forward pass)
183
  # ------------------------------------------------------------------
184
 
185
+ def _compress_values(
186
+ self,
187
+ v: torch.Tensor,
188
+ layer_idx: int,
189
+ start: int,
190
+ ) -> torch.Tensor:
191
+ """Compress values from position `start` onward, return full tensor."""
192
+ if start >= v.shape[2]:
193
+ return v
194
+ v_new = v[:, :, start:, :]
195
+ gran = self._kv_granularity_map[layer_idx]
196
+ v_dtype = v.dtype
197
+ compressor = self.model.layers[layer_idx].self_attn.leanllm_v_compressor
198
+ comp_dtype = compressor.encoder.weight.dtype
199
+ v_vec = _kv_to_vec(v_new, gran).to(comp_dtype)
200
+ v_rec = compressor.decode(compressor.encode(v_vec), x_orig=None)
201
+ v_compressed = _vec_to_kv(v_rec, v_new).to(dtype=v_dtype)
202
+ if start == 0:
203
+ return v_compressed
204
+ return torch.cat([v[:, :, :start, :], v_compressed], dim=2)
205
+
206
+ def _compress_keys(
207
+ self,
208
+ k: torch.Tensor,
209
+ layer_idx: int,
210
+ start: int,
211
+ ) -> torch.Tensor:
212
+ """Compress keys from position `start` onward, return full tensor."""
213
+ if start >= k.shape[2]:
214
+ return k
215
+ k_new = k[:, :, start:, :]
216
+ gran = self._kv_granularity_map[layer_idx]
217
+ k_dtype = k.dtype
218
+ k_comp = self.model.layers[layer_idx].self_attn.leanllm_k_compressor
219
+ comp_dtype = k_comp.encoder.weight.dtype
220
+ k_vec = _kv_to_vec(k_new, gran).to(comp_dtype)
221
+ k_rec = k_comp.decode(k_comp.encode(k_vec), x_orig=None)
222
+ k_compressed = _vec_to_kv(k_rec, k_new).to(dtype=k_dtype)
223
+ if start == 0:
224
+ return k_compressed
225
+ return torch.cat([k[:, :, :start, :], k_compressed], dim=2)
226
+
227
  def _compress_past(self, past_key_values: Any) -> Any:
228
  if past_key_values is None or not self._kv_layers:
229
  return past_key_values
 
233
  for layer_idx in self._kv_layers:
234
  layer_cache = past_key_values.layers[layer_idx]
235
  v = layer_cache.values # [B, H, T, D]
236
+ total_tokens = v.shape[2]
237
+ start = self._compressed_up_to[layer_idx]
238
+
239
+ layer_cache.values = self._compress_values(v, layer_idx, start)
 
 
 
 
 
 
 
 
 
 
240
 
241
  if not self._values_only and hasattr(
242
  self.model.layers[layer_idx].self_attn, "leanllm_k_compressor"
243
  ):
244
+ layer_cache.keys = self._compress_keys(
245
+ layer_cache.keys, layer_idx, start
246
+ )
247
+
248
+ self._compressed_up_to[layer_idx] = total_tokens
 
 
 
 
 
 
 
249
 
250
  return past_key_values
251
 
 
254
  past_list = list(past_key_values)
255
  for layer_idx in self._kv_layers:
256
  k, v = past_list[layer_idx]
257
+ total_tokens = v.shape[2]
258
+ start = self._compressed_up_to[layer_idx]
259
+
260
+ v_new = self._compress_values(v, layer_idx, start)
261
+ if not self._values_only and hasattr(
262
+ self.model.layers[layer_idx].self_attn, "leanllm_k_compressor"
263
+ ):
264
+ k_new = self._compress_keys(k, layer_idx, start)
 
 
 
 
265
  else:
266
+ k_new = k
267
+ past_list[layer_idx] = (k_new, v_new)
268
+
269
+ self._compressed_up_to[layer_idx] = total_tokens
270
  return tuple(past_list)
271
 
272
  return past_key_values
273
 
274
  def forward(self, *args: Any, **kwargs: Any) -> CausalLMOutputWithPast:
275
+ # Reset compression tracking when there's no cache (new sequence)
276
+ past = kwargs.get("past_key_values", None)
277
+ if past is None and len(args) < 5:
278
+ # No KV cache passed — starting fresh
279
+ for layer_idx in self._kv_layers:
280
+ self._compressed_up_to[layer_idx] = 0
281
+
282
  outputs = super().forward(*args, **kwargs)
283
  if hasattr(outputs, "past_key_values"):
284
  outputs.past_key_values = self._compress_past(outputs.past_key_values)