Spaces:
Sleeping
Sleeping
| """ | |
| Model presets for both fine-tuning and zero-shot classification. | |
| Provides configuration for various HuggingFace models optimized for text classification. | |
| """ | |
| MODEL_PRESETS = { | |
| # Zero-shot capable models (NLI-trained) | |
| 'bart-large-mnli': { | |
| 'name': 'BART-large-MNLI', | |
| 'model_id': 'facebook/bart-large-mnli', | |
| 'max_length': 1024, | |
| 'size': '400M', | |
| 'speed': 'Slow', | |
| 'best_for': 'Zero-shot + Fine-tuning', | |
| 'description': 'Large sequence-to-sequence model, excellent zero-shot performance', | |
| 'recommended_lr': 2e-5, | |
| 'recommended_batch': 4, | |
| 'supports_zero_shot': True | |
| }, | |
| 'deberta-v3-base-mnli': { | |
| 'name': 'DeBERTa-v3-base-MNLI', | |
| 'model_id': 'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli', | |
| 'max_length': 512, | |
| 'size': '86M', | |
| 'speed': 'Fast', | |
| 'best_for': 'Fast zero-shot classification', | |
| 'description': 'DeBERTa trained on NLI datasets, excellent zero-shot with better speed', | |
| 'recommended_lr': 2e-5, | |
| 'recommended_batch': 8, | |
| 'supports_zero_shot': True | |
| }, | |
| 'distilbart-mnli': { | |
| 'name': 'DistilBART-MNLI', | |
| 'model_id': 'valhalla/distilbart-mnli-12-3', | |
| 'max_length': 1024, | |
| 'size': '134M', | |
| 'speed': 'Medium', | |
| 'best_for': 'Balanced zero-shot', | |
| 'description': 'Distilled BART for zero-shot, good balance of speed and accuracy', | |
| 'recommended_lr': 2e-5, | |
| 'recommended_batch': 8, | |
| 'supports_zero_shot': True | |
| }, | |
| # Fine-tuning only models | |
| 'deberta-v3-small': { | |
| 'name': 'DeBERTa-v3-small', | |
| 'model_id': 'microsoft/deberta-v3-small', | |
| 'max_length': 512, | |
| 'size': '44M', | |
| 'speed': 'Very Fast', | |
| 'best_for': 'Fine-tuning with small datasets', | |
| 'description': 'State-of-the-art efficient model, excellent for small datasets', | |
| 'recommended_lr': 3e-5, | |
| 'recommended_batch': 8, | |
| 'supports_zero_shot': False | |
| }, | |
| 'deberta-v3-base': { | |
| 'name': 'DeBERTa-v3-base', | |
| 'model_id': 'microsoft/deberta-v3-base', | |
| 'max_length': 512, | |
| 'size': '86M', | |
| 'speed': 'Fast', | |
| 'best_for': 'High accuracy fine-tuning', | |
| 'description': 'Larger DeBERTa model with better accuracy', | |
| 'recommended_lr': 2e-5, | |
| 'recommended_batch': 8, | |
| 'supports_zero_shot': False | |
| }, | |
| 'distilbert-base': { | |
| 'name': 'DistilBERT-base', | |
| 'model_id': 'distilbert-base-uncased', | |
| 'max_length': 512, | |
| 'size': '66M', | |
| 'speed': 'Fast', | |
| 'best_for': 'Balanced speed and accuracy', | |
| 'description': 'Distilled BERT, 60% faster with 97% performance retention', | |
| 'recommended_lr': 5e-5, | |
| 'recommended_batch': 8, | |
| 'supports_zero_shot': False | |
| }, | |
| 'roberta-base': { | |
| 'name': 'RoBERTa-base', | |
| 'model_id': 'roberta-base', | |
| 'max_length': 512, | |
| 'size': '125M', | |
| 'speed': 'Medium', | |
| 'best_for': 'Maximum accuracy', | |
| 'description': 'Robustly optimized BERT, excellent classification performance', | |
| 'recommended_lr': 2e-5, | |
| 'recommended_batch': 8, | |
| 'supports_zero_shot': False | |
| }, | |
| 'electra-small': { | |
| 'name': 'ELECTRA-small', | |
| 'model_id': 'google/electra-small-discriminator', | |
| 'max_length': 512, | |
| 'size': '14M', | |
| 'speed': 'Fastest', | |
| 'best_for': 'Speed-critical applications', | |
| 'description': 'Very fast and lightweight, good for production', | |
| 'recommended_lr': 5e-5, | |
| 'recommended_batch': 16, | |
| 'supports_zero_shot': False | |
| }, | |
| 'minilm': { | |
| 'name': 'MiniLM-L12', | |
| 'model_id': 'microsoft/MiniLM-L12-H384-uncased', | |
| 'max_length': 512, | |
| 'size': '33M', | |
| 'speed': 'Very Fast', | |
| 'best_for': 'Lightweight production deployment', | |
| 'description': 'Compact model optimized for speed', | |
| 'recommended_lr': 4e-5, | |
| 'recommended_batch': 12, | |
| 'supports_zero_shot': False | |
| } | |
| } | |
| def get_model_preset(preset_key): | |
| """Get model preset configuration by key.""" | |
| return MODEL_PRESETS.get(preset_key, MODEL_PRESETS['bart-large-mnli']) | |
| def get_available_models(): | |
| """Get list of all available models for selection.""" | |
| return [ | |
| { | |
| 'key': key, | |
| 'name': config['name'], | |
| 'size': config['size'], | |
| 'speed': config['speed'], | |
| 'best_for': config['best_for'], | |
| 'supports_zero_shot': config['supports_zero_shot'] | |
| } | |
| for key, config in MODEL_PRESETS.items() | |
| ] | |
| def get_zero_shot_models(): | |
| """Get list of models that support zero-shot classification.""" | |
| return [ | |
| { | |
| 'key': key, | |
| 'name': config['name'], | |
| 'model_id': config['model_id'], | |
| 'size': config['size'], | |
| 'speed': config['speed'], | |
| 'description': config['description'] | |
| } | |
| for key, config in MODEL_PRESETS.items() | |
| if config.get('supports_zero_shot', False) | |
| ] | |
| def get_recommended_hyperparams(preset_key, training_mode='lora'): | |
| """Get recommended hyperparameters for a model preset.""" | |
| preset = get_model_preset(preset_key) | |
| base_params = { | |
| 'learning_rate': preset['recommended_lr'], | |
| 'batch_size': preset['recommended_batch'], | |
| 'max_length': preset['max_length'] | |
| } | |
| if training_mode == 'head_only': | |
| # Higher learning rate for head-only training | |
| base_params['learning_rate'] = preset['recommended_lr'] * 2 | |
| return base_params | |