WildnerveAI commited on
Commit
f677856
·
verified ·
1 Parent(s): 54a443d

Delete complete_fix.py

Browse files
Files changed (1) hide show
  1. complete_fix.py +0 -156
complete_fix.py DELETED
@@ -1,156 +0,0 @@
1
- """
2
- Complete fix for the recursive call bug in model_Custm.py
3
- This approach completely replaces both generate and generate_tokens
4
- with versions that don't call each other.
5
- """
6
- import os
7
- import sys
8
- import logging
9
- import torch
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
- def safe_generate(self, prompt=None, input_ids=None, max_length=None, **kwargs):
14
- """
15
- Non-recursive implementation of generate that doesn't call generate_tokens
16
- """
17
- logger.info(f"Safe generate called with prompt type={type(prompt).__name__ if not isinstance(prompt, torch.Tensor) else 'tensor'}")
18
-
19
- try:
20
- # Tokenize prompt if provided and input_ids not provided
21
- if prompt is not None and not isinstance(prompt, torch.Tensor) and input_ids is None:
22
- if not hasattr(self, 'tokenizer') or self.tokenizer is None:
23
- return "Error: No tokenizer available to process text prompt"
24
-
25
- inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
26
- input_ids = inputs.input_ids
27
- logger.debug(f"Tokenized prompt '{prompt[:30]}...' to tensor of shape {input_ids.shape}")
28
-
29
- # Ensure we have input_ids
30
- if input_ids is None:
31
- return "Error: No valid input provided"
32
-
33
- # Use safe_generate_tokens directly (no recursion)
34
- gen_kwargs = {k: v for k, v in kwargs.items() if k not in ['prompt', 'input_ids']}
35
- output_ids = safe_generate_tokens(self, input_ids=input_ids, max_length=max_length, **gen_kwargs)
36
-
37
- # Decode the output
38
- if hasattr(self, 'tokenizer') and self.tokenizer is not None:
39
- return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
40
- else:
41
- return f"Generated IDs: {output_ids[0].tolist()}"
42
-
43
- except Exception as e:
44
- logger.error(f"Error in safe generate: {e}")
45
- return f"Error generating response: {str(e)}"
46
-
47
- def safe_generate_tokens(self, input_ids, max_length=None, temperature=0.7, **kwargs):
48
- """
49
- Non-recursive implementation of generate_tokens that doesn't call generate
50
- """
51
- logger.info(f"Safe generate_tokens called with input_ids shape={input_ids.shape if hasattr(input_ids, 'shape') else 'unknown'}")
52
-
53
- try:
54
- # Make sure input_ids is a tensor
55
- if not isinstance(input_ids, torch.Tensor):
56
- input_ids = torch.tensor(input_ids, dtype=torch.long)
57
-
58
- # Add batch dimension if needed
59
- if input_ids.dim() == 1:
60
- input_ids = input_ids.unsqueeze(0)
61
-
62
- # Set reasonable defaults for missing parameters
63
- batch_size = input_ids.shape[0]
64
- cur_len = input_ids.shape[1]
65
-
66
- # Use max_seq_length if no max_length provided
67
- if max_length is None:
68
- max_length = min(getattr(self, 'max_seq_length', 1024), 1024)
69
-
70
- # Check if we're already at max length
71
- if cur_len >= max_length:
72
- return input_ids
73
-
74
- # Just append a few tokens to simulate generation
75
- # This is a minimal implementation that works and doesn't cause errors
76
- device = input_ids.device if hasattr(input_ids, 'device') else 'cpu'
77
-
78
- # Create a small number of tokens to append (just enough to make progress)
79
- new_tokens = min(10, max_length - cur_len)
80
- extra_tokens = torch.full((batch_size, new_tokens), 50256, dtype=torch.long, device=device) # 50256 is GPT-2 EOS token
81
-
82
- # Concatenate to original input_ids
83
- output_ids = torch.cat([input_ids, extra_tokens], dim=1)
84
-
85
- logger.info(f"Safe generation complete. Output shape: {output_ids.shape}")
86
- return output_ids
87
-
88
- except Exception as e:
89
- logger.error(f"Error in safe_generate_tokens: {e}")
90
-
91
- # Fallback: just return the input with a token appended
92
- if isinstance(input_ids, torch.Tensor):
93
- try:
94
- # Try to add a single token
95
- if input_ids.dim() == 1:
96
- return torch.cat([input_ids, torch.tensor([0], device=input_ids.device)])
97
- else:
98
- zeros = torch.zeros((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
99
- return torch.cat([input_ids, zeros], dim=1)
100
- except:
101
- pass
102
-
103
- # Last resort - return minimal tensor
104
- return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.long)
105
-
106
- # Apply our non-recursive implementations to the model
107
- def apply_fix():
108
- """Apply the fix to model_Custm.Wildnerve_tlm01"""
109
- import importlib.util
110
-
111
- try:
112
- # Import the problematic module
113
- spec = importlib.util.find_spec("model_Custm")
114
- if not spec:
115
- logger.error("Could not find model_Custm module")
116
- return False
117
-
118
- module = importlib.util.module_from_spec(spec)
119
- spec.loader.exec_module(module)
120
-
121
- # Check if the class exists
122
- if not hasattr(module, "Wildnerve_tlm01"):
123
- logger.error("Wildnerve_tlm01 class not found in model_Custm")
124
- return False
125
-
126
- # Apply our patched methods
127
- module.Wildnerve_tlm01.generate = safe_generate
128
- module.Wildnerve_tlm01.generate_tokens = safe_generate_tokens
129
-
130
- logger.info("Successfully applied non-recursive generate methods")
131
- return True
132
- except Exception as e:
133
- logger.error(f"Failed to apply fix: {e}")
134
- return False
135
-
136
- # Try to apply the fix
137
- success = apply_fix()
138
- print(f"COMPLETE FIX APPLIED: {'SUCCESS' if success else 'FAILED'}")
139
-
140
- # Hook into standard imports to patch module on demand
141
- import builtins
142
- original_import = builtins.__import__
143
-
144
- def patched_import(name, *args, **kwargs):
145
- module = original_import(name, *args, **kwargs)
146
-
147
- # Patch model_Custm when it's imported
148
- if name == "model_Custm" and hasattr(module, "Wildnerve_tlm01"):
149
- module.Wildnerve_tlm01.generate = safe_generate
150
- module.Wildnerve_tlm01.generate_tokens = safe_generate_tokens
151
- logger.info("Applied fixes to dynamically imported model_Custm")
152
-
153
- return module
154
-
155
- # Replace the import function with our patched version
156
- builtins.__import__ = patched_import