contexto-api / src /model_selector.py
Dev-ks04
feat: Contexto FastAPI backend - intent-aware summarization engine
39028c9
"""
Intelligent model selection based on document characteristics
"""
import logging
from typing import Optional, Dict, Any
from enum import Enum
logger = logging.getLogger(__name__)
class ModelCategory(Enum):
"""Model categories by speed/complexity tradeoff."""
LIGHTWEIGHT = "lightweight"
BALANCED = "balanced"
ADVANCED = "advanced"
class DocumentComplexity(Enum):
"""Document complexity levels."""
SIMPLE = "simple"
MODERATE = "moderate"
COMPLEX = "complex"
VERY_COMPLEX = "very_complex"
class ModelSelector:
"""Intelligent model selector based on document characteristics."""
LIGHTWEIGHT_MODELS = {
't5-small': {'params': 60e6, 'speed': 'very_fast', 'quality': 'good'},
'distilbart-cnn-6-6': {'params': 82e6, 'speed': 'very_fast', 'quality': 'good'},
'bart-base': {'params': 140e6, 'speed': 'fast', 'quality': 'very_good'},
}
BALANCED_MODELS = {
't5-base': {'params': 220e6, 'speed': 'fast', 'quality': 'excellent'},
'pegasus-arxiv': {'params': 568e6, 'speed': 'moderate', 'quality': 'excellent'},
'bart-large-cnn': {'params': 406e6, 'speed': 'moderate', 'quality': 'excellent'},
}
ADVANCED_MODELS = {
'llama-7b': {'params': 7e9, 'speed': 'slow', 'quality': 'outstanding'},
'gpt-3.5': {'params': 175e9, 'speed': 'slow', 'quality': 'outstanding'},
't5-large': {'params': 770e6, 'speed': 'moderate', 'quality': 'outstanding'},
}
COMPLEXITY_THRESHOLDS = {
'word_count': {
'simple': 500,
'moderate': 2000,
'complex': 5000,
'very_complex': float('inf')
},
'sentence_length': {
'simple': 15,
'moderate': 25,
'complex': 40,
'very_complex': float('inf')
},
'vocabulary_complexity': {
'simple': 0.3,
'moderate': 0.5,
'complex': 0.7,
'very_complex': 1.0
}
}
def __init__(self):
"""Initialize model selector."""
self.current_model = None
self.current_complexity = None
self.use_rag = False
def analyze_document(self, text: str) -> Dict[str, Any]:
"""
Analyze document to determine complexity.
Args:
text: Document text
Returns:
Dictionary with complexity analysis
"""
words = text.split()
sentences = text.split('.')
word_count = len(words)
sentence_count = len([s for s in sentences if s.strip()])
avg_sentence_length = word_count / max(sentence_count, 1)
unique_words = len(set(w.lower() for w in words))
vocabulary_richness = unique_words / max(word_count, 1)
analysis = {
'word_count': word_count,
'sentence_count': sentence_count,
'avg_sentence_length': avg_sentence_length,
'vocabulary_richness': vocabulary_richness,
}
return analysis
def determine_complexity(self, analysis: Dict[str, Any]) -> DocumentComplexity:
"""
Determine document complexity from analysis.
Args:
analysis: Document analysis dictionary
Returns:
DocumentComplexity enum
"""
word_count = analysis['word_count']
avg_sentence_length = analysis['avg_sentence_length']
vocab_richness = analysis['vocabulary_richness']
complexity_score = 0
if word_count > self.COMPLEXITY_THRESHOLDS['word_count']['very_complex']:
complexity_score += 3
elif word_count > self.COMPLEXITY_THRESHOLDS['word_count']['complex']:
complexity_score += 2.5
elif word_count > self.COMPLEXITY_THRESHOLDS['word_count']['moderate']:
complexity_score += 1.5
elif word_count > self.COMPLEXITY_THRESHOLDS['word_count']['simple']:
complexity_score += 0.5
if avg_sentence_length > self.COMPLEXITY_THRESHOLDS['sentence_length']['very_complex']:
complexity_score += 3
elif avg_sentence_length > self.COMPLEXITY_THRESHOLDS['sentence_length']['complex']:
complexity_score += 2.5
elif avg_sentence_length > self.COMPLEXITY_THRESHOLDS['sentence_length']['moderate']:
complexity_score += 1.5
elif avg_sentence_length > self.COMPLEXITY_THRESHOLDS['sentence_length']['simple']:
complexity_score += 0.5
if vocab_richness > self.COMPLEXITY_THRESHOLDS['vocabulary_complexity']['very_complex']:
complexity_score += 3
elif vocab_richness > self.COMPLEXITY_THRESHOLDS['vocabulary_complexity']['complex']:
complexity_score += 2.5
elif vocab_richness > self.COMPLEXITY_THRESHOLDS['vocabulary_complexity']['moderate']:
complexity_score += 1.5
elif vocab_richness > self.COMPLEXITY_THRESHOLDS['vocabulary_complexity']['simple']:
complexity_score += 0.5
if complexity_score < 2:
return DocumentComplexity.SIMPLE
elif complexity_score < 4:
return DocumentComplexity.MODERATE
elif complexity_score < 6:
return DocumentComplexity.COMPLEX
else:
return DocumentComplexity.VERY_COMPLEX
def select_model(self, text: str, prefer_fast: bool = True) -> Dict[str, Any]:
"""
Select best model based on document.
Args:
text: Document text
prefer_fast: If True, prefer speed over quality
Returns:
Dictionary with model recommendation and settings
"""
analysis = self.analyze_document(text)
complexity = self.determine_complexity(analysis)
self.current_complexity = complexity
logger.info(f"Document complexity: {complexity.value}")
logger.info(f"Analysis: word_count={analysis['word_count']}, "
f"avg_sentence_length={analysis['avg_sentence_length']:.1f}, "
f"vocabulary_richness={analysis['vocabulary_richness']:.2f}")
if complexity == DocumentComplexity.SIMPLE:
return self._get_lightweight_recommendation(prefer_fast)
elif complexity == DocumentComplexity.MODERATE:
return self._get_balanced_recommendation(prefer_fast)
elif complexity == DocumentComplexity.COMPLEX:
return self._get_advanced_recommendation(prefer_fast)
else:
return self._get_advanced_recommendation_with_rag(prefer_fast)
def _get_lightweight_recommendation(self, prefer_fast: bool) -> Dict[str, Any]:
"""Recommend lightweight model for simple documents."""
if prefer_fast:
model = 't5-small'
else:
model = 'bart-base'
return {
'model': model,
'category': ModelCategory.LIGHTWEIGHT,
'use_rag': False,
'num_beams': 2,
'max_length': 100,
'reason': 'Simple document - lightweight model sufficient',
'estimated_time': '1-2 seconds'
}
def _get_balanced_recommendation(self, prefer_fast: bool) -> Dict[str, Any]:
"""Recommend balanced model for moderate complexity documents."""
if prefer_fast:
model = 't5-base'
else:
model = 'bart-large-cnn'
return {
'model': model,
'category': ModelCategory.BALANCED,
'use_rag': False,
'num_beams': 4,
'max_length': 150,
'reason': 'Moderate complexity - balanced model recommended',
'estimated_time': '2-5 seconds'
}
def _get_advanced_recommendation(self, prefer_fast: bool) -> Dict[str, Any]:
"""Recommend advanced model for complex documents."""
if prefer_fast:
model = 'pegasus-arxiv'
else:
model = 't5-large'
return {
'model': model,
'category': ModelCategory.ADVANCED,
'use_rag': True,
'num_beams': 4,
'max_length': 200,
'reason': 'Complex document - advanced model with RAG recommended',
'estimated_time': '3-7 seconds'
}
def _get_advanced_recommendation_with_rag(self, prefer_fast: bool) -> Dict[str, Any]:
"""Recommend advanced model with RAG for very complex documents."""
return {
'model': 't5-large',
'category': ModelCategory.ADVANCED,
'use_rag': True,
'num_beams': 4,
'max_length': 250,
'reason': 'Very complex document - RAG pipeline required for context awareness',
'estimated_time': '5-10 seconds',
'retrieval_chunks': 5,
'chunk_overlap': 100
}
def should_use_rag(self, complexity: DocumentComplexity) -> bool:
"""
Determine if RAG should be used.
Args:
complexity: Document complexity
Returns:
Boolean indicating whether to use RAG
"""
return complexity in [DocumentComplexity.COMPLEX, DocumentComplexity.VERY_COMPLEX]
def get_model_info(self, model_name: str) -> Optional[Dict[str, Any]]:
"""
Get information about a model.
Args:
model_name: Model name
Returns:
Model information or None if not found
"""
for models_dict in [self.LIGHTWEIGHT_MODELS, self.BALANCED_MODELS, self.ADVANCED_MODELS]:
if model_name in models_dict:
return models_dict[model_name]
return None
def recommend_settings(self, text: str, quality_preference: str = 'balanced') -> Dict[str, Any]:
"""
Get recommended settings for document summarization.
Args:
text: Document text
quality_preference: 'speed', 'balanced', or 'quality'
Returns:
Recommended settings dictionary
"""
prefer_fast = quality_preference == 'speed'
recommendation = self.select_model(text, prefer_fast)
model_info = self.get_model_info(recommendation['model'])
if model_info:
recommendation['model_info'] = model_info
return recommendation