Wildnerve-tlm01_Hybrid_Model / config_utils.py
WildnerveAI's picture
Upload 5 files
1a8d9bc verified
"""
Utilities for safely accessing configuration values regardless of whether they're stored
in dictionaries or objects.
"""
def safe_get_config(config_obj, key, default=None):
"""
Safely get a configuration value regardless of whether the config object is a dictionary or object.
Args:
config_obj: Configuration object (dict or object with attributes)
key: Key or attribute name to access
default: Default value to return if key/attribute doesn't exist
Returns:
The value of the key/attribute, or default if not found
"""
if config_obj is None:
return default
if isinstance(config_obj, dict):
return config_obj.get(key, default)
return getattr(config_obj, key, default)
def get_model_name(config_obj):
"""
Gets model name from config with proper fallbacks.
Args:
config_obj: Configuration object
Returns:
Model name string
"""
transformer_config = safe_get_config(config_obj, "TRANSFORMER_CONFIG", {})
return safe_get_config(transformer_config, "MODEL_NAME", "bert-base-uncased")
def get_embedding_model(config_obj):
"""
Creates or retrieves an embedding model based on configuration.
Args:
config_obj: Configuration object
Returns:
Embedding model instance
"""
from utils.transformer_utils import get_sentence_transformer
model_name = get_model_name(config_obj)
return get_sentence_transformer(model_name)
"""Utilities for configuration validation and fixes"""
import os
import logging
from typing import Dict, Any
logger = logging.getLogger(__name__)
def validate_config(config: Any) -> None:
"""Validate and fix configuration objects to prevent errors"""
# Fix TRANSFORMER_CONFIG if it exists
if hasattr(config, 'TRANSFORMER_CONFIG'):
tc = config.TRANSFORMER_CONFIG
# If it's a dict, make sure it has necessary values
if isinstance(tc, dict):
# Ensure required keys have default values
defaults = {
"MAX_SEQ_LENGTH": 512,
"MODEL_NAME": "bert-base-uncased",
"NUM_LAYERS": 6,
"EMBEDDING_DIM": 768,
"NUM_HEADS": 12,
"HIDDEN_DIM": 768,
"DROPOUT": 0.1,
"POOLING_MODE": "mean"
}
# Add any missing defaults
for key, value in defaults.items():
if key not in tc:
tc[key] = value
logger.info(f"Added default {key}={value} to TRANSFORMER_CONFIG")
# Check data directories
data_dir = getattr(config, "DATA_DIR", os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data"))
os.makedirs(data_dir, exist_ok=True)
# Check model directory
model_dir = getattr(config, "MODEL_DIR", os.path.join(data_dir, "models"))
os.makedirs(model_dir, exist_ok=True)
logger.info("Configuration validated and fixed")