|
|
import logging
|
|
|
import os
|
|
|
from typing import Optional, Tuple, Dict, Any
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
VALIDATED_MODEL_TYPES = [
|
|
|
'bert', 'roberta', 'distilbert', 'gpt2', 't5', 'albert',
|
|
|
'xlm-roberta', 'bart', 'electra', 'xlnet'
|
|
|
]
|
|
|
|
|
|
def validate_model_name(model_name: str) -> Tuple[bool, Optional[str]]:
|
|
|
"""
|
|
|
Validates if a model name is recognized in the Hugging Face model registry.
|
|
|
|
|
|
Args:
|
|
|
model_name: Name of the model to validate
|
|
|
|
|
|
Returns:
|
|
|
Tuple containing:
|
|
|
- Boolean indicating if the model is valid
|
|
|
- Recommended fallback model name if the original is invalid, None otherwise
|
|
|
"""
|
|
|
|
|
|
is_valid = any(model_type in model_name.lower() for model_type in VALIDATED_MODEL_TYPES)
|
|
|
|
|
|
|
|
|
if not is_valid:
|
|
|
return False, 'bert-base-uncased'
|
|
|
|
|
|
return True, None
|
|
|
|
|
|
def get_safe_model_name(config):
|
|
|
"""
|
|
|
Get a validated and sanitized model name from config.
|
|
|
|
|
|
Args:
|
|
|
config: Either a config dictionary or a string model name
|
|
|
|
|
|
Returns:
|
|
|
str: A sanitized model name
|
|
|
"""
|
|
|
|
|
|
if isinstance(config, str):
|
|
|
model_name = config
|
|
|
else:
|
|
|
|
|
|
model_name = config.get('MODEL_NAME', 'bert-base-uncased')
|
|
|
|
|
|
|
|
|
is_valid, fallback = validate_model_name(model_name)
|
|
|
|
|
|
|
|
|
return model_name if is_valid else fallback
|
|
|
|
|
|
def create_model_config_json(model_dir: str, model_type: str = 'bert') -> None:
|
|
|
"""
|
|
|
Creates a config.json file for a custom model with proper model_type key.
|
|
|
|
|
|
Args:
|
|
|
model_dir: Directory where model is/will be stored
|
|
|
model_type: The type of model (e.g., 'bert', 'roberta')
|
|
|
"""
|
|
|
import json
|
|
|
|
|
|
if not os.path.exists(model_dir):
|
|
|
os.makedirs(model_dir)
|
|
|
|
|
|
config_path = os.path.join(model_dir, 'config.json')
|
|
|
|
|
|
|
|
|
config = {
|
|
|
"model_type": model_type,
|
|
|
"architectures": [f"{model_type.capitalize()}Model"],
|
|
|
"hidden_size": 768,
|
|
|
"num_attention_heads": 12,
|
|
|
"num_hidden_layers": 12
|
|
|
}
|
|
|
|
|
|
with open(config_path, 'w') as f:
|
|
|
json.dump(config, f, indent=2)
|
|
|
|
|
|
logger.info(f"Created model config.json with model_type: {model_type} in {model_dir}")
|
|
|
|