|
|
"""
|
|
|
Utilities for safely accessing configuration values regardless of whether they're stored
|
|
|
in dictionaries or objects.
|
|
|
"""
|
|
|
|
|
|
def safe_get_config(config_obj, key, default=None):
|
|
|
"""
|
|
|
Safely get a configuration value regardless of whether the config object is a dictionary or object.
|
|
|
|
|
|
Args:
|
|
|
config_obj: Configuration object (dict or object with attributes)
|
|
|
key: Key or attribute name to access
|
|
|
default: Default value to return if key/attribute doesn't exist
|
|
|
|
|
|
Returns:
|
|
|
The value of the key/attribute, or default if not found
|
|
|
"""
|
|
|
if config_obj is None:
|
|
|
return default
|
|
|
|
|
|
if isinstance(config_obj, dict):
|
|
|
return config_obj.get(key, default)
|
|
|
|
|
|
return getattr(config_obj, key, default)
|
|
|
|
|
|
def get_model_name(config_obj):
|
|
|
"""
|
|
|
Gets model name from config with proper fallbacks.
|
|
|
|
|
|
Args:
|
|
|
config_obj: Configuration object
|
|
|
|
|
|
Returns:
|
|
|
Model name string
|
|
|
"""
|
|
|
transformer_config = safe_get_config(config_obj, "TRANSFORMER_CONFIG", {})
|
|
|
return safe_get_config(transformer_config, "MODEL_NAME", "bert-base-uncased")
|
|
|
|
|
|
def get_embedding_model(config_obj):
|
|
|
"""
|
|
|
Creates or retrieves an embedding model based on configuration.
|
|
|
|
|
|
Args:
|
|
|
config_obj: Configuration object
|
|
|
|
|
|
Returns:
|
|
|
Embedding model instance
|
|
|
"""
|
|
|
from utils.transformer_utils import get_sentence_transformer
|
|
|
model_name = get_model_name(config_obj)
|
|
|
return get_sentence_transformer(model_name)
|
|
|
|
|
|
"""Utilities for configuration validation and fixes"""
|
|
|
import os
|
|
|
import logging
|
|
|
from typing import Dict, Any
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def validate_config(config: Any) -> None:
|
|
|
"""Validate and fix configuration objects to prevent errors"""
|
|
|
|
|
|
|
|
|
if hasattr(config, 'TRANSFORMER_CONFIG'):
|
|
|
tc = config.TRANSFORMER_CONFIG
|
|
|
|
|
|
|
|
|
if isinstance(tc, dict):
|
|
|
|
|
|
defaults = {
|
|
|
"MAX_SEQ_LENGTH": 512,
|
|
|
"MODEL_NAME": "bert-base-uncased",
|
|
|
"NUM_LAYERS": 6,
|
|
|
"EMBEDDING_DIM": 768,
|
|
|
"NUM_HEADS": 12,
|
|
|
"HIDDEN_DIM": 768,
|
|
|
"DROPOUT": 0.1,
|
|
|
"POOLING_MODE": "mean"
|
|
|
}
|
|
|
|
|
|
|
|
|
for key, value in defaults.items():
|
|
|
if key not in tc:
|
|
|
tc[key] = value
|
|
|
logger.info(f"Added default {key}={value} to TRANSFORMER_CONFIG")
|
|
|
|
|
|
|
|
|
data_dir = getattr(config, "DATA_DIR", os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data"))
|
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
model_dir = getattr(config, "MODEL_DIR", os.path.join(data_dir, "models"))
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
|
logger.info("Configuration validated and fixed")
|
|
|
|