thadillo
πŸš€ Deploy to HF Spaces: Model selection + Fine-tuning updates
1377fb1
"""
Model presets for both fine-tuning and zero-shot classification.
Provides configuration for various HuggingFace models optimized for text classification.
"""
MODEL_PRESETS = {
# Zero-shot capable models (NLI-trained)
'bart-large-mnli': {
'name': 'BART-large-MNLI',
'model_id': 'facebook/bart-large-mnli',
'max_length': 1024,
'size': '400M',
'speed': 'Slow',
'best_for': 'Zero-shot + Fine-tuning',
'description': 'Large sequence-to-sequence model, excellent zero-shot performance',
'recommended_lr': 2e-5,
'recommended_batch': 4,
'supports_zero_shot': True
},
'deberta-v3-base-mnli': {
'name': 'DeBERTa-v3-base-MNLI',
'model_id': 'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli',
'max_length': 512,
'size': '86M',
'speed': 'Fast',
'best_for': 'Fast zero-shot classification',
'description': 'DeBERTa trained on NLI datasets, excellent zero-shot with better speed',
'recommended_lr': 2e-5,
'recommended_batch': 8,
'supports_zero_shot': True
},
'distilbart-mnli': {
'name': 'DistilBART-MNLI',
'model_id': 'valhalla/distilbart-mnli-12-3',
'max_length': 1024,
'size': '134M',
'speed': 'Medium',
'best_for': 'Balanced zero-shot',
'description': 'Distilled BART for zero-shot, good balance of speed and accuracy',
'recommended_lr': 2e-5,
'recommended_batch': 8,
'supports_zero_shot': True
},
# Fine-tuning only models
'deberta-v3-small': {
'name': 'DeBERTa-v3-small',
'model_id': 'microsoft/deberta-v3-small',
'max_length': 512,
'size': '44M',
'speed': 'Very Fast',
'best_for': 'Fine-tuning with small datasets',
'description': 'State-of-the-art efficient model, excellent for small datasets',
'recommended_lr': 3e-5,
'recommended_batch': 8,
'supports_zero_shot': False
},
'deberta-v3-base': {
'name': 'DeBERTa-v3-base',
'model_id': 'microsoft/deberta-v3-base',
'max_length': 512,
'size': '86M',
'speed': 'Fast',
'best_for': 'High accuracy fine-tuning',
'description': 'Larger DeBERTa model with better accuracy',
'recommended_lr': 2e-5,
'recommended_batch': 8,
'supports_zero_shot': False
},
'distilbert-base': {
'name': 'DistilBERT-base',
'model_id': 'distilbert-base-uncased',
'max_length': 512,
'size': '66M',
'speed': 'Fast',
'best_for': 'Balanced speed and accuracy',
'description': 'Distilled BERT, 60% faster with 97% performance retention',
'recommended_lr': 5e-5,
'recommended_batch': 8,
'supports_zero_shot': False
},
'roberta-base': {
'name': 'RoBERTa-base',
'model_id': 'roberta-base',
'max_length': 512,
'size': '125M',
'speed': 'Medium',
'best_for': 'Maximum accuracy',
'description': 'Robustly optimized BERT, excellent classification performance',
'recommended_lr': 2e-5,
'recommended_batch': 8,
'supports_zero_shot': False
},
'electra-small': {
'name': 'ELECTRA-small',
'model_id': 'google/electra-small-discriminator',
'max_length': 512,
'size': '14M',
'speed': 'Fastest',
'best_for': 'Speed-critical applications',
'description': 'Very fast and lightweight, good for production',
'recommended_lr': 5e-5,
'recommended_batch': 16,
'supports_zero_shot': False
},
'minilm': {
'name': 'MiniLM-L12',
'model_id': 'microsoft/MiniLM-L12-H384-uncased',
'max_length': 512,
'size': '33M',
'speed': 'Very Fast',
'best_for': 'Lightweight production deployment',
'description': 'Compact model optimized for speed',
'recommended_lr': 4e-5,
'recommended_batch': 12,
'supports_zero_shot': False
}
}
def get_model_preset(preset_key):
"""Get model preset configuration by key."""
return MODEL_PRESETS.get(preset_key, MODEL_PRESETS['bart-large-mnli'])
def get_available_models():
"""Get list of all available models for selection."""
return [
{
'key': key,
'name': config['name'],
'size': config['size'],
'speed': config['speed'],
'best_for': config['best_for'],
'supports_zero_shot': config['supports_zero_shot']
}
for key, config in MODEL_PRESETS.items()
]
def get_zero_shot_models():
"""Get list of models that support zero-shot classification."""
return [
{
'key': key,
'name': config['name'],
'model_id': config['model_id'],
'size': config['size'],
'speed': config['speed'],
'description': config['description']
}
for key, config in MODEL_PRESETS.items()
if config.get('supports_zero_shot', False)
]
def get_recommended_hyperparams(preset_key, training_mode='lora'):
"""Get recommended hyperparameters for a model preset."""
preset = get_model_preset(preset_key)
base_params = {
'learning_rate': preset['recommended_lr'],
'batch_size': preset['recommended_batch'],
'max_length': preset['max_length']
}
if training_mode == 'head_only':
# Higher learning rate for head-only training
base_params['learning_rate'] = preset['recommended_lr'] * 2
return base_params