""" Utilities for safely accessing configuration values regardless of whether they're stored in dictionaries or objects. """ def safe_get_config(config_obj, key, default=None): """ Safely get a configuration value regardless of whether the config object is a dictionary or object. Args: config_obj: Configuration object (dict or object with attributes) key: Key or attribute name to access default: Default value to return if key/attribute doesn't exist Returns: The value of the key/attribute, or default if not found """ if config_obj is None: return default if isinstance(config_obj, dict): return config_obj.get(key, default) return getattr(config_obj, key, default) def get_model_name(config_obj): """ Gets model name from config with proper fallbacks. Args: config_obj: Configuration object Returns: Model name string """ transformer_config = safe_get_config(config_obj, "TRANSFORMER_CONFIG", {}) return safe_get_config(transformer_config, "MODEL_NAME", "bert-base-uncased") def get_embedding_model(config_obj): """ Creates or retrieves an embedding model based on configuration. Args: config_obj: Configuration object Returns: Embedding model instance """ from utils.transformer_utils import get_sentence_transformer model_name = get_model_name(config_obj) return get_sentence_transformer(model_name) """Utilities for configuration validation and fixes""" import os import logging from typing import Dict, Any logger = logging.getLogger(__name__) def validate_config(config: Any) -> None: """Validate and fix configuration objects to prevent errors""" # Fix TRANSFORMER_CONFIG if it exists if hasattr(config, 'TRANSFORMER_CONFIG'): tc = config.TRANSFORMER_CONFIG # If it's a dict, make sure it has necessary values if isinstance(tc, dict): # Ensure required keys have default values defaults = { "MAX_SEQ_LENGTH": 512, "MODEL_NAME": "bert-base-uncased", "NUM_LAYERS": 6, "EMBEDDING_DIM": 768, "NUM_HEADS": 12, "HIDDEN_DIM": 768, "DROPOUT": 0.1, "POOLING_MODE": "mean" } # Add any missing defaults for key, value in defaults.items(): if key not in tc: tc[key] = value logger.info(f"Added default {key}={value} to TRANSFORMER_CONFIG") # Check data directories data_dir = getattr(config, "DATA_DIR", os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data")) os.makedirs(data_dir, exist_ok=True) # Check model directory model_dir = getattr(config, "MODEL_DIR", os.path.join(data_dir, "models")) os.makedirs(model_dir, exist_ok=True) logger.info("Configuration validated and fixed")