WildnerveAI commited on
Commit
dbbc74a
·
verified ·
1 Parent(s): 5184350

Upload 6 files

Browse files
Files changed (4) hide show
  1. adapter_layer.py +23 -1
  2. config.py +65 -25
  3. handler.py +47 -10
  4. model_List.py +56 -8
adapter_layer.py CHANGED
@@ -8,7 +8,6 @@ import logging
8
  import pydantic # required
9
  import codecarbon
10
  import importlib.util # required
11
- from model_List import PromptAnalyzer
12
  from typing import Dict, Any, Optional, List, Tuple
13
  from service_registry import registry, MODEL, PRETRAINED_MODEL, TOKENIZER
14
 
@@ -27,6 +26,29 @@ def is_module_available(module_name):
27
  except ImportError:
28
  return False
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class WildnerveModelAdapter:
31
  """Adapter layer that interfaces between HF inference endpoints and the model."""
32
  RETRY_COUNT = 5
 
8
  import pydantic # required
9
  import codecarbon
10
  import importlib.util # required
 
11
  from typing import Dict, Any, Optional, List, Tuple
12
  from service_registry import registry, MODEL, PRETRAINED_MODEL, TOKENIZER
13
 
 
26
  except ImportError:
27
  return False
28
 
29
+ # More robust import for PromptAnalyzer
30
+ try:
31
+ from model_List import PromptAnalyzer
32
+ logger.info("Successfully imported PromptAnalyzer")
33
+ except ImportError as e:
34
+ logger.error(f"Error importing PromptAnalyzer: {e}")
35
+ # Create a minimal PromptAnalyzer class
36
+ class PromptAnalyzer:
37
+ def __init__(self, **kwargs):
38
+ self.logger = logging.getLogger(__name__)
39
+ self.predefined_topics = {
40
+ "programming": ["python", "java", "code"],
41
+ "general": ["weather", "hello", "chat"]
42
+ }
43
+
44
+ def analyze_prompt(self, prompt: str):
45
+ # Simple keyword-based routing
46
+ prompt_lower = prompt.lower()
47
+ for tech_word in self.predefined_topics.get("programming", []):
48
+ if tech_word in prompt_lower:
49
+ return "model_Custm", 0.8
50
+ return "model_PrTr", 0.6
51
+
52
  class WildnerveModelAdapter:
53
  """Adapter layer that interfaces between HF inference endpoints and the model."""
54
  RETRY_COUNT = 5
config.py CHANGED
@@ -6,7 +6,7 @@ import argparse
6
  import pydantic # prefer real import in main block
7
  import dependency_helpers # keep helper import early
8
  from pathlib import Path
9
- from typing import Optional, Dict, List, Literal, Any
10
 
11
  # flag indicating real pydantic is present
12
  pydantic_available = True
@@ -351,7 +351,22 @@ class STDPConfig(BaseModel):
351
  extra="allow"
352
  )
353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  class AppConfig(BaseModel):
 
355
  # which model files to load by default
356
  SELECTED_MODEL: List[str] = Field(
357
  default=["model_Custm.py", "model_PrTr.py"],
@@ -387,31 +402,47 @@ class AppConfig(BaseModel):
387
  TOP_K: int = Field(default=3)
388
  MAX_ACTIVE_MODELS: int = Field(default=2)
389
  MODEL_IDLE_THRESHOLD: int = Field(default=600)
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
- class AttrDict(dict):
392
- """Dictionary subclass with attribute-style access"""
393
- __getattr__ = dict.get
394
- __setattr__ = dict.__setitem__
395
- __delattr__ = dict.__delitem__
396
-
397
- def load_config() -> AppConfig:
398
  config_path = os.path.join(os.path.dirname(__file__), "config.json")
399
  logger.info(f"Loading config from {config_path}")
 
 
400
  try:
401
  with open(config_path, "r") as f:
402
- raw = json.load(f)
 
 
 
 
 
 
403
 
404
- # Fix 1: Create AttrDict with config_data attribute first
405
  if isinstance(raw.get("TRANSFORMER_CONFIG"), dict):
406
- transformer_config = AttrDict(raw["TRANSFORMER_CONFIG"])
 
407
 
408
- # Crucial fix: Set config_data immediately and explicitly
409
- transformer_config.config_data = transformer_config
410
 
411
- # Replace the dict with our enhanced AttrDict
412
  raw["TRANSFORMER_CONFIG"] = transformer_config
413
 
414
- # Ensure GPT-2 parameters - these come AFTER setting config_data
415
  if not isinstance(transformer_config.get("VOCAB_SIZE"), int) or transformer_config["VOCAB_SIZE"] != 50257:
416
  transformer_config["VOCAB_SIZE"] = 50257 # Standard GPT-2 vocab size
417
 
@@ -436,19 +467,28 @@ def load_config() -> AppConfig:
436
  logger.error(f"Failed to read config.json: {e}", exc_info=True)
437
  raise
438
 
439
- try:
440
- cfg = AppConfig(**raw)
441
- logger.debug(f"Config loaded: {cfg.json()}")
442
- except ValidationError as ve:
443
- logger.error(f"Config validation error: {ve}", exc_info=True)
444
- raise
445
-
446
- return cfg
 
 
 
 
 
 
 
 
 
 
447
 
448
  # Global application config
449
  app_config = load_config()
450
 
451
  if __name__ == "__main__":
452
  args = argparse.ArgumentParser(description="Tiny Language Model Configuration").parse_args()
453
- print("Configuration loaded:")
454
- print(app_config)
 
6
  import pydantic # prefer real import in main block
7
  import dependency_helpers # keep helper import early
8
  from pathlib import Path
9
+ from typing import Optional, Dict, List, Literal, Any, Union
10
 
11
  # flag indicating real pydantic is present
12
  pydantic_available = True
 
351
  extra="allow"
352
  )
353
 
354
+ class SerializableDict(dict):
355
+ """Dictionary subclass with attribute-style access that can be serialized safely"""
356
+ def __getattr__(self, key):
357
+ if key in self:
358
+ return self[key]
359
+ return None
360
+
361
+ def __setattr__(self, key, value):
362
+ self[key] = value
363
+
364
+ def __delattr__(self, key):
365
+ if key in self:
366
+ del self[key]
367
+
368
  class AppConfig(BaseModel):
369
+ """Main application configuration with proper serialization handling"""
370
  # which model files to load by default
371
  SELECTED_MODEL: List[str] = Field(
372
  default=["model_Custm.py", "model_PrTr.py"],
 
402
  TOP_K: int = Field(default=3)
403
  MAX_ACTIVE_MODELS: int = Field(default=2)
404
  MODEL_IDLE_THRESHOLD: int = Field(default=600)
405
+
406
+ # Add a new Pydantic model_config to fix serialization issues
407
+ model_config = ConfigDict(
408
+ extra="allow", # Allow extra fields not in the model
409
+ arbitrary_types_allowed=True, # Allow arbitrary types
410
+ populate_by_name=True, # Allow population by field name
411
+ json_encoders={
412
+ # Add custom encoders for non-serializable types
413
+ SerializableDict: lambda v: {k: v[k] for k in v if not k.startswith("_")}
414
+ },
415
+ validate_assignment=False # Don't validate on attribute assignment
416
+ )
417
 
418
+ def load_config() -> Union[AppConfig, Dict[str, Any]]:
419
+ """Load configuration from JSON file with robust error handling"""
 
 
 
 
 
420
  config_path = os.path.join(os.path.dirname(__file__), "config.json")
421
  logger.info(f"Loading config from {config_path}")
422
+ raw_config = {}
423
+
424
  try:
425
  with open(config_path, "r") as f:
426
+ try:
427
+ raw = json.load(f)
428
+ raw_config = raw # Save raw config in case Pydantic validation fails
429
+ except json.JSONDecodeError as e:
430
+ logger.error(f"JSON parsing error in config.json: {e}")
431
+ logger.error(f"Error at line {e.lineno}, column {e.colno}: {e.msg}")
432
+ raise
433
 
434
+ # Process the TRANSFORMER_CONFIG section
435
  if isinstance(raw.get("TRANSFORMER_CONFIG"), dict):
436
+ # Convert to SerializableDict instead of AttrDict
437
+ transformer_config = SerializableDict(raw["TRANSFORMER_CONFIG"])
438
 
439
+ # Crucial fix: Add config_data property that doesn't break serialization
440
+ transformer_config["config_data"] = transformer_config
441
 
442
+ # Replace the dict with our enhanced SerializableDict
443
  raw["TRANSFORMER_CONFIG"] = transformer_config
444
 
445
+ # Ensure GPT-2 parameters are set
446
  if not isinstance(transformer_config.get("VOCAB_SIZE"), int) or transformer_config["VOCAB_SIZE"] != 50257:
447
  transformer_config["VOCAB_SIZE"] = 50257 # Standard GPT-2 vocab size
448
 
 
467
  logger.error(f"Failed to read config.json: {e}", exc_info=True)
468
  raise
469
 
470
+ # Try to create AppConfig with pydantic validation
471
+ if pydantic_available:
472
+ try:
473
+ cfg = AppConfig(**raw)
474
+
475
+ # DON'T try to serialize the entire config - this was causing our issue
476
+ # Just log that config loaded successfully
477
+ logger.debug("Config loaded successfully")
478
+ return cfg
479
+ except ValidationError as ve:
480
+ logger.error(f"Config validation error: {ve}", exc_info=True)
481
+
482
+ # Fall back to returning the raw config as a dict
483
+ logger.warning("Using raw config dictionary due to validation failure")
484
+ return raw_config
485
+ else:
486
+ # If pydantic not available, just return the raw dict
487
+ return raw_config
488
 
489
  # Global application config
490
  app_config = load_config()
491
 
492
  if __name__ == "__main__":
493
  args = argparse.ArgumentParser(description="Tiny Language Model Configuration").parse_args()
494
+ print("Configuration loaded successfully!")
 
handler.py CHANGED
@@ -2,26 +2,44 @@
2
  import os
3
  import sys
4
  import time
5
- import torch # Add missing torch import!
6
  import logging
7
  import traceback
8
  from typing import Dict, Any, List
9
  import importlib.util
10
 
11
- # Add this near the top (after imports)
12
- from service_registry import ensure_models_registered
13
- ensure_models_registered()
14
-
15
- # --- DEBUG: confirm correct handler.py is loaded ---
16
- print("DEBUG: using Wildnerve-tlm_HF/handler.py — v4 with dependencies in place")
17
-
18
- # Set up logging
19
  logging.basicConfig(
20
  level=logging.INFO,
21
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
22
  )
23
  logger = logging.getLogger(__name__)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Safely check for required packages without crashing
26
  try:
27
  import pydantic
@@ -92,7 +110,26 @@ class EndpointHandler:
92
  if model_dir:
93
  logger.info(f"Handler init with path: {model_dir}")
94
  try:
95
- # supply model_dir as the adapter’s model_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  self.adapter = WildnerveModelAdapter(model_dir or "")
97
  except Exception as e:
98
  logger.error(f"Adapter init failed: {e}", exc_info=True)
 
2
  import os
3
  import sys
4
  import time
5
+ import torch
6
  import logging
7
  import traceback
8
  from typing import Dict, Any, List
9
  import importlib.util
10
 
11
+ # Configure logging first
 
 
 
 
 
 
 
12
  logging.basicConfig(
13
  level=logging.INFO,
14
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
15
  )
16
  logger = logging.getLogger(__name__)
17
 
18
+ # --- DEBUG: confirm correct handler.py is loaded ---
19
+ print("DEBUG: using Wildnerve-tlm_HF/handler.py — v5 with robust config handling")
20
+
21
+ # Safe config import that won't fail during initialization
22
+ try:
23
+ from config import app_config
24
+ logger.info("Successfully imported config")
25
+ except Exception as e:
26
+ logger.error(f"Error importing config: {e}")
27
+ # Create minimal config to avoid further errors
28
+ app_config = {
29
+ "MODEL_NAME": "Wildnerve-tlm01_Hybrid_Model",
30
+ "TRANSFORMER_CONFIG": {
31
+ "MODEL_NAME": "gpt2",
32
+ "VOCAB_SIZE": 50257
33
+ }
34
+ }
35
+
36
+ # Add this near the top (after imports)
37
+ try:
38
+ from service_registry import ensure_models_registered
39
+ ensure_models_registered()
40
+ except Exception as e:
41
+ logger.error(f"Error ensuring models are registered: {e}")
42
+
43
  # Safely check for required packages without crashing
44
  try:
45
  import pydantic
 
110
  if model_dir:
111
  logger.info(f"Handler init with path: {model_dir}")
112
  try:
113
+ # Try to import adapter layer
114
+ try:
115
+ # For more reliable importing
116
+ script_dir = os.path.dirname(os.path.abspath(__file__))
117
+ sys.path.insert(0, script_dir)
118
+
119
+ from adapter_layer import WildnerveModelAdapter
120
+ logger.info("Successfully imported adapter_layer module")
121
+ except ImportError as e:
122
+ logger.error(f"Could not import adapter_layer: {e}")
123
+ # Create a minimal placeholder adapter class
124
+ class WildnerveModelAdapter:
125
+ def __init__(self, model_path: str =""):
126
+ self.model_path = model_path
127
+ logger.info(f"Using fallback WildnerveModelAdapter with path: {model_path}")
128
+
129
+ def generate(self, text_input, **kwargs):
130
+ return f"Model adapter unavailable. Received input: {text_input[:30]}..."
131
+
132
+ # supply model_dir as the adapter's model_path
133
  self.adapter = WildnerveModelAdapter(model_dir or "")
134
  except Exception as e:
135
  logger.error(f"Adapter init failed: {e}", exc_info=True)
model_List.py CHANGED
@@ -17,7 +17,23 @@ try:
17
  except LookupError:
18
  nltk.download("punkt")
19
  from service_registry import registry, TOKENIZER, MODEL
20
- from config import app_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Add SmartHybridAttention imports
22
  from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
23
 
@@ -35,14 +51,35 @@ class PromptAnalyzer:
35
  def __init__(self, model_name=None, dataset_path=None, specialization=None, hidden_dim=None):
36
  self.logger = logging.getLogger(__name__)
37
 
38
- # Load config
39
- self.config = load_config(config_file="config.json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Use provided values or config values
42
- self.model_name = model_name or self.config.PROMPT_ANALYZER_CONFIG.MODEL_NAME
43
- self.dataset_path = dataset_path or self.config.PROMPT_ANALYZER_CONFIG.DATASET_PATH
44
- self.specialization = specialization or self.config.PROMPT_ANALYZER_CONFIG.SPECIALIZATION
45
- self.hidden_dim = hidden_dim or self.config.PROMPT_ANALYZER_CONFIG.HIDDEN_DIM
46
 
47
  self.logger.info(f"Initialized PromptAnalyzer with {self.model_name}")
48
  self._model_cache: Dict[str, Type] = {}
@@ -91,6 +128,17 @@ class PromptAnalyzer:
91
  except Exception:
92
  pass
93
 
 
 
 
 
 
 
 
 
 
 
 
94
  def _load_predefined_topics(self):
95
  """Load topic keywords from config file or use defaults with caching"""
96
  # Try to load from config first
 
17
  except LookupError:
18
  nltk.download("punkt")
19
  from service_registry import registry, TOKENIZER, MODEL
20
+
21
+ # More robust config import
22
+ try:
23
+ from config import app_config
24
+ except ImportError:
25
+ logger.error("Failed to import app_config from config")
26
+ # Create minimal app_config
27
+ app_config = {
28
+ "PROMPT_ANALYZER_CONFIG": {
29
+ "MODEL_NAME": "gpt2",
30
+ "DATASET_PATH": None,
31
+ "SPECIALIZATION": None,
32
+ "HIDDEN_DIM": 768,
33
+ "MAX_CACHE_SIZE": 10
34
+ }
35
+ }
36
+
37
  # Add SmartHybridAttention imports
38
  from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
39
 
 
51
  def __init__(self, model_name=None, dataset_path=None, specialization=None, hidden_dim=None):
52
  self.logger = logging.getLogger(__name__)
53
 
54
+ # Load config with better error handling
55
+ try:
56
+ if hasattr(app_config, "PROMPT_ANALYZER_CONFIG"):
57
+ self.config_data = app_config.PROMPT_ANALYZER_CONFIG
58
+ elif isinstance(app_config, dict) and "PROMPT_ANALYZER_CONFIG" in app_config:
59
+ self.config_data = app_config["PROMPT_ANALYZER_CONFIG"]
60
+ else:
61
+ self.config_data = {
62
+ "MODEL_NAME": "gpt2",
63
+ "DATASET_PATH": None,
64
+ "SPECIALIZATION": None,
65
+ "HIDDEN_DIM": 768,
66
+ "MAX_CACHE_SIZE": 10
67
+ }
68
+ except Exception as e:
69
+ self.logger.warning(f"Error loading config: {e}, using defaults")
70
+ self.config_data = {
71
+ "MODEL_NAME": "gpt2",
72
+ "DATASET_PATH": None,
73
+ "SPECIALIZATION": None,
74
+ "HIDDEN_DIM": 768,
75
+ "MAX_CACHE_SIZE": 10
76
+ }
77
 
78
+ # Use provided values or config values with safe getters
79
+ self.model_name = model_name or self._safe_get("MODEL_NAME", "gpt2")
80
+ self.dataset_path = dataset_path or self._safe_get("DATASET_PATH")
81
+ self.specialization = specialization or self._safe_get("SPECIALIZATION")
82
+ self.hidden_dim = hidden_dim or self._safe_get("HIDDEN_DIM", 768)
83
 
84
  self.logger.info(f"Initialized PromptAnalyzer with {self.model_name}")
85
  self._model_cache: Dict[str, Type] = {}
 
128
  except Exception:
129
  pass
130
 
131
+ def _safe_get(self, key, default=None):
132
+ """Safely get a configuration value regardless of config type"""
133
+ try:
134
+ if isinstance(self.config_data, dict):
135
+ return self.config_data.get(key, default)
136
+ elif hasattr(self.config_data, key):
137
+ return getattr(self.config_data, key, default)
138
+ return default
139
+ except:
140
+ return default
141
+
142
  def _load_predefined_topics(self):
143
  """Load topic keywords from config file or use defaults with caching"""
144
  # Try to load from config first