replace `model.generate` with custom generation function to optimize kv_cache
Browse filesreplaced the problematic high level generate
```
model_outputs = self.model.generate(
input_ids=current_input_id,
attention_mask=current_attention_mask, # This mask is for (history in KV cache + current_input_id)
eos_token_id=[eos_id, code_id[1]], # code_id[1] is assumed to be </code>'s last token ID
past_key_values=current_kv,
generation_config=self.generation_config, # Ensure this has return_dict_in_generate=True, use_cache=True
# max_new_tokens should be set in self.generation_config appropriately for a segment
)
```
with custom generation and decoding functions.
This is because you can't use stateful cache in high level `model.generate`.
- src/retool_trainer.py +168 -101
|
@@ -25,6 +25,7 @@ from transformers import (
|
|
| 25 |
is_wandb_available,
|
| 26 |
PreTrainedTokenizer,
|
| 27 |
)
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
|
|
@@ -257,128 +258,194 @@ class ReToolTrainer(Trainer): # Change this line
|
|
| 257 |
return advantages
|
| 258 |
|
| 259 |
|
| 260 |
-
def
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
|
|
|
|
|
|
| 271 |
batch_size = prompt_ids_batch.size(0)
|
| 272 |
batch_completion = []
|
| 273 |
batch_interpreter_positions = []
|
| 274 |
-
|
| 275 |
-
for i in range(batch_size):
|
| 276 |
-
#
|
| 277 |
-
current_input_id = prompt_ids_batch[i:i+1]
|
| 278 |
-
current_attention_mask = attention_mask_batch[i:i+1]
|
| 279 |
current_kv = None
|
| 280 |
-
|
| 281 |
-
#
|
| 282 |
cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device)
|
| 283 |
interpreter_positions = []
|
| 284 |
-
|
| 285 |
for turn_idx in range(max_turns):
|
| 286 |
-
#
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
input_ids=current_input_id,
|
| 289 |
-
attention_mask=current_attention_mask,
|
| 290 |
-
eos_token_id=[eos_id, code_id[1]], # code_id[1] is assumed to be </code>'s last token ID
|
| 291 |
past_key_values=current_kv,
|
| 292 |
-
|
| 293 |
-
|
| 294 |
)
|
| 295 |
-
|
| 296 |
-
#
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
#
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
#
|
| 303 |
-
cumulative_completion_ids = torch.cat([cumulative_completion_ids, completion_id], dim=1)
|
| 304 |
-
|
| 305 |
-
# Update current_input_id for the next generation step
|
| 306 |
-
# Update current_attention_mask: it was for (history + current_input_id),
|
| 307 |
-
# now append 1s for completion_id
|
| 308 |
-
current_attention_mask = torch.cat([
|
| 309 |
-
current_attention_mask,
|
| 310 |
-
torch.ones_like(completion_id)
|
| 311 |
-
], dim=1)
|
| 312 |
-
|
| 313 |
-
current_kv = model_outputs.past_key_values # Cache for the new current_full_ids
|
| 314 |
-
|
| 315 |
-
last_token_id = current_full_ids[0, -1].item()
|
| 316 |
-
|
| 317 |
if last_token_id == eos_id or turn_idx == max_turns - 1:
|
| 318 |
batch_completion.append(cumulative_completion_ids.squeeze(0))
|
| 319 |
-
batch_interpreter_positions.append(interpreter_positions)
|
| 320 |
break
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
# Extract code from the
|
| 325 |
-
full_text = self.processing_class.decode(
|
|
|
|
|
|
|
| 326 |
code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)
|
|
|
|
| 327 |
if code_match:
|
| 328 |
-
code_block = code_match.group(1)
|
| 329 |
-
interpreter_text = self._execute_code(code_block)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
else:
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
formatted_feedback_text = f"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}"
|
| 334 |
-
|
| 335 |
-
interpreter_feedback_id = self.processing_class(
|
| 336 |
-
formatted_feedback_text,
|
| 337 |
-
return_tensors="pt",
|
| 338 |
-
add_special_tokens=False
|
| 339 |
-
).input_ids.to(current_full_ids.device)
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
# Record positions relative to cumulative_completion_ids *before* appending feedback
|
| 343 |
-
interpreter_start_idx = cumulative_completion_ids.size(1)
|
| 344 |
-
cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_feedback_id], dim=1) # Use cumulative, not current
|
| 345 |
-
interpreter_end_idx = cumulative_completion_ids.size(1) - 1
|
| 346 |
-
interpreter_positions.append((interpreter_start_idx, interpreter_end_idx))
|
| 347 |
-
|
| 348 |
-
# Update attention mask for the appended tool feedback
|
| 349 |
-
current_attention_mask = torch.cat([
|
| 350 |
-
current_attention_mask,
|
| 351 |
-
torch.ones_like(interpreter_feedback_id)
|
| 352 |
-
], dim=1)
|
| 353 |
-
|
| 354 |
-
# Prepare for the next LM generation step:
|
| 355 |
-
# The model needs to "process" the tool_output_tokens to update its KV cache.
|
| 356 |
-
# The `current_input_id` for the next generate call will be `interpreter_feedback_id`.
|
| 357 |
-
# `current_kv` already holds the cache for `current_full_ids` *before* the tool feedback was appended.
|
| 358 |
-
# The `current_attention_mask` now correctly covers `current_full_ids` (which includes tool feedback).
|
| 359 |
-
current_input_id = interpreter_feedback_id
|
| 360 |
-
# `current_kv` is correct (it's for the prefix before `interpreter_feedback_id`).
|
| 361 |
-
# The next `model.generate` call will use this `current_input_id`, `current_attention_mask`, and `current_kv`.
|
| 362 |
else:
|
| 363 |
-
#
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
else: # Executed if the loop finished due to max_turns without a break
|
| 369 |
batch_completion.append(cumulative_completion_ids.squeeze(0))
|
| 370 |
batch_interpreter_positions.append(interpreter_positions)
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
return padded_sequences, batch_interpreter_positions
|
| 381 |
-
|
| 382 |
|
| 383 |
|
| 384 |
def _create_interpreter_mask(
|
|
|
|
| 25 |
is_wandb_available,
|
| 26 |
PreTrainedTokenizer,
|
| 27 |
)
|
| 28 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
|
| 29 |
|
| 30 |
|
| 31 |
|
|
|
|
| 258 |
return advantages
|
| 259 |
|
| 260 |
|
| 261 |
+
def _custom_generate(self, input_ids, attention_mask=None, past_key_values=None, max_new_tokens=50, eos_token_ids=None):
|
| 262 |
+
"""Custom generation function that avoids KV cache issues"""
|
| 263 |
+
if attention_mask is None:
|
| 264 |
+
attention_mask = torch.ones_like(input_ids)
|
| 265 |
+
|
| 266 |
+
if eos_token_ids is None:
|
| 267 |
+
eos_token_ids = [self.processing_class.eos_token_id]
|
| 268 |
+
|
| 269 |
+
# Initialize
|
| 270 |
+
current_ids = input_ids.clone()
|
| 271 |
+
current_mask = attention_mask.clone()
|
| 272 |
+
current_kv = past_key_values
|
| 273 |
+
|
| 274 |
+
# Generate tokens in batches for efficiency
|
| 275 |
+
all_tokens = []
|
| 276 |
+
batch_size = 10 # Process this many tokens at once
|
| 277 |
+
|
| 278 |
+
for start_idx in range(0, max_new_tokens, batch_size):
|
| 279 |
+
# How many tokens to generate in this batch
|
| 280 |
+
batch_tokens = min(batch_size, max_new_tokens - start_idx)
|
| 281 |
+
|
| 282 |
+
# Accumulate new tokens
|
| 283 |
+
new_tokens = []
|
| 284 |
+
|
| 285 |
+
for _ in range(batch_tokens):
|
| 286 |
+
# Forward pass with proper cache handling
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
outputs = self.model(
|
| 289 |
+
input_ids=current_ids if current_kv is None else current_ids[:, -1:],
|
| 290 |
+
attention_mask=current_mask if current_kv is None else current_mask[:, -1:],
|
| 291 |
+
past_key_values=DynamicCache.from_legacy_cache(current_kv) if current_kv is not None else None,
|
| 292 |
+
use_cache=True
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Sample next token
|
| 296 |
+
next_token_logits = outputs.logits[:, -1, :] / self.temperature
|
| 297 |
+
filtered_logits = self._filter_logits(next_token_logits)
|
| 298 |
+
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
|
| 299 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 300 |
+
|
| 301 |
+
# Add to accumulated tokens
|
| 302 |
+
token_id = next_token.item()
|
| 303 |
+
new_tokens.append(token_id)
|
| 304 |
+
|
| 305 |
+
# Update for next iteration
|
| 306 |
+
current_ids = torch.cat([current_ids, next_token], dim=1)
|
| 307 |
+
token_mask = torch.ones((1, 1), device=current_mask.device, dtype=current_mask.dtype)
|
| 308 |
+
current_mask = torch.cat([current_mask, token_mask], dim=1)
|
| 309 |
+
current_kv = outputs.past_key_values
|
| 310 |
+
|
| 311 |
+
# Check for stop tokens - include both EOS and code_end
|
| 312 |
+
if token_id in eos_token_ids:
|
| 313 |
+
break
|
| 314 |
+
|
| 315 |
+
# Add batch tokens to overall result
|
| 316 |
+
all_tokens.extend(new_tokens)
|
| 317 |
+
|
| 318 |
+
# Check if we hit a stop token
|
| 319 |
+
if len(new_tokens) < batch_tokens:
|
| 320 |
+
break
|
| 321 |
+
|
| 322 |
+
# Convert to tensor
|
| 323 |
+
result = torch.tensor([all_tokens], device=input_ids.device)
|
| 324 |
+
return result, current_kv
|
| 325 |
+
|
| 326 |
+
def _filter_logits(self, logits):
|
| 327 |
+
"""Apply top-k and top-p filtering"""
|
| 328 |
+
if self.top_k > 0:
|
| 329 |
+
top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
|
| 330 |
+
logits[0, :] = torch.full_like(logits[0, :], float('-inf'))
|
| 331 |
+
logits[0, top_k_indices[0]] = top_k_logits[0]
|
| 332 |
+
|
| 333 |
+
if self.top_p < 1.0:
|
| 334 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
| 335 |
+
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
| 336 |
+
|
| 337 |
+
# Remove tokens with cumulative probability above threshold
|
| 338 |
+
sorted_indices_to_remove = cumulative_probs > self.top_p
|
| 339 |
+
# Shift the indices to the right to keep the first token above threshold
|
| 340 |
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
| 341 |
+
sorted_indices_to_remove[:, 0] = 0
|
| 342 |
+
|
| 343 |
+
# Scatter sorted tensors to original indexing
|
| 344 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 345 |
+
logits[indices_to_remove] = float('-inf')
|
| 346 |
+
|
| 347 |
+
return logits
|
| 348 |
|
| 349 |
+
def _retool_generate_with_interpreter(self, prompt_ids_batch, attention_mask_batch, eos_id, interpreter_id, code_id, max_turns=10):
|
| 350 |
+
"""Implementation with custom generation to avoid KV cache issues"""
|
| 351 |
batch_size = prompt_ids_batch.size(0)
|
| 352 |
batch_completion = []
|
| 353 |
batch_interpreter_positions = []
|
| 354 |
+
|
| 355 |
+
for i in range(batch_size):
|
| 356 |
+
# Initialize
|
| 357 |
+
current_input_id = prompt_ids_batch[i:i+1]
|
| 358 |
+
current_attention_mask = attention_mask_batch[i:i+1]
|
| 359 |
current_kv = None
|
| 360 |
+
|
| 361 |
+
# Track completion (excludes prompt)
|
| 362 |
cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device)
|
| 363 |
interpreter_positions = []
|
| 364 |
+
|
| 365 |
for turn_idx in range(max_turns):
|
| 366 |
+
# Check if input is empty
|
| 367 |
+
if current_input_id.size(1) == 0:
|
| 368 |
+
break
|
| 369 |
+
|
| 370 |
+
# Generate with custom function
|
| 371 |
+
newly_generated_tokens, current_kv = self._custom_generate(
|
| 372 |
input_ids=current_input_id,
|
| 373 |
+
attention_mask=current_attention_mask,
|
|
|
|
| 374 |
past_key_values=current_kv,
|
| 375 |
+
max_new_tokens=self.max_completion_length, # Use class attribute
|
| 376 |
+
eos_token_ids=[eos_id, code_id[1]]
|
| 377 |
)
|
| 378 |
+
|
| 379 |
+
# Add to completion
|
| 380 |
+
cumulative_completion_ids = torch.cat([cumulative_completion_ids, newly_generated_tokens], dim=1)
|
| 381 |
+
|
| 382 |
+
# Check last token
|
| 383 |
+
last_token_id = newly_generated_tokens[0, -1].item() if newly_generated_tokens.size(1) > 0 else None
|
| 384 |
+
|
| 385 |
+
# Check for end conditions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
if last_token_id == eos_id or turn_idx == max_turns - 1:
|
| 387 |
batch_completion.append(cumulative_completion_ids.squeeze(0))
|
| 388 |
+
batch_interpreter_positions.append(interpreter_positions)
|
| 389 |
break
|
| 390 |
+
|
| 391 |
+
# Check for code end token
|
| 392 |
+
if last_token_id == code_id[1]:
|
| 393 |
+
# Extract code from the full text
|
| 394 |
+
full_text = self.processing_class.decode(
|
| 395 |
+
torch.cat([prompt_ids_batch[i], cumulative_completion_ids[0]], dim=0)
|
| 396 |
+
)
|
| 397 |
code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)
|
| 398 |
+
|
| 399 |
if code_match:
|
| 400 |
+
code_block = code_match.group(1).strip()
|
| 401 |
+
interpreter_text = self._execute_code(code_block)
|
| 402 |
+
|
| 403 |
+
# Format and add interpreter output
|
| 404 |
+
formatted_feedback = f"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}"
|
| 405 |
+
interpreter_ids = self.processing_class(
|
| 406 |
+
formatted_feedback,
|
| 407 |
+
return_tensors="pt",
|
| 408 |
+
add_special_tokens=False
|
| 409 |
+
).input_ids.to(prompt_ids_batch.device)
|
| 410 |
+
|
| 411 |
+
# Record positions
|
| 412 |
+
interpreter_start_idx = cumulative_completion_ids.size(1)
|
| 413 |
+
cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_ids], dim=1)
|
| 414 |
+
interpreter_end_idx = cumulative_completion_ids.size(1) - 1
|
| 415 |
+
interpreter_positions.append((interpreter_start_idx, interpreter_end_idx))
|
| 416 |
+
|
| 417 |
+
# Set up for next turn
|
| 418 |
+
current_input_id = interpreter_ids
|
| 419 |
+
current_attention_mask = torch.ones_like(current_input_id)
|
| 420 |
+
# Keep current_kv from previous generation
|
| 421 |
else:
|
| 422 |
+
# No code block found despite </code> token
|
| 423 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
else:
|
| 425 |
+
# Continue with the newly generated tokens
|
| 426 |
+
current_input_id = newly_generated_tokens
|
| 427 |
+
current_attention_mask = torch.ones_like(current_input_id)
|
| 428 |
+
else:
|
| 429 |
+
# Loop finished due to max_turns without a break
|
|
|
|
| 430 |
batch_completion.append(cumulative_completion_ids.squeeze(0))
|
| 431 |
batch_interpreter_positions.append(interpreter_positions)
|
| 432 |
+
|
| 433 |
+
# Pad sequences
|
| 434 |
+
if len(batch_completion) > 0:
|
| 435 |
+
# Ensure padding_value is a valid integer
|
| 436 |
+
padding_value = self.processing_class.pad_token_id
|
| 437 |
+
if padding_value is None:
|
| 438 |
+
padding_value = 0 # Use 0 as a default if pad_token_id is None
|
| 439 |
+
|
| 440 |
+
padded_sequences = torch.nn.utils.rnn.pad_sequence(
|
| 441 |
+
batch_completion,
|
| 442 |
+
batch_first=True,
|
| 443 |
+
padding_value=padding_value
|
| 444 |
+
)
|
| 445 |
+
else:
|
| 446 |
+
padded_sequences = torch.empty((0, 0), dtype=torch.long, device=prompt_ids_batch.device)
|
| 447 |
+
|
| 448 |
return padded_sequences, batch_interpreter_positions
|
|
|
|
| 449 |
|
| 450 |
|
| 451 |
def _create_interpreter_mask(
|