import logging import os from typing import Optional, Tuple, Dict, Any logger = logging.getLogger(__name__) # List of validated model types from Hugging Face VALIDATED_MODEL_TYPES = [ 'bert', 'roberta', 'distilbert', 'gpt2', 't5', 'albert', 'xlm-roberta', 'bart', 'electra', 'xlnet' ] def validate_model_name(model_name: str) -> Tuple[bool, Optional[str]]: """ Validates if a model name is recognized in the Hugging Face model registry. Args: model_name: Name of the model to validate Returns: Tuple containing: - Boolean indicating if the model is valid - Recommended fallback model name if the original is invalid, None otherwise """ # Check if model name contains any known model type is_valid = any(model_type in model_name.lower() for model_type in VALIDATED_MODEL_TYPES) # Return appropriate fallback based on failure reason if not is_valid: return False, 'bert-base-uncased' # Default fallback return True, None def get_safe_model_name(config): """ Get a validated and sanitized model name from config. Args: config: Either a config dictionary or a string model name Returns: str: A sanitized model name """ # Handle string input directly if isinstance(config, str): model_name = config else: # Handle dictionary input (original behavior) model_name = config.get('MODEL_NAME', 'bert-base-uncased') # Validate the model name is_valid, fallback = validate_model_name(model_name) # Return original name if valid, otherwise return fallback return model_name if is_valid else fallback def create_model_config_json(model_dir: str, model_type: str = 'bert') -> None: """ Creates a config.json file for a custom model with proper model_type key. Args: model_dir: Directory where model is/will be stored model_type: The type of model (e.g., 'bert', 'roberta') """ import json if not os.path.exists(model_dir): os.makedirs(model_dir) config_path = os.path.join(model_dir, 'config.json') # Create a minimal config with the required model_type key config = { "model_type": model_type, "architectures": [f"{model_type.capitalize()}Model"], "hidden_size": 768, "num_attention_heads": 12, "num_hidden_layers": 12 } with open(config_path, 'w') as f: json.dump(config, f, indent=2) logger.info(f"Created model config.json with model_type: {model_type} in {model_dir}")