WildnerveAI commited on
Commit
ea7b133
·
verified ·
1 Parent(s): ea02f44

Upload 3 files

Browse files
Files changed (3) hide show
  1. adapter_layer.py +6 -0
  2. config.py +28 -5
  3. model_List.py +7 -1
adapter_layer.py CHANGED
@@ -193,6 +193,12 @@ class WildnerveModelAdapter:
193
  # All attempts failed
194
  raise ImportError("No model registered in service registry")
195
 
 
 
 
 
 
 
196
  def _build_init_kwargs(self):
197
  return {
198
  "vocab_size": 30522,
 
193
  # All attempts failed
194
  raise ImportError("No model registered in service registry")
195
 
196
+ # When storing models/objects, make sure we don't create circular references
197
+ if registry.has(MODEL):
198
+ self.model = registry.get(MODEL)
199
+ # Don't add back-references to registry or other objects that might
200
+ # include this adapter, to avoid circular references
201
+
202
  def _build_init_kwargs(self):
203
  return {
204
  "vocab_size": 30522,
config.py CHANGED
@@ -364,6 +364,24 @@ class SerializableDict(dict):
364
  def __delattr__(self, key):
365
  if key in self:
366
  del self[key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
  class AppConfig(BaseModel):
369
  """Main application configuration with proper serialization handling"""
@@ -433,11 +451,16 @@ def load_config() -> Union[AppConfig, Dict[str, Any]]:
433
 
434
  # Process the TRANSFORMER_CONFIG section
435
  if isinstance(raw.get("TRANSFORMER_CONFIG"), dict):
436
- # Convert to SerializableDict instead of AttrDict
437
  transformer_config = SerializableDict(raw["TRANSFORMER_CONFIG"])
438
 
439
- # Crucial fix: Add config_data property that doesn't break serialization
440
- transformer_config["config_data"] = transformer_config
 
 
 
 
 
441
 
442
  # Replace the dict with our enhanced SerializableDict
443
  raw["TRANSFORMER_CONFIG"] = transformer_config
@@ -472,10 +495,10 @@ def load_config() -> Union[AppConfig, Dict[str, Any]]:
472
  try:
473
  cfg = AppConfig(**raw)
474
 
475
- # DON'T try to serialize the entire config - this was causing our issue
476
- # Just log that config loaded successfully
477
  logger.debug("Config loaded successfully")
478
  return cfg
 
479
  except ValidationError as ve:
480
  logger.error(f"Config validation error: {ve}", exc_info=True)
481
 
 
364
  def __delattr__(self, key):
365
  if key in self:
366
  del self[key]
367
+
368
+ # Add special methods to handle JSON serialization
369
+ def __getstate__(self):
370
+ """Return state for pickling - exclude config_data if it's self"""
371
+ state = dict(self)
372
+ if 'config_data' in state and id(state['config_data']) == id(self):
373
+ state['config_data'] = '__self__' # Replace self-reference with marker
374
+ return state
375
+
376
+ def __repr__(self):
377
+ """Safe representation that handles circular references"""
378
+ items = []
379
+ for k, v in self.items():
380
+ if k == "config_data" and v is self:
381
+ items.append(f"{k}=<self>")
382
+ else:
383
+ items.append(f"{k}={v!r}")
384
+ return f"{self.__class__.__name__}({', '.join(items)})"
385
 
386
  class AppConfig(BaseModel):
387
  """Main application configuration with proper serialization handling"""
 
451
 
452
  # Process the TRANSFORMER_CONFIG section
453
  if isinstance(raw.get("TRANSFORMER_CONFIG"), dict):
454
+ # Create SerializableDict with safe self-reference handling
455
  transformer_config = SerializableDict(raw["TRANSFORMER_CONFIG"])
456
 
457
+ # Store reference to self using a descriptor instead of direct reference
458
+ class ConfigDataDescriptor:
459
+ def __get__(self, obj, objtype=None):
460
+ return obj
461
+
462
+ # Add descriptor to class
463
+ type(transformer_config).config_data = ConfigDataDescriptor()
464
 
465
  # Replace the dict with our enhanced SerializableDict
466
  raw["TRANSFORMER_CONFIG"] = transformer_config
 
495
  try:
496
  cfg = AppConfig(**raw)
497
 
498
+ # Just log success message
 
499
  logger.debug("Config loaded successfully")
500
  return cfg
501
+
502
  except ValidationError as ve:
503
  logger.error(f"Config validation error: {ve}", exc_info=True)
504
 
model_List.py CHANGED
@@ -238,7 +238,12 @@ class PromptAnalyzer:
238
  self.attention = None
239
 
240
  def _track_model_performance(self, model_type: str, start_time: float) -> None:
241
- """Track model loading and performance metrics"""
 
 
 
 
 
242
  end_time = time.time()
243
  if model_type not in self._performance_metrics:
244
  self._performance_metrics[model_type] = {
@@ -247,6 +252,7 @@ class PromptAnalyzer:
247
  'avg_response_time': 0.0
248
  }
249
 
 
250
  metrics = self._performance_metrics[model_type]
251
  metrics['load_time'] = end_time - start_time
252
  metrics['usage_count'] += 1
 
238
  self.attention = None
239
 
240
  def _track_model_performance(self, model_type: str, start_time: float) -> None:
241
+ """Track model loading and performance metrics.
242
+
243
+ Args:
244
+ model_type: Type of model being tracked
245
+ start_time: Start time of operation
246
+ """
247
  end_time = time.time()
248
  if model_type not in self._performance_metrics:
249
  self._performance_metrics[model_type] = {
 
252
  'avg_response_time': 0.0
253
  }
254
 
255
+ # Ensure we're not creating circular references that might impact serialization
256
  metrics = self._performance_metrics[model_type]
257
  metrics['load_time'] = end_time - start_time
258
  metrics['usage_count'] += 1