Spaces:
Sleeping
Sleeping
File size: 5,681 Bytes
1377fb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
"""
Model presets for both fine-tuning and zero-shot classification.
Provides configuration for various HuggingFace models optimized for text classification.
"""
MODEL_PRESETS = {
# Zero-shot capable models (NLI-trained)
'bart-large-mnli': {
'name': 'BART-large-MNLI',
'model_id': 'facebook/bart-large-mnli',
'max_length': 1024,
'size': '400M',
'speed': 'Slow',
'best_for': 'Zero-shot + Fine-tuning',
'description': 'Large sequence-to-sequence model, excellent zero-shot performance',
'recommended_lr': 2e-5,
'recommended_batch': 4,
'supports_zero_shot': True
},
'deberta-v3-base-mnli': {
'name': 'DeBERTa-v3-base-MNLI',
'model_id': 'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli',
'max_length': 512,
'size': '86M',
'speed': 'Fast',
'best_for': 'Fast zero-shot classification',
'description': 'DeBERTa trained on NLI datasets, excellent zero-shot with better speed',
'recommended_lr': 2e-5,
'recommended_batch': 8,
'supports_zero_shot': True
},
'distilbart-mnli': {
'name': 'DistilBART-MNLI',
'model_id': 'valhalla/distilbart-mnli-12-3',
'max_length': 1024,
'size': '134M',
'speed': 'Medium',
'best_for': 'Balanced zero-shot',
'description': 'Distilled BART for zero-shot, good balance of speed and accuracy',
'recommended_lr': 2e-5,
'recommended_batch': 8,
'supports_zero_shot': True
},
# Fine-tuning only models
'deberta-v3-small': {
'name': 'DeBERTa-v3-small',
'model_id': 'microsoft/deberta-v3-small',
'max_length': 512,
'size': '44M',
'speed': 'Very Fast',
'best_for': 'Fine-tuning with small datasets',
'description': 'State-of-the-art efficient model, excellent for small datasets',
'recommended_lr': 3e-5,
'recommended_batch': 8,
'supports_zero_shot': False
},
'deberta-v3-base': {
'name': 'DeBERTa-v3-base',
'model_id': 'microsoft/deberta-v3-base',
'max_length': 512,
'size': '86M',
'speed': 'Fast',
'best_for': 'High accuracy fine-tuning',
'description': 'Larger DeBERTa model with better accuracy',
'recommended_lr': 2e-5,
'recommended_batch': 8,
'supports_zero_shot': False
},
'distilbert-base': {
'name': 'DistilBERT-base',
'model_id': 'distilbert-base-uncased',
'max_length': 512,
'size': '66M',
'speed': 'Fast',
'best_for': 'Balanced speed and accuracy',
'description': 'Distilled BERT, 60% faster with 97% performance retention',
'recommended_lr': 5e-5,
'recommended_batch': 8,
'supports_zero_shot': False
},
'roberta-base': {
'name': 'RoBERTa-base',
'model_id': 'roberta-base',
'max_length': 512,
'size': '125M',
'speed': 'Medium',
'best_for': 'Maximum accuracy',
'description': 'Robustly optimized BERT, excellent classification performance',
'recommended_lr': 2e-5,
'recommended_batch': 8,
'supports_zero_shot': False
},
'electra-small': {
'name': 'ELECTRA-small',
'model_id': 'google/electra-small-discriminator',
'max_length': 512,
'size': '14M',
'speed': 'Fastest',
'best_for': 'Speed-critical applications',
'description': 'Very fast and lightweight, good for production',
'recommended_lr': 5e-5,
'recommended_batch': 16,
'supports_zero_shot': False
},
'minilm': {
'name': 'MiniLM-L12',
'model_id': 'microsoft/MiniLM-L12-H384-uncased',
'max_length': 512,
'size': '33M',
'speed': 'Very Fast',
'best_for': 'Lightweight production deployment',
'description': 'Compact model optimized for speed',
'recommended_lr': 4e-5,
'recommended_batch': 12,
'supports_zero_shot': False
}
}
def get_model_preset(preset_key):
"""Get model preset configuration by key."""
return MODEL_PRESETS.get(preset_key, MODEL_PRESETS['bart-large-mnli'])
def get_available_models():
"""Get list of all available models for selection."""
return [
{
'key': key,
'name': config['name'],
'size': config['size'],
'speed': config['speed'],
'best_for': config['best_for'],
'supports_zero_shot': config['supports_zero_shot']
}
for key, config in MODEL_PRESETS.items()
]
def get_zero_shot_models():
"""Get list of models that support zero-shot classification."""
return [
{
'key': key,
'name': config['name'],
'model_id': config['model_id'],
'size': config['size'],
'speed': config['speed'],
'description': config['description']
}
for key, config in MODEL_PRESETS.items()
if config.get('supports_zero_shot', False)
]
def get_recommended_hyperparams(preset_key, training_mode='lora'):
"""Get recommended hyperparameters for a model preset."""
preset = get_model_preset(preset_key)
base_params = {
'learning_rate': preset['recommended_lr'],
'batch_size': preset['recommended_batch'],
'max_length': preset['max_length']
}
if training_mode == 'head_only':
# Higher learning rate for head-only training
base_params['learning_rate'] = preset['recommended_lr'] * 2
return base_params
|