Upload 5 files
Browse files- config_utils.py +95 -0
- handler.py +9 -0
- main.py +23 -3
config_utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for safely accessing configuration values regardless of whether they're stored
|
| 3 |
+
in dictionaries or objects.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
def safe_get_config(config_obj, key, default=None):
|
| 7 |
+
"""
|
| 8 |
+
Safely get a configuration value regardless of whether the config object is a dictionary or object.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
config_obj: Configuration object (dict or object with attributes)
|
| 12 |
+
key: Key or attribute name to access
|
| 13 |
+
default: Default value to return if key/attribute doesn't exist
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
The value of the key/attribute, or default if not found
|
| 17 |
+
"""
|
| 18 |
+
if config_obj is None:
|
| 19 |
+
return default
|
| 20 |
+
|
| 21 |
+
if isinstance(config_obj, dict):
|
| 22 |
+
return config_obj.get(key, default)
|
| 23 |
+
|
| 24 |
+
return getattr(config_obj, key, default)
|
| 25 |
+
|
| 26 |
+
def get_model_name(config_obj):
|
| 27 |
+
"""
|
| 28 |
+
Gets model name from config with proper fallbacks.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
config_obj: Configuration object
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Model name string
|
| 35 |
+
"""
|
| 36 |
+
transformer_config = safe_get_config(config_obj, "TRANSFORMER_CONFIG", {})
|
| 37 |
+
return safe_get_config(transformer_config, "MODEL_NAME", "bert-base-uncased")
|
| 38 |
+
|
| 39 |
+
def get_embedding_model(config_obj):
|
| 40 |
+
"""
|
| 41 |
+
Creates or retrieves an embedding model based on configuration.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
config_obj: Configuration object
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Embedding model instance
|
| 48 |
+
"""
|
| 49 |
+
from utils.transformer_utils import get_sentence_transformer
|
| 50 |
+
model_name = get_model_name(config_obj)
|
| 51 |
+
return get_sentence_transformer(model_name)
|
| 52 |
+
|
| 53 |
+
"""Utilities for configuration validation and fixes"""
|
| 54 |
+
import os
|
| 55 |
+
import logging
|
| 56 |
+
from typing import Dict, Any
|
| 57 |
+
|
| 58 |
+
logger = logging.getLogger(__name__)
|
| 59 |
+
|
| 60 |
+
def validate_config(config: Any) -> None:
|
| 61 |
+
"""Validate and fix configuration objects to prevent errors"""
|
| 62 |
+
|
| 63 |
+
# Fix TRANSFORMER_CONFIG if it exists
|
| 64 |
+
if hasattr(config, 'TRANSFORMER_CONFIG'):
|
| 65 |
+
tc = config.TRANSFORMER_CONFIG
|
| 66 |
+
|
| 67 |
+
# If it's a dict, make sure it has necessary values
|
| 68 |
+
if isinstance(tc, dict):
|
| 69 |
+
# Ensure required keys have default values
|
| 70 |
+
defaults = {
|
| 71 |
+
"MAX_SEQ_LENGTH": 512,
|
| 72 |
+
"MODEL_NAME": "bert-base-uncased",
|
| 73 |
+
"NUM_LAYERS": 6,
|
| 74 |
+
"EMBEDDING_DIM": 768,
|
| 75 |
+
"NUM_HEADS": 12,
|
| 76 |
+
"HIDDEN_DIM": 768,
|
| 77 |
+
"DROPOUT": 0.1,
|
| 78 |
+
"POOLING_MODE": "mean"
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Add any missing defaults
|
| 82 |
+
for key, value in defaults.items():
|
| 83 |
+
if key not in tc:
|
| 84 |
+
tc[key] = value
|
| 85 |
+
logger.info(f"Added default {key}={value} to TRANSFORMER_CONFIG")
|
| 86 |
+
|
| 87 |
+
# Check data directories
|
| 88 |
+
data_dir = getattr(config, "DATA_DIR", os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data"))
|
| 89 |
+
os.makedirs(data_dir, exist_ok=True)
|
| 90 |
+
|
| 91 |
+
# Check model directory
|
| 92 |
+
model_dir = getattr(config, "MODEL_DIR", os.path.join(data_dir, "models"))
|
| 93 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 94 |
+
|
| 95 |
+
logger.info("Configuration validated and fixed")
|
handler.py
CHANGED
|
@@ -76,6 +76,15 @@ except ImportError as e:
|
|
| 76 |
def generate(self, text_input, **kwargs):
|
| 77 |
return f"Model adapter unavailable. Received input: {text_input[:30]}..."
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
class EndpointHandler:
|
| 80 |
def __init__(self, model_dir: str = None):
|
| 81 |
# HF toolkit passes model directory here; log or ignore
|
|
|
|
| 76 |
def generate(self, text_input, **kwargs):
|
| 77 |
return f"Model adapter unavailable. Received input: {text_input[:30]}..."
|
| 78 |
|
| 79 |
+
# After imports but before EndpointHandler class
|
| 80 |
+
try:
|
| 81 |
+
# Try to initialize the system first
|
| 82 |
+
from main import initialize_system
|
| 83 |
+
success = initialize_system()
|
| 84 |
+
logger.info(f"System initialization {'successful' if success else 'failed'}")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Failed to initialize system: {e}")
|
| 87 |
+
|
| 88 |
class EndpointHandler:
|
| 89 |
def __init__(self, model_dir: str = None):
|
| 90 |
# HF toolkit passes model directory here; log or ignore
|
main.py
CHANGED
|
@@ -840,20 +840,40 @@ def initialize_system():
|
|
| 840 |
registry.register(TOKENIZER, tokenizer, overwrite=True)
|
| 841 |
logger.info("Tokenizer registered")
|
| 842 |
|
| 843 |
-
# Now load model
|
| 844 |
try:
|
| 845 |
from model_Custm import Wildnerve_tlm01
|
| 846 |
-
model = Wildnerve_tlm01(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
|
| 848 |
# Register model
|
| 849 |
from service_registry import MODEL
|
| 850 |
registry.register(MODEL, model, overwrite=True)
|
| 851 |
logger.info("Model registered successfully")
|
|
|
|
|
|
|
| 852 |
return True
|
| 853 |
except Exception as e:
|
| 854 |
-
logger.error(f"Failed to initialize model: {e}")
|
| 855 |
return False
|
| 856 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 857 |
if __name__ == "__main__":
|
| 858 |
success = initialize_system()
|
| 859 |
logger.info(f"Initialization {'successful' if success else 'failed'}")
|
|
|
|
| 840 |
registry.register(TOKENIZER, tokenizer, overwrite=True)
|
| 841 |
logger.info("Tokenizer registered")
|
| 842 |
|
| 843 |
+
# Now load model with safer error handling
|
| 844 |
try:
|
| 845 |
from model_Custm import Wildnerve_tlm01
|
| 846 |
+
model = Wildnerve_tlm01(
|
| 847 |
+
vocab_size=30522,
|
| 848 |
+
specialization="general",
|
| 849 |
+
dataset_path=None,
|
| 850 |
+
model_name="bert-base-uncased",
|
| 851 |
+
embedding_dim=768,
|
| 852 |
+
num_heads=12,
|
| 853 |
+
hidden_dim=768,
|
| 854 |
+
num_layers=2,
|
| 855 |
+
output_size=768,
|
| 856 |
+
dropout=0.1,
|
| 857 |
+
max_seq_length=128,
|
| 858 |
+
pooling_mode="mean",
|
| 859 |
+
tokenizer=tokenizer
|
| 860 |
+
)
|
| 861 |
|
| 862 |
# Register model
|
| 863 |
from service_registry import MODEL
|
| 864 |
registry.register(MODEL, model, overwrite=True)
|
| 865 |
logger.info("Model registered successfully")
|
| 866 |
+
|
| 867 |
+
# Optional: Register model manager if needed
|
| 868 |
return True
|
| 869 |
except Exception as e:
|
| 870 |
+
logger.error(f"Failed to initialize model: {e}", exc_info=True)
|
| 871 |
return False
|
| 872 |
|
| 873 |
+
# Call initialization at the beginning
|
| 874 |
+
success = initialize_system()
|
| 875 |
+
logger.info(f"Initialization {'successful' if success else 'failed'}")
|
| 876 |
+
|
| 877 |
if __name__ == "__main__":
|
| 878 |
success = initialize_system()
|
| 879 |
logger.info(f"Initialization {'successful' if success else 'failed'}")
|