WildnerveAI commited on
Commit
e9d9bd0
·
verified ·
1 Parent(s): c640e48

Upload 2 files

Browse files
Files changed (2) hide show
  1. adapter_layer.py +183 -43
  2. dependency_helpers.py +118 -0
adapter_layer.py CHANGED
@@ -6,12 +6,36 @@ import traceback
6
  from typing import Dict, Any, Optional, List
7
  import importlib.util
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  logger = logging.getLogger(__name__)
10
 
11
  class WildnerveModelAdapter:
12
  """
13
  Adapter layer that interfaces between HF inference endpoints and the model.
14
- Uses full tokenizer implementation with multiple fallbacks.
15
  """
16
 
17
  def __init__(self, model_path: str):
@@ -26,62 +50,95 @@ class WildnerveModelAdapter:
26
 
27
  logger.info(f"Model adapter initialized with path: {model_path}")
28
 
29
- # Initialize tokenizer using the proper TokenizerWrapper
30
  self._initialize_tokenizer()
31
 
32
  def _initialize_tokenizer(self):
33
- """Initialize tokenizer with multiple fallbacks"""
34
- # Try to use the full TokenizerWrapper from tokenizer.py
35
  try:
36
- # First check if TokenizerWrapper is already available
37
- tokenizer_spec = importlib.util.find_spec('tokenizer')
38
- if (tokenizer_spec):
39
- logger.info("Found tokenizer.py module, importing TokenizerWrapper")
40
- from tokenizer import TokenizerWrapper, get_tokenizer
 
 
 
 
 
41
 
42
- # Try using the get_tokenizer function
43
- self.tokenizer = get_tokenizer()
44
- if self.tokenizer:
45
- logger.info("Successfully initialized TokenizerWrapper")
46
  return
47
 
48
- # Next try to import it from our model path
49
- tokenizer_path = os.path.join(self.model_path, "tokenizer.py")
50
- if os.path.exists(tokenizer_path):
51
- spec = importlib.util.spec_from_file_location("tokenizer_module", tokenizer_path)
52
- tokenizer_module = importlib.util.module_from_spec(spec)
53
- spec.loader.exec_module(tokenizer_module)
 
54
 
55
- if hasattr(tokenizer_module, 'TokenizerWrapper'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  self.tokenizer = tokenizer_module.TokenizerWrapper()
57
- logger.info("Loaded TokenizerWrapper from tokenizer.py")
58
- return
 
 
 
 
 
 
 
 
59
  except Exception as e:
60
- logger.warning(f"Failed to load TokenizerWrapper: {e}")
61
 
62
- # Try to load from transformers if available
63
  try:
64
- from transformers import AutoTokenizer, BertTokenizer
 
 
 
 
 
 
65
 
66
- # Try multiple model names, starting with our own
67
- for model_name in ["Wildnerve-tlm01_Hybrid_Model", "bert-base-uncased", "gpt2"]:
68
  try:
69
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
70
- logger.info(f"Loaded {model_name} tokenizer via AutoTokenizer")
 
 
 
 
 
71
  return
72
  except Exception as e:
73
- logger.warning(f"Failed to load {model_name} tokenizer: {e}")
74
 
75
- # Direct attempt with BertTokenizer
76
- self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
77
- logger.info("Loaded bert-base-uncased via BertTokenizer")
78
- return
79
-
80
- except Exception as e:
81
- logger.warning(f"Failed to load transformers tokenizers: {e}")
82
 
83
- # If all else fails, use our SimpleTokenizer
84
- logger.warning("All tokenizer loading attempts failed, using SimpleTokenizer")
85
  self.tokenizer = SimpleTokenizer()
86
 
87
  def load_fallback_model(self):
@@ -90,12 +147,63 @@ class WildnerveModelAdapter:
90
  return self.fallback_model
91
 
92
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  self.fallback_model = SimpleFallbackModel(self.tokenizer)
94
- logger.info("Created fallback model")
95
  return self.fallback_model
 
96
  except Exception as e:
97
- logger.error(f"Failed to create fallback model: {e}")
98
- return None
 
 
99
 
100
  def generate(self, prompt: str, **kwargs) -> str:
101
  """Generate a response to the prompt"""
@@ -112,7 +220,38 @@ class WildnerveModelAdapter:
112
  # Try to use the fallback model if it exists or can be created
113
  model = self.load_fallback_model()
114
  if model is not None:
115
- return model.generate(prompt, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  # If fallback model failed, use a simple hardcoded response based on prompt
118
  logger.warning("Using hardcoded response as fallback")
@@ -153,7 +292,8 @@ Please try again later when these issues have been resolved."""
153
  return f"Error generating response: {str(e)}"
154
 
155
 
156
- # Fallback tokenizer implementation
 
157
  class SimpleTokenizer:
158
  """
159
  A minimal tokenizer implementation for fallback purposes.
 
6
  from typing import Dict, Any, Optional, List
7
  import importlib.util
8
 
9
+ # Import dependency helpers
10
+ try:
11
+ from dependency_helpers import safely_import, is_module_available, with_fallback
12
+ except ImportError:
13
+ # Inline implementation if module isn't available
14
+ def safely_import(module_name):
15
+ try:
16
+ return importlib.import_module(module_name)
17
+ except ImportError:
18
+ return None
19
+
20
+ def is_module_available(module_name):
21
+ try:
22
+ importlib.util.find_spec(module_name)
23
+ return True
24
+ except ImportError:
25
+ return False
26
+
27
+ def with_fallback(primary_func, fallback_func, *args, **kwargs):
28
+ try:
29
+ return primary_func(*args, **kwargs)
30
+ except Exception:
31
+ return fallback_func(*args, **kwargs)
32
+
33
  logger = logging.getLogger(__name__)
34
 
35
  class WildnerveModelAdapter:
36
  """
37
  Adapter layer that interfaces between HF inference endpoints and the model.
38
+ Compatible with the original architecture while providing robust fallbacks.
39
  """
40
 
41
  def __init__(self, model_path: str):
 
50
 
51
  logger.info(f"Model adapter initialized with path: {model_path}")
52
 
53
+ # Initialize tokenizer - try to use the original TokenizerWrapper
54
  self._initialize_tokenizer()
55
 
56
  def _initialize_tokenizer(self):
57
+ """Initialize tokenizer while respecting the original architecture"""
58
+ # First try loading config - use original implementation if available
59
  try:
60
+ # Check if we have a config module
61
+ has_config = is_module_available('config')
62
+
63
+ # Try to import from service_registry if available
64
+ has_registry = is_module_available('service_registry')
65
+
66
+ # Use the appropriate approach based on available modules
67
+ if has_registry:
68
+ # Use original registry approach
69
+ from service_registry import registry, TOKENIZER
70
 
71
+ if registry.has(TOKENIZER):
72
+ self.tokenizer = registry.get(TOKENIZER)
73
+ logger.info("Retrieved tokenizer from registry")
 
74
  return
75
 
76
+ # Try loading from the original tokenizer.py
77
+ tokenizer_module = None
78
+
79
+ # First check if it's directly importable
80
+ if is_module_available('tokenizer'):
81
+ tokenizer_module = safely_import('tokenizer')
82
+ logger.info("Imported tokenizer module from Python path")
83
 
84
+ # Next try to load it from model_path
85
+ if tokenizer_module is None:
86
+ tokenizer_path = os.path.join(self.model_path, "tokenizer.py")
87
+ if os.path.exists(tokenizer_path):
88
+ spec = importlib.util.spec_from_file_location("tokenizer_module", tokenizer_path)
89
+ tokenizer_module = importlib.util.module_from_spec(spec)
90
+ spec.loader.exec_module(tokenizer_module)
91
+ logger.info("Loaded tokenizer module from model path")
92
+
93
+ # Create tokenizer if module was loaded
94
+ if tokenizer_module is not None and hasattr(tokenizer_module, 'TokenizerWrapper'):
95
+ # Handle potential missing config_app
96
+ if hasattr(tokenizer_module, 'get_tokenizer'):
97
+ self.tokenizer = tokenizer_module.get_tokenizer()
98
+ else:
99
+ # Try direct instantiation
100
  self.tokenizer = tokenizer_module.TokenizerWrapper()
101
+
102
+ logger.info("Created TokenizerWrapper instance")
103
+
104
+ # Register in registry if available
105
+ if has_registry:
106
+ from service_registry import registry, TOKENIZER
107
+ registry.register(TOKENIZER, self.tokenizer)
108
+
109
+ return
110
+
111
  except Exception as e:
112
+ logger.warning(f"Error initializing original tokenizer: {e}")
113
 
114
+ # If we reach here, try the HuggingFace transformers approach
115
  try:
116
+ from transformers import AutoTokenizer
117
+
118
+ models_to_try = [
119
+ "bert-base-uncased", # Standard BERT model
120
+ "distilbert-base-uncased", # Smaller, faster alternative
121
+ "gpt2" # Another commonly available model
122
+ ]
123
 
124
+ for model_name in models_to_try:
 
125
  try:
126
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
127
+ logger.info(f"Using transformers AutoTokenizer with {model_name}")
128
+
129
+ # Register if registry is available
130
+ if 'registry' in locals() and 'TOKENIZER' in locals():
131
+ registry.register(TOKENIZER, self.tokenizer)
132
+
133
  return
134
  except Exception as e:
135
+ logger.warning(f"Failed to load {model_name}: {e}")
136
 
137
+ except ImportError:
138
+ logger.warning("transformers package not available")
 
 
 
 
 
139
 
140
+ # Last resort: use our SimpleTokenizer implementation
141
+ logger.warning("Using SimpleTokenizer as final fallback")
142
  self.tokenizer = SimpleTokenizer()
143
 
144
  def load_fallback_model(self):
 
147
  return self.fallback_model
148
 
149
  try:
150
+ # First try to import original model if available
151
+ model_module = None
152
+ models_to_try = ["model_Combn", "model_Custm", "model_PrTr"]
153
+
154
+ for model_name in models_to_try:
155
+ try:
156
+ if is_module_available(model_name):
157
+ model_module = safely_import(model_name)
158
+ if model_module:
159
+ logger.info(f"Imported {model_name} module")
160
+ break
161
+ except Exception as e:
162
+ logger.warning(f"Failed to import {model_name}: {e}")
163
+
164
+ # If we found a valid model module, try to instantiate it
165
+ if model_module:
166
+ model_classes = [
167
+ "Wildnerve_tlm01_Hybrid_Model",
168
+ "Wildnerve_tlm01"
169
+ ]
170
+
171
+ for class_name in model_classes:
172
+ if hasattr(model_module, class_name):
173
+ try:
174
+ # Try to instantiate with minimal parameters
175
+ model_class = getattr(model_module, class_name)
176
+ instance = model_class(
177
+ vocab_size=30522,
178
+ specialization="general",
179
+ dataset_path=None,
180
+ model_name="bert-base-uncased",
181
+ embedding_dim=768,
182
+ num_heads=12,
183
+ hidden_dim=768,
184
+ num_layers=6,
185
+ output_size=768,
186
+ dropout=0.1,
187
+ max_seq_length=512,
188
+ pooling_mode="mean",
189
+ tokenizer=self.tokenizer
190
+ )
191
+ logger.info(f"Created {class_name} instance from {model_module.__name__}")
192
+ self.fallback_model = instance
193
+ return self.fallback_model
194
+ except Exception as e:
195
+ logger.warning(f"Failed to instantiate {class_name}: {e}")
196
+
197
+ # If we couldn't use the original model, use our fallback
198
  self.fallback_model = SimpleFallbackModel(self.tokenizer)
199
+ logger.info("Created SimpleFallbackModel")
200
  return self.fallback_model
201
+
202
  except Exception as e:
203
+ logger.error(f"Failed to create any fallback model: {e}")
204
+ # As an absolute last resort, create a minimal model on the fly
205
+ self.fallback_model = SimpleFallbackModel(self.tokenizer)
206
+ return self.fallback_model
207
 
208
  def generate(self, prompt: str, **kwargs) -> str:
209
  """Generate a response to the prompt"""
 
220
  # Try to use the fallback model if it exists or can be created
221
  model = self.load_fallback_model()
222
  if model is not None:
223
+ # Try different generation methods the model might have
224
+ if hasattr(model, "generate_streaming"):
225
+ try:
226
+ # For streaming we need to collect all tokens
227
+ tokens = []
228
+ for token in model.generate_streaming(prompt, **kwargs):
229
+ tokens.append(token)
230
+ return "".join(tokens)
231
+ except Exception as e:
232
+ logger.warning(f"Streaming generation failed: {e}")
233
+
234
+ # Try standard generate methods
235
+ gen_methods = ["generate_with_decoding", "generate"]
236
+ for method_name in gen_methods:
237
+ if hasattr(model, method_name):
238
+ try:
239
+ # Tokenize the input if needed
240
+ if hasattr(self.tokenizer, "__call__"):
241
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
242
+ # Get the result
243
+ method = getattr(model, method_name)
244
+ result = method(input_ids, **kwargs)
245
+ if isinstance(result, str) and result:
246
+ return result
247
+ except Exception as e:
248
+ logger.warning(f"{method_name} failed: {e}")
249
+
250
+ # If we get here, try a final simple generate method
251
+ try:
252
+ return model.generate(prompt, **kwargs)
253
+ except Exception as e:
254
+ logger.warning(f"Direct generation failed: {e}")
255
 
256
  # If fallback model failed, use a simple hardcoded response based on prompt
257
  logger.warning("Using hardcoded response as fallback")
 
292
  return f"Error generating response: {str(e)}"
293
 
294
 
295
+ # Minimal implementations below - these are only used if absolutely necessary
296
+
297
  class SimpleTokenizer:
298
  """
299
  A minimal tokenizer implementation for fallback purposes.
dependency_helpers.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helper utilities for handling dependencies in a graceful manner.
3
+ This module provides functions to check for and load dependencies without crashing.
4
+ """
5
+ import importlib
6
+ import logging
7
+ import sys
8
+ import os
9
+ from typing import Optional, Any, Dict, Callable, List
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def safely_import(module_name: str) -> Optional[Any]:
14
+ """
15
+ Safely import a module without crashing if it's not available.
16
+
17
+ Args:
18
+ module_name: Name of the module to import
19
+
20
+ Returns:
21
+ The imported module or None if import failed
22
+ """
23
+ try:
24
+ return importlib.import_module(module_name)
25
+ except ImportError as e:
26
+ logger.warning(f"Failed to import {module_name}: {e}")
27
+ return None
28
+
29
+ def is_module_available(module_name: str) -> bool:
30
+ """
31
+ Check if a module is available without importing it.
32
+
33
+ Args:
34
+ module_name: Name of the module to check
35
+
36
+ Returns:
37
+ True if module is available, False otherwise
38
+ """
39
+ try:
40
+ importlib.util.find_spec(module_name)
41
+ return True
42
+ except ImportError:
43
+ return False
44
+
45
+ def check_dependencies(dependencies: List[str]) -> Dict[str, bool]:
46
+ """
47
+ Check multiple dependencies at once.
48
+
49
+ Args:
50
+ dependencies: List of module names to check
51
+
52
+ Returns:
53
+ Dictionary mapping module names to availability (True/False)
54
+ """
55
+ return {dep: is_module_available(dep) for dep in dependencies}
56
+
57
+ def get_object_if_available(module_name: str, object_name: str) -> Optional[Any]:
58
+ """
59
+ Get an object from a module if the module is available.
60
+
61
+ Args:
62
+ module_name: Name of the module containing the object
63
+ object_name: Name of the object to get
64
+
65
+ Returns:
66
+ The requested object or None if not available
67
+ """
68
+ module = safely_import(module_name)
69
+ if module and hasattr(module, object_name):
70
+ return getattr(module, object_name)
71
+ return None
72
+
73
+ def with_fallback(primary_func: Callable, fallback_func: Callable, *args, **kwargs) -> Any:
74
+ """
75
+ Call primary_func with the given args/kwargs, falling back to fallback_func if it fails.
76
+
77
+ Args:
78
+ primary_func: Function to try first
79
+ fallback_func: Function to use if primary_func fails
80
+ args: Positional arguments to pass to both functions
81
+ kwargs: Keyword arguments to pass to both functions
82
+
83
+ Returns:
84
+ Result from either primary_func or fallback_func
85
+ """
86
+ try:
87
+ return primary_func(*args, **kwargs)
88
+ except Exception as e:
89
+ logger.warning(f"Primary function {primary_func.__name__} failed: {e}")
90
+ return fallback_func(*args, **kwargs)
91
+
92
+ def install_package(package_name: str) -> bool:
93
+ """
94
+ Attempt to install a package using pip.
95
+ Note: This is generally not recommended in production code but can be useful for development.
96
+
97
+ Args:
98
+ package_name: Name of the package to install
99
+
100
+ Returns:
101
+ True if installation was successful, False otherwise
102
+ """
103
+ try:
104
+ import subprocess
105
+ logger.info(f"Attempting to install {package_name}")
106
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
107
+ return True
108
+ except Exception as e:
109
+ logger.warning(f"Failed to install {package_name}: {e}")
110
+ return False
111
+
112
+ # Check critical dependencies used in the project
113
+ CRITICAL_DEPENDENCIES = ["torch", "transformers", "sentencepiece", "pydantic", "nltk"]
114
+ DEPENDENCY_STATUS = check_dependencies(CRITICAL_DEPENDENCIES)
115
+
116
+ def get_dependency_status() -> Dict[str, bool]:
117
+ """Get the status of critical dependencies."""
118
+ return DEPENDENCY_STATUS