WildnerveAI commited on
Commit
f71f4f5
·
verified ·
1 Parent(s): f677856

Delete generate_tokens_fix.py

Browse files
Files changed (1) hide show
  1. generate_tokens_fix.py +0 -115
generate_tokens_fix.py DELETED
@@ -1,115 +0,0 @@
1
- """
2
- Emergency fix for the recursive call issue in model_Custm.py
3
- This module provides a self-contained implementation of generate_tokens
4
- that doesn't call back to generate() and avoids tensor boolean ambiguity.
5
- """
6
- import os
7
- import torch
8
- import logging
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- def safe_generate_tokens(
13
- model,
14
- input_ids,
15
- max_length=50,
16
- temperature=0.7,
17
- top_k=50,
18
- top_p=0.95,
19
- repetition_penalty=1.0,
20
- **kwargs
21
- ):
22
- """
23
- Non-recursive implementation of generate_tokens that avoids boolean tensor ambiguity.
24
- """
25
- try:
26
- logger.info("Using fixed generate_tokens implementation")
27
-
28
- # Make sure input_ids is a tensor
29
- if not isinstance(input_ids, torch.Tensor):
30
- input_ids = torch.tensor(input_ids, dtype=torch.long)
31
-
32
- # Add batch dimension if needed
33
- if input_ids.dim() == 1:
34
- input_ids = input_ids.unsqueeze(0)
35
-
36
- # Get device - use input tensor's device
37
- device = input_ids.device
38
-
39
- # Initialize generation variables
40
- batch_size = input_ids.shape[0]
41
- cur_len = input_ids.shape[1]
42
-
43
- # Set reasonable defaults for missing parameters
44
- if max_length is None:
45
- max_length = min(getattr(model, 'max_seq_length', 1024), 1024)
46
- max_length = min(max_length, 1024) # Reasonable maximum
47
-
48
- # Create attention mask if needed
49
- attention_mask = None
50
- if hasattr(model, 'transformer'):
51
- attention_mask = torch.ones((batch_size, cur_len), dtype=torch.long, device=device)
52
-
53
- # Initialize generated sequences with input_ids
54
- generated_sequences = input_ids.clone()
55
-
56
- # Get end token ID safely
57
- eos_token_id = None
58
- if hasattr(model, 'tokenizer') and model.tokenizer is not None:
59
- if hasattr(model.tokenizer, 'eos_token_id'):
60
- eos_token_id = model.tokenizer.eos_token_id
61
-
62
- # Track which sequences are finished
63
- unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
64
-
65
- # Simulate simplistic auto-regressive generation to avoid recursion issues
66
- # Just add some fixed tokens to make progress
67
- if input_ids.shape[1] >= max_length:
68
- # Input already at max length, return as is
69
- logger.info(f"Input already at max length ({input_ids.shape[1]} >= {max_length})")
70
- return input_ids
71
-
72
- # Generate a fixed number of new tokens to make progress
73
- num_new_tokens = min(10, max_length - input_ids.shape[1])
74
-
75
- # Create some simple continuation tokens
76
- all_tokens = torch.tensor([[101, 102, 103, 104, 105, 106, 107, 108, 109, 110]], device=device)
77
- continuation = all_tokens[:, :num_new_tokens] # Now slice the created tensor
78
-
79
- # Append continuation to input_ids
80
- result = torch.cat([input_ids, continuation], dim=1)
81
- logger.info(f"Added {num_new_tokens} tokens, new shape: {result.shape}")
82
-
83
- return result
84
-
85
- except Exception as e:
86
- logger.error(f"Error in fixed generate_tokens: {e}")
87
-
88
- # Return input unchanged for safety
89
- return input_ids
90
-
91
- # Monkey patch for model_Custm.Wildnerve_tlm01.generate_tokens
92
- def apply_generate_tokens_fix():
93
- try:
94
- # Import the model class
95
- import model_Custm
96
-
97
- # Check if the class exists
98
- if hasattr(model_Custm, 'Wildnerve_tlm01'):
99
- # Store the original method for reference
100
- model_Custm.Wildnerve_tlm01._original_generate_tokens = model_Custm.Wildnerve_tlm01.generate_tokens
101
-
102
- # Apply the monkey patch
103
- model_Custm.Wildnerve_tlm01.generate_tokens = safe_generate_tokens
104
-
105
- logger.info("Successfully patched model_Custm.Wildnerve_tlm01.generate_tokens")
106
- return True
107
- except Exception as e:
108
- logger.error(f"Failed to apply generate_tokens patch: {e}")
109
-
110
- return False
111
-
112
- # Apply the patch immediately when this module is imported
113
- success = apply_generate_tokens_fix()
114
- if success:
115
- print("PATCHED: model_Custm.generate_tokens has been fixed")