WildnerveAI commited on
Commit
1a8d9bc
·
verified ·
1 Parent(s): 84ad412

Upload 5 files

Browse files
Files changed (3) hide show
  1. config_utils.py +95 -0
  2. handler.py +9 -0
  3. main.py +23 -3
config_utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for safely accessing configuration values regardless of whether they're stored
3
+ in dictionaries or objects.
4
+ """
5
+
6
+ def safe_get_config(config_obj, key, default=None):
7
+ """
8
+ Safely get a configuration value regardless of whether the config object is a dictionary or object.
9
+
10
+ Args:
11
+ config_obj: Configuration object (dict or object with attributes)
12
+ key: Key or attribute name to access
13
+ default: Default value to return if key/attribute doesn't exist
14
+
15
+ Returns:
16
+ The value of the key/attribute, or default if not found
17
+ """
18
+ if config_obj is None:
19
+ return default
20
+
21
+ if isinstance(config_obj, dict):
22
+ return config_obj.get(key, default)
23
+
24
+ return getattr(config_obj, key, default)
25
+
26
+ def get_model_name(config_obj):
27
+ """
28
+ Gets model name from config with proper fallbacks.
29
+
30
+ Args:
31
+ config_obj: Configuration object
32
+
33
+ Returns:
34
+ Model name string
35
+ """
36
+ transformer_config = safe_get_config(config_obj, "TRANSFORMER_CONFIG", {})
37
+ return safe_get_config(transformer_config, "MODEL_NAME", "bert-base-uncased")
38
+
39
+ def get_embedding_model(config_obj):
40
+ """
41
+ Creates or retrieves an embedding model based on configuration.
42
+
43
+ Args:
44
+ config_obj: Configuration object
45
+
46
+ Returns:
47
+ Embedding model instance
48
+ """
49
+ from utils.transformer_utils import get_sentence_transformer
50
+ model_name = get_model_name(config_obj)
51
+ return get_sentence_transformer(model_name)
52
+
53
+ """Utilities for configuration validation and fixes"""
54
+ import os
55
+ import logging
56
+ from typing import Dict, Any
57
+
58
+ logger = logging.getLogger(__name__)
59
+
60
+ def validate_config(config: Any) -> None:
61
+ """Validate and fix configuration objects to prevent errors"""
62
+
63
+ # Fix TRANSFORMER_CONFIG if it exists
64
+ if hasattr(config, 'TRANSFORMER_CONFIG'):
65
+ tc = config.TRANSFORMER_CONFIG
66
+
67
+ # If it's a dict, make sure it has necessary values
68
+ if isinstance(tc, dict):
69
+ # Ensure required keys have default values
70
+ defaults = {
71
+ "MAX_SEQ_LENGTH": 512,
72
+ "MODEL_NAME": "bert-base-uncased",
73
+ "NUM_LAYERS": 6,
74
+ "EMBEDDING_DIM": 768,
75
+ "NUM_HEADS": 12,
76
+ "HIDDEN_DIM": 768,
77
+ "DROPOUT": 0.1,
78
+ "POOLING_MODE": "mean"
79
+ }
80
+
81
+ # Add any missing defaults
82
+ for key, value in defaults.items():
83
+ if key not in tc:
84
+ tc[key] = value
85
+ logger.info(f"Added default {key}={value} to TRANSFORMER_CONFIG")
86
+
87
+ # Check data directories
88
+ data_dir = getattr(config, "DATA_DIR", os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data"))
89
+ os.makedirs(data_dir, exist_ok=True)
90
+
91
+ # Check model directory
92
+ model_dir = getattr(config, "MODEL_DIR", os.path.join(data_dir, "models"))
93
+ os.makedirs(model_dir, exist_ok=True)
94
+
95
+ logger.info("Configuration validated and fixed")
handler.py CHANGED
@@ -76,6 +76,15 @@ except ImportError as e:
76
  def generate(self, text_input, **kwargs):
77
  return f"Model adapter unavailable. Received input: {text_input[:30]}..."
78
 
 
 
 
 
 
 
 
 
 
79
  class EndpointHandler:
80
  def __init__(self, model_dir: str = None):
81
  # HF toolkit passes model directory here; log or ignore
 
76
  def generate(self, text_input, **kwargs):
77
  return f"Model adapter unavailable. Received input: {text_input[:30]}..."
78
 
79
+ # After imports but before EndpointHandler class
80
+ try:
81
+ # Try to initialize the system first
82
+ from main import initialize_system
83
+ success = initialize_system()
84
+ logger.info(f"System initialization {'successful' if success else 'failed'}")
85
+ except Exception as e:
86
+ logger.error(f"Failed to initialize system: {e}")
87
+
88
  class EndpointHandler:
89
  def __init__(self, model_dir: str = None):
90
  # HF toolkit passes model directory here; log or ignore
main.py CHANGED
@@ -840,20 +840,40 @@ def initialize_system():
840
  registry.register(TOKENIZER, tokenizer, overwrite=True)
841
  logger.info("Tokenizer registered")
842
 
843
- # Now load model
844
  try:
845
  from model_Custm import Wildnerve_tlm01
846
- model = Wildnerve_tlm01(tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
847
 
848
  # Register model
849
  from service_registry import MODEL
850
  registry.register(MODEL, model, overwrite=True)
851
  logger.info("Model registered successfully")
 
 
852
  return True
853
  except Exception as e:
854
- logger.error(f"Failed to initialize model: {e}")
855
  return False
856
 
 
 
 
 
857
  if __name__ == "__main__":
858
  success = initialize_system()
859
  logger.info(f"Initialization {'successful' if success else 'failed'}")
 
840
  registry.register(TOKENIZER, tokenizer, overwrite=True)
841
  logger.info("Tokenizer registered")
842
 
843
+ # Now load model with safer error handling
844
  try:
845
  from model_Custm import Wildnerve_tlm01
846
+ model = Wildnerve_tlm01(
847
+ vocab_size=30522,
848
+ specialization="general",
849
+ dataset_path=None,
850
+ model_name="bert-base-uncased",
851
+ embedding_dim=768,
852
+ num_heads=12,
853
+ hidden_dim=768,
854
+ num_layers=2,
855
+ output_size=768,
856
+ dropout=0.1,
857
+ max_seq_length=128,
858
+ pooling_mode="mean",
859
+ tokenizer=tokenizer
860
+ )
861
 
862
  # Register model
863
  from service_registry import MODEL
864
  registry.register(MODEL, model, overwrite=True)
865
  logger.info("Model registered successfully")
866
+
867
+ # Optional: Register model manager if needed
868
  return True
869
  except Exception as e:
870
+ logger.error(f"Failed to initialize model: {e}", exc_info=True)
871
  return False
872
 
873
+ # Call initialization at the beginning
874
+ success = initialize_system()
875
+ logger.info(f"Initialization {'successful' if success else 'failed'}")
876
+
877
  if __name__ == "__main__":
878
  success = initialize_system()
879
  logger.info(f"Initialization {'successful' if success else 'failed'}")