Delete generate_tokens_fix.py
Browse files- generate_tokens_fix.py +0 -115
generate_tokens_fix.py
DELETED
|
@@ -1,115 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Emergency fix for the recursive call issue in model_Custm.py
|
| 3 |
-
This module provides a self-contained implementation of generate_tokens
|
| 4 |
-
that doesn't call back to generate() and avoids tensor boolean ambiguity.
|
| 5 |
-
"""
|
| 6 |
-
import os
|
| 7 |
-
import torch
|
| 8 |
-
import logging
|
| 9 |
-
|
| 10 |
-
logger = logging.getLogger(__name__)
|
| 11 |
-
|
| 12 |
-
def safe_generate_tokens(
|
| 13 |
-
model,
|
| 14 |
-
input_ids,
|
| 15 |
-
max_length=50,
|
| 16 |
-
temperature=0.7,
|
| 17 |
-
top_k=50,
|
| 18 |
-
top_p=0.95,
|
| 19 |
-
repetition_penalty=1.0,
|
| 20 |
-
**kwargs
|
| 21 |
-
):
|
| 22 |
-
"""
|
| 23 |
-
Non-recursive implementation of generate_tokens that avoids boolean tensor ambiguity.
|
| 24 |
-
"""
|
| 25 |
-
try:
|
| 26 |
-
logger.info("Using fixed generate_tokens implementation")
|
| 27 |
-
|
| 28 |
-
# Make sure input_ids is a tensor
|
| 29 |
-
if not isinstance(input_ids, torch.Tensor):
|
| 30 |
-
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
| 31 |
-
|
| 32 |
-
# Add batch dimension if needed
|
| 33 |
-
if input_ids.dim() == 1:
|
| 34 |
-
input_ids = input_ids.unsqueeze(0)
|
| 35 |
-
|
| 36 |
-
# Get device - use input tensor's device
|
| 37 |
-
device = input_ids.device
|
| 38 |
-
|
| 39 |
-
# Initialize generation variables
|
| 40 |
-
batch_size = input_ids.shape[0]
|
| 41 |
-
cur_len = input_ids.shape[1]
|
| 42 |
-
|
| 43 |
-
# Set reasonable defaults for missing parameters
|
| 44 |
-
if max_length is None:
|
| 45 |
-
max_length = min(getattr(model, 'max_seq_length', 1024), 1024)
|
| 46 |
-
max_length = min(max_length, 1024) # Reasonable maximum
|
| 47 |
-
|
| 48 |
-
# Create attention mask if needed
|
| 49 |
-
attention_mask = None
|
| 50 |
-
if hasattr(model, 'transformer'):
|
| 51 |
-
attention_mask = torch.ones((batch_size, cur_len), dtype=torch.long, device=device)
|
| 52 |
-
|
| 53 |
-
# Initialize generated sequences with input_ids
|
| 54 |
-
generated_sequences = input_ids.clone()
|
| 55 |
-
|
| 56 |
-
# Get end token ID safely
|
| 57 |
-
eos_token_id = None
|
| 58 |
-
if hasattr(model, 'tokenizer') and model.tokenizer is not None:
|
| 59 |
-
if hasattr(model.tokenizer, 'eos_token_id'):
|
| 60 |
-
eos_token_id = model.tokenizer.eos_token_id
|
| 61 |
-
|
| 62 |
-
# Track which sequences are finished
|
| 63 |
-
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
|
| 64 |
-
|
| 65 |
-
# Simulate simplistic auto-regressive generation to avoid recursion issues
|
| 66 |
-
# Just add some fixed tokens to make progress
|
| 67 |
-
if input_ids.shape[1] >= max_length:
|
| 68 |
-
# Input already at max length, return as is
|
| 69 |
-
logger.info(f"Input already at max length ({input_ids.shape[1]} >= {max_length})")
|
| 70 |
-
return input_ids
|
| 71 |
-
|
| 72 |
-
# Generate a fixed number of new tokens to make progress
|
| 73 |
-
num_new_tokens = min(10, max_length - input_ids.shape[1])
|
| 74 |
-
|
| 75 |
-
# Create some simple continuation tokens
|
| 76 |
-
all_tokens = torch.tensor([[101, 102, 103, 104, 105, 106, 107, 108, 109, 110]], device=device)
|
| 77 |
-
continuation = all_tokens[:, :num_new_tokens] # Now slice the created tensor
|
| 78 |
-
|
| 79 |
-
# Append continuation to input_ids
|
| 80 |
-
result = torch.cat([input_ids, continuation], dim=1)
|
| 81 |
-
logger.info(f"Added {num_new_tokens} tokens, new shape: {result.shape}")
|
| 82 |
-
|
| 83 |
-
return result
|
| 84 |
-
|
| 85 |
-
except Exception as e:
|
| 86 |
-
logger.error(f"Error in fixed generate_tokens: {e}")
|
| 87 |
-
|
| 88 |
-
# Return input unchanged for safety
|
| 89 |
-
return input_ids
|
| 90 |
-
|
| 91 |
-
# Monkey patch for model_Custm.Wildnerve_tlm01.generate_tokens
|
| 92 |
-
def apply_generate_tokens_fix():
|
| 93 |
-
try:
|
| 94 |
-
# Import the model class
|
| 95 |
-
import model_Custm
|
| 96 |
-
|
| 97 |
-
# Check if the class exists
|
| 98 |
-
if hasattr(model_Custm, 'Wildnerve_tlm01'):
|
| 99 |
-
# Store the original method for reference
|
| 100 |
-
model_Custm.Wildnerve_tlm01._original_generate_tokens = model_Custm.Wildnerve_tlm01.generate_tokens
|
| 101 |
-
|
| 102 |
-
# Apply the monkey patch
|
| 103 |
-
model_Custm.Wildnerve_tlm01.generate_tokens = safe_generate_tokens
|
| 104 |
-
|
| 105 |
-
logger.info("Successfully patched model_Custm.Wildnerve_tlm01.generate_tokens")
|
| 106 |
-
return True
|
| 107 |
-
except Exception as e:
|
| 108 |
-
logger.error(f"Failed to apply generate_tokens patch: {e}")
|
| 109 |
-
|
| 110 |
-
return False
|
| 111 |
-
|
| 112 |
-
# Apply the patch immediately when this module is imported
|
| 113 |
-
success = apply_generate_tokens_fix()
|
| 114 |
-
if success:
|
| 115 |
-
print("PATCHED: model_Custm.generate_tokens has been fixed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|