WildnerveAI commited on
Commit
0f72521
·
verified ·
1 Parent(s): 619154c

Upload 3 files

Browse files
Files changed (3) hide show
  1. model_Custm.py +34 -7
  2. model_PrTr.py +8 -1
  3. tokenizer.py +46 -106
model_Custm.py CHANGED
@@ -708,27 +708,54 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
708
  # Handle prompt if provided (convert to input_ids)
709
  if prompt is not None and input_ids is None:
710
  if self.tokenizer is not None:
711
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
712
- input_ids = inputs.input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  else:
714
  # Try to get tokenizer from registry
715
  from service_registry import registry, TOKENIZER
 
 
 
716
  tokenizer = registry.get(TOKENIZER)
717
- if tokenizer:
 
 
 
 
 
 
718
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
719
  input_ids = inputs.input_ids
720
- else:
721
- raise ValueError("No tokenizer available to process prompt")
 
722
 
723
  # Check if we have valid input_ids at this point
724
  if input_ids is None:
725
  raise ValueError("Either prompt or input_ids must be provided")
726
 
727
  # Now continue with original generate implementation that uses input_ids
728
- # ...existing implementation...
729
 
730
  # Simple fallback if no implementation exists
731
- return f"I processed your request about '{prompt[:30]}...' successfully."
732
 
733
  except Exception as e:
734
  logger.error(f"Error in generate: {e}")
 
708
  # Handle prompt if provided (convert to input_ids)
709
  if prompt is not None and input_ids is None:
710
  if self.tokenizer is not None:
711
+ # Check if tokenizer is directly callable
712
+ if callable(self.tokenizer):
713
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
714
+ # Check for encode method (common in TokenizerWrapper implementations)
715
+ elif hasattr(self.tokenizer, "encode"):
716
+ tokens = self.tokenizer.encode(prompt)
717
+ # Convert to tensor if needed
718
+ if isinstance(tokens, list):
719
+ input_ids = torch.tensor([tokens], dtype=torch.long)
720
+ else:
721
+ input_ids = tokens.unsqueeze(0) if tokens.dim() == 1 else tokens
722
+ # Check for tokenize method
723
+ elif hasattr(self.tokenizer, "tokenize"):
724
+ tokens = self.tokenizer.tokenize(prompt)
725
+ if hasattr(self.tokenizer, "convert_tokens_to_ids"):
726
+ token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
727
+ input_ids = torch.tensor([token_ids], dtype=torch.long)
728
+ else:
729
+ raise ValueError(f"Tokenizer type {type(self.tokenizer)} doesn't support required methods")
730
  else:
731
  # Try to get tokenizer from registry
732
  from service_registry import registry, TOKENIZER
733
+ from transformers import AutoTokenizer
734
+
735
+ # Try to get from registry first
736
  tokenizer = registry.get(TOKENIZER)
737
+
738
+ # If not available, create a new one
739
+ if not tokenizer:
740
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
741
+
742
+ # Now use the tokenizer safely
743
+ if callable(tokenizer):
744
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
745
  input_ids = inputs.input_ids
746
+ elif hasattr(tokenizer, "encode"):
747
+ tokens = tokenizer.encode(prompt)
748
+ input_ids = torch.tensor([tokens], dtype=torch.long) if isinstance(tokens, list) else tokens
749
 
750
  # Check if we have valid input_ids at this point
751
  if input_ids is None:
752
  raise ValueError("Either prompt or input_ids must be provided")
753
 
754
  # Now continue with original generate implementation that uses input_ids
755
+ # ...existing code...
756
 
757
  # Simple fallback if no implementation exists
758
+ return f"I processed your request about '{prompt[:30] if prompt else 'your input'}...' successfully."
759
 
760
  except Exception as e:
761
  logger.error(f"Error in generate: {e}")
model_PrTr.py CHANGED
@@ -25,8 +25,15 @@ logger = logging.getLogger(__name__)
25
  # Positional Encoding Module (for decoder)
26
  # ----------------------------
27
  class PositionalEncoding(nn.Module):
28
- def __init__(self, d_model: int, max_len: int = app_config.MAX_SEQ_LENGTH):
29
  super().__init__()
 
 
 
 
 
 
 
30
  pe = torch.zeros(max_len, d_model)
31
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
32
  div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model))
 
25
  # Positional Encoding Module (for decoder)
26
  # ----------------------------
27
  class PositionalEncoding(nn.Module):
28
+ def __init__(self, d_model: int, max_len: Optional[int] = None):
29
  super().__init__()
30
+ # Get MAX_SEQ_LENGTH safely from config
31
+ if max_len is None:
32
+ if hasattr(app_config, "TRANSFORMER_CONFIG") and isinstance(app_config.TRANSFORMER_CONFIG, dict):
33
+ max_len = app_config.TRANSFORMER_CONFIG.get("MAX_SEQ_LENGTH", 1024)
34
+ else:
35
+ max_len = 1024 # Safe default
36
+
37
  pe = torch.zeros(max_len, d_model)
38
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
39
  div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model))
tokenizer.py CHANGED
@@ -22,115 +22,55 @@ from service_registry import registry, TOKENIZER
22
  logger = logging.getLogger(__name__)
23
 
24
  class TokenizerWrapper:
25
- """A simple wrapper around AutoTokenizer to standardize tokenizer usage."""
26
- def __init__(self,
27
- primary_model: str = "Wildnerve-tlm01-0.05Bx12",
28
- fallback_model: str = "bert-base-uncased",
29
- fallback2_model: str = "gpt2",
30
- sp_model_path: str = None):
31
- # Use a robust, multi-fallback initialization
32
- self.primary_model = primary_model
33
- self.fallback_model = fallback_model
34
- self.fallback2_model = fallback2_model
35
- self.sp_model_path = sp_model_path
36
- self.sp = None
37
- self.tokenizer = None
38
- # Advanced feature flags
39
- self.features = {
40
- "normalize_text": True,
41
- "custom_preprocessing": True,
42
- "multi_fallback": True
43
- }
44
- self.initialize_tokenizer()
45
-
46
- def initialize_tokenizer(self):
47
- """Initialize the tokenizer with proper error handling"""
48
- # First, try to load SentencePiece model if provided
49
- if self.sp_model_path and SP_AVAILABLE:
50
- try:
51
- self.sp = spm.SentencePieceProcessor()
52
- self.sp.Load(self.sp_model_path)
53
- logger.info(f"Loaded SentencePiece model from {self.sp_model_path}")
54
- except Exception as e:
55
- logger.warning(f"Failed to load SentencePiece model: {e}")
56
- self.sp = None
57
- else:
58
- if not SP_AVAILABLE and self.sp_model_path:
59
- logger.warning("SentencePiece is not installed; skipping SP model loading")
60
- self.sp = None
61
-
62
- # Next, attempt to load the primary tokenizer
63
  try:
64
- self.tokenizer = AutoTokenizer.from_pretrained(self.primary_model)
65
- logger.info(f"Loaded primary tokenizer: {self.primary_model}")
 
 
 
66
  except Exception as e:
67
- logger.warning(f"Primary tokenizer '{self.primary_model}' load failed: {e}")
68
- try:
69
- self.tokenizer = BertTokenizer.from_pretrained(self.fallback_model)
70
- logger.info(f"Loaded fallback tokenizer: {self.fallback_model}")
71
- except Exception as e2:
72
- logger.warning(f"Fallback tokenizer '{self.fallback_model}' load failed: {e2}")
73
- try:
74
- self.tokenizer = AutoTokenizer.from_pretrained(self.fallback2_model)
75
- logger.info(f"Loaded second fallback tokenizer: {self.fallback2_model}")
76
- except Exception as e3:
77
- logger.error(f"All tokenizer loads failed: {e3}")
78
- self.tokenizer = None
79
-
80
- if self.tokenizer:
81
- registry.register(TOKENIZER, self.tokenizer)
82
-
83
- def advanced_normalize(self, text: str) -> str:
84
- # Advanced normalization: lowercasing and removing extra spaces
85
- normalized = text.strip().lower()
86
- normalized = " ".join(normalized.split())
87
- return normalized
88
-
89
- def tokenize(self, text: str, use_sentencepiece: bool = False) -> list:
90
- """
91
- Tokenize text robustly using SentencePiece if requested and available;
92
- Otherwise use the transformer tokenizer; fallback to simple split if needed.
93
- """
 
 
 
 
 
 
94
  try:
95
- # Apply text normalization if enabled
96
- if self.features["normalize_text"]:
97
- text = self.advanced_normalize(text)
98
- if use_sentencepiece and self.sp:
99
- tokens = self.sp.EncodeAsPieces(text)
100
- logger.debug("Tokenized text using SentencePiece")
101
- return tokens
102
- elif self.tokenizer:
103
- tokens = self.tokenizer.tokenize(text)
104
- logger.debug("Tokenized text using transformer tokenizer")
105
- # Optional custom preprocessing: filter out empty tokens
106
- if self.features["custom_preprocessing"]:
107
- tokens = [tok for tok in tokens if tok.strip()]
108
- return tokens
109
- else:
110
- raise ValueError("No tokenizer available")
111
- except Exception as ex:
112
- logger.error(f"Tokenization failed: {ex}")
113
- return text.split() # fallback
114
-
115
- def encode(self, text: str, **kwargs) -> list:
116
- try:
117
- if self.tokenizer:
118
- return self.tokenizer.encode(text, **kwargs)
119
- else:
120
- raise ValueError("Tokenizer not initialized")
121
- except Exception as e:
122
- logger.error(f"Encoding error: {e}")
123
- return []
124
-
125
- def decode(self, token_ids: list) -> str:
126
- try:
127
- return self.tokenizer.decode(token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
128
- except Exception as ex:
129
- logger.error(f"Decoding failed: {ex}")
130
- return "Decoding error"
131
-
132
- def get_tokenizer(model_name: str = "bert-base-uncased") -> TokenizerWrapper:
133
- return TokenizerWrapper(model_name)
134
 
135
  if __name__ == "__main__":
136
  # Example usage showcasing advanced features
 
22
  logger = logging.getLogger(__name__)
23
 
24
  class TokenizerWrapper:
25
+ """A wrapper for transformer tokenizers with fallbacks"""
26
+
27
+ def __init__(self, model_name="gpt2"):
28
+ self.model_name = model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  try:
30
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
31
+ # Add pad token if it doesn't exist (important for GPT-2)
32
+ if self.tokenizer.pad_token is None:
33
+ self.tokenizer.pad_token = self.tokenizer.eos_token
34
+ logger.info(f"Initialized tokenizer from {model_name}")
35
  except Exception as e:
36
+ logger.error(f"Error loading tokenizer: {e}")
37
+ self.tokenizer = None
38
+
39
+ def __call__(self, text, **kwargs):
40
+ """Make the wrapper callable like a standard HF tokenizer"""
41
+ if self.tokenizer is None:
42
+ raise ValueError("Tokenizer not initialized")
43
+ return self.tokenizer(text, **kwargs)
44
+
45
+ def encode(self, text, **kwargs):
46
+ """Encode text to token IDs"""
47
+ if self.tokenizer is None:
48
+ raise ValueError("Tokenizer not initialized")
49
+ return self.tokenizer.encode(text, **kwargs)
50
+
51
+ def decode(self, token_ids, **kwargs):
52
+ """Decode token IDs to text"""
53
+ if self.tokenizer is None:
54
+ raise ValueError("Tokenizer not initialized")
55
+ return self.tokenizer.decode(token_ids, **kwargs)
56
+
57
+ def tokenize(self, text, **kwargs):
58
+ """Tokenize text to tokens"""
59
+ if self.tokenizer is None:
60
+ raise ValueError("Tokenizer not initialized")
61
+ return self.tokenizer.tokenize(text, **kwargs)
62
+
63
+ def get_tokenizer(model_name="gpt2"):
64
+ """Get a tokenizer instance with proper fallback handling"""
65
+ try:
66
+ return TokenizerWrapper(model_name)
67
+ except Exception as e:
68
+ logger.error(f"Error creating TokenizerWrapper: {e}")
69
  try:
70
+ return AutoTokenizer.from_pretrained(model_name)
71
+ except Exception as e2:
72
+ logger.error(f"Error loading AutoTokenizer: {e2}")
73
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  if __name__ == "__main__":
76
  # Example usage showcasing advanced features