Fix compression to handle all new tokens (chunked prefill support)
Browse files- modeling_lean_llama.py +73 -41
modeling_lean_llama.py
CHANGED
|
@@ -137,6 +137,8 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
|
|
| 137 |
kv_granularity: dict[str, str] = dict(getattr(config, "leanllm_kv_granularity", {}))
|
| 138 |
|
| 139 |
self._kv_granularity_map: dict[int, str] = {}
|
|
|
|
|
|
|
| 140 |
|
| 141 |
for layer_idx in self._kv_layers:
|
| 142 |
key = str(layer_idx)
|
|
@@ -148,6 +150,7 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
|
|
| 148 |
if f_alpha is not None:
|
| 149 |
f_alpha = float(f_alpha)
|
| 150 |
self._kv_granularity_map[layer_idx] = kv_granularity.get(key, "per_token")
|
|
|
|
| 151 |
|
| 152 |
# Attach value compressor at the same path used by convert_to_hf_model.py
|
| 153 |
# so that from_pretrained auto-loads the saved weights.
|
|
@@ -162,7 +165,6 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
|
|
| 162 |
)
|
| 163 |
self.model.layers[layer_idx].self_attn.leanllm_v_compressor = v_mod
|
| 164 |
|
| 165 |
-
|
| 166 |
# Key compressor (if not values-only)
|
| 167 |
if not self._values_only and key in kv_key_dims:
|
| 168 |
k_dim = int(kv_key_dims[key])
|
|
@@ -180,6 +182,48 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
|
|
| 180 |
# KV cache compression (applied after each forward pass)
|
| 181 |
# ------------------------------------------------------------------
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
def _compress_past(self, past_key_values: Any) -> Any:
|
| 184 |
if past_key_values is None or not self._kv_layers:
|
| 185 |
return past_key_values
|
|
@@ -189,36 +233,19 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
|
|
| 189 |
for layer_idx in self._kv_layers:
|
| 190 |
layer_cache = past_key_values.layers[layer_idx]
|
| 191 |
v = layer_cache.values # [B, H, T, D]
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
compressor = self.model.layers[layer_idx].self_attn.leanllm_v_compressor
|
| 197 |
-
comp_dtype = compressor.encoder.weight.dtype
|
| 198 |
-
v_vec = _kv_to_vec(v_last, gran).to(comp_dtype)
|
| 199 |
-
v_rec = compressor.decode(compressor.encode(v_vec), x_orig=None)
|
| 200 |
-
v_new_last = _vec_to_kv(v_rec, v_last).to(dtype=v_dtype)
|
| 201 |
-
|
| 202 |
-
if v.shape[2] == 1:
|
| 203 |
-
layer_cache.values = v_new_last
|
| 204 |
-
else:
|
| 205 |
-
layer_cache.values = torch.cat([v[:, :, :-1, :], v_new_last], dim=2)
|
| 206 |
|
| 207 |
if not self._values_only and hasattr(
|
| 208 |
self.model.layers[layer_idx].self_attn, "leanllm_k_compressor"
|
| 209 |
):
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
k_vec = _kv_to_vec(k_last, gran).to(k_comp_dtype)
|
| 216 |
-
k_rec = k_comp.decode(k_comp.encode(k_vec), x_orig=None)
|
| 217 |
-
k_new_last = _vec_to_kv(k_rec, k_last).to(dtype=k_dtype)
|
| 218 |
-
if k.shape[2] == 1:
|
| 219 |
-
layer_cache.keys = k_new_last
|
| 220 |
-
else:
|
| 221 |
-
layer_cache.keys = torch.cat([k[:, :, :-1, :], k_new_last], dim=2)
|
| 222 |
|
| 223 |
return past_key_values
|
| 224 |
|
|
@@ -227,26 +254,31 @@ class LeanLlamaForCausalLM(LlamaForCausalLM):
|
|
| 227 |
past_list = list(past_key_values)
|
| 228 |
for layer_idx in self._kv_layers:
|
| 229 |
k, v = past_list[layer_idx]
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
v_new_last = _vec_to_kv(v_rec, v_last).to(dtype=v_dtype)
|
| 239 |
-
|
| 240 |
-
if v.shape[2] == 1:
|
| 241 |
-
past_list[layer_idx] = (k, v_new_last)
|
| 242 |
else:
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
| 245 |
return tuple(past_list)
|
| 246 |
|
| 247 |
return past_key_values
|
| 248 |
|
| 249 |
def forward(self, *args: Any, **kwargs: Any) -> CausalLMOutputWithPast:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
outputs = super().forward(*args, **kwargs)
|
| 251 |
if hasattr(outputs, "past_key_values"):
|
| 252 |
outputs.past_key_values = self._compress_past(outputs.past_key_values)
|
|
|
|
| 137 |
kv_granularity: dict[str, str] = dict(getattr(config, "leanllm_kv_granularity", {}))
|
| 138 |
|
| 139 |
self._kv_granularity_map: dict[int, str] = {}
|
| 140 |
+
# Track how many tokens have been compressed per layer to avoid re-compressing
|
| 141 |
+
self._compressed_up_to: dict[int, int] = {}
|
| 142 |
|
| 143 |
for layer_idx in self._kv_layers:
|
| 144 |
key = str(layer_idx)
|
|
|
|
| 150 |
if f_alpha is not None:
|
| 151 |
f_alpha = float(f_alpha)
|
| 152 |
self._kv_granularity_map[layer_idx] = kv_granularity.get(key, "per_token")
|
| 153 |
+
self._compressed_up_to[layer_idx] = 0
|
| 154 |
|
| 155 |
# Attach value compressor at the same path used by convert_to_hf_model.py
|
| 156 |
# so that from_pretrained auto-loads the saved weights.
|
|
|
|
| 165 |
)
|
| 166 |
self.model.layers[layer_idx].self_attn.leanllm_v_compressor = v_mod
|
| 167 |
|
|
|
|
| 168 |
# Key compressor (if not values-only)
|
| 169 |
if not self._values_only and key in kv_key_dims:
|
| 170 |
k_dim = int(kv_key_dims[key])
|
|
|
|
| 182 |
# KV cache compression (applied after each forward pass)
|
| 183 |
# ------------------------------------------------------------------
|
| 184 |
|
| 185 |
+
def _compress_values(
|
| 186 |
+
self,
|
| 187 |
+
v: torch.Tensor,
|
| 188 |
+
layer_idx: int,
|
| 189 |
+
start: int,
|
| 190 |
+
) -> torch.Tensor:
|
| 191 |
+
"""Compress values from position `start` onward, return full tensor."""
|
| 192 |
+
if start >= v.shape[2]:
|
| 193 |
+
return v
|
| 194 |
+
v_new = v[:, :, start:, :]
|
| 195 |
+
gran = self._kv_granularity_map[layer_idx]
|
| 196 |
+
v_dtype = v.dtype
|
| 197 |
+
compressor = self.model.layers[layer_idx].self_attn.leanllm_v_compressor
|
| 198 |
+
comp_dtype = compressor.encoder.weight.dtype
|
| 199 |
+
v_vec = _kv_to_vec(v_new, gran).to(comp_dtype)
|
| 200 |
+
v_rec = compressor.decode(compressor.encode(v_vec), x_orig=None)
|
| 201 |
+
v_compressed = _vec_to_kv(v_rec, v_new).to(dtype=v_dtype)
|
| 202 |
+
if start == 0:
|
| 203 |
+
return v_compressed
|
| 204 |
+
return torch.cat([v[:, :, :start, :], v_compressed], dim=2)
|
| 205 |
+
|
| 206 |
+
def _compress_keys(
|
| 207 |
+
self,
|
| 208 |
+
k: torch.Tensor,
|
| 209 |
+
layer_idx: int,
|
| 210 |
+
start: int,
|
| 211 |
+
) -> torch.Tensor:
|
| 212 |
+
"""Compress keys from position `start` onward, return full tensor."""
|
| 213 |
+
if start >= k.shape[2]:
|
| 214 |
+
return k
|
| 215 |
+
k_new = k[:, :, start:, :]
|
| 216 |
+
gran = self._kv_granularity_map[layer_idx]
|
| 217 |
+
k_dtype = k.dtype
|
| 218 |
+
k_comp = self.model.layers[layer_idx].self_attn.leanllm_k_compressor
|
| 219 |
+
comp_dtype = k_comp.encoder.weight.dtype
|
| 220 |
+
k_vec = _kv_to_vec(k_new, gran).to(comp_dtype)
|
| 221 |
+
k_rec = k_comp.decode(k_comp.encode(k_vec), x_orig=None)
|
| 222 |
+
k_compressed = _vec_to_kv(k_rec, k_new).to(dtype=k_dtype)
|
| 223 |
+
if start == 0:
|
| 224 |
+
return k_compressed
|
| 225 |
+
return torch.cat([k[:, :, :start, :], k_compressed], dim=2)
|
| 226 |
+
|
| 227 |
def _compress_past(self, past_key_values: Any) -> Any:
|
| 228 |
if past_key_values is None or not self._kv_layers:
|
| 229 |
return past_key_values
|
|
|
|
| 233 |
for layer_idx in self._kv_layers:
|
| 234 |
layer_cache = past_key_values.layers[layer_idx]
|
| 235 |
v = layer_cache.values # [B, H, T, D]
|
| 236 |
+
total_tokens = v.shape[2]
|
| 237 |
+
start = self._compressed_up_to[layer_idx]
|
| 238 |
+
|
| 239 |
+
layer_cache.values = self._compress_values(v, layer_idx, start)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
if not self._values_only and hasattr(
|
| 242 |
self.model.layers[layer_idx].self_attn, "leanllm_k_compressor"
|
| 243 |
):
|
| 244 |
+
layer_cache.keys = self._compress_keys(
|
| 245 |
+
layer_cache.keys, layer_idx, start
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
self._compressed_up_to[layer_idx] = total_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
return past_key_values
|
| 251 |
|
|
|
|
| 254 |
past_list = list(past_key_values)
|
| 255 |
for layer_idx in self._kv_layers:
|
| 256 |
k, v = past_list[layer_idx]
|
| 257 |
+
total_tokens = v.shape[2]
|
| 258 |
+
start = self._compressed_up_to[layer_idx]
|
| 259 |
+
|
| 260 |
+
v_new = self._compress_values(v, layer_idx, start)
|
| 261 |
+
if not self._values_only and hasattr(
|
| 262 |
+
self.model.layers[layer_idx].self_attn, "leanllm_k_compressor"
|
| 263 |
+
):
|
| 264 |
+
k_new = self._compress_keys(k, layer_idx, start)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
else:
|
| 266 |
+
k_new = k
|
| 267 |
+
past_list[layer_idx] = (k_new, v_new)
|
| 268 |
+
|
| 269 |
+
self._compressed_up_to[layer_idx] = total_tokens
|
| 270 |
return tuple(past_list)
|
| 271 |
|
| 272 |
return past_key_values
|
| 273 |
|
| 274 |
def forward(self, *args: Any, **kwargs: Any) -> CausalLMOutputWithPast:
|
| 275 |
+
# Reset compression tracking when there's no cache (new sequence)
|
| 276 |
+
past = kwargs.get("past_key_values", None)
|
| 277 |
+
if past is None and len(args) < 5:
|
| 278 |
+
# No KV cache passed — starting fresh
|
| 279 |
+
for layer_idx in self._kv_layers:
|
| 280 |
+
self._compressed_up_to[layer_idx] = 0
|
| 281 |
+
|
| 282 |
outputs = super().forward(*args, **kwargs)
|
| 283 |
if hasattr(outputs, "past_key_values"):
|
| 284 |
outputs.past_key_values = self._compress_past(outputs.past_key_values)
|