gemini-chat / router.py
theScottyBe's picture
Upload router.py with huggingface_hub
d2972df verified
"""Smart model routing logic."""
from typing import List, Optional, Dict, Any
from models import BaseModel, TaskConfig, ProcessingResult
class ModelRouter:
"""Routes image processing tasks to optimal models."""
def __init__(self, models: List[BaseModel], prefer_free: bool = True):
self.models = models
self.prefer_free = prefer_free
# Quality thresholds
self.quality_thresholds = {
'low': 0.85,
'medium': 0.90,
'high': 0.95,
'premium': 0.98
}
def select_model(self, task_config: TaskConfig) -> Optional[BaseModel]:
"""
Select the best model for the task.
Args:
task_config: Task configuration
Returns:
Best model or None if no suitable model found
"""
# Filter by capability (task support)
capable = [m for m in self.models
if task_config.task_type in m.get_capabilities()['tasks']]
if not capable:
return None
# Filter by cost if prefer_free is set
if task_config.prefer_free:
free_models = [m for m in capable
if m.get_capabilities()['cost'] == 'free']
if free_models:
capable = free_models
# Filter by health check
healthy = [m for m in capable if m.health_check()]
if not healthy:
return None
# Score and select best
return self._score_and_select(healthy, task_config)
def _score_and_select(self, models: List[BaseModel], task_config: TaskConfig) -> BaseModel:
"""Score models and select the best one."""
quality_threshold = self.quality_thresholds.get(task_config.quality_mode, 0.95)
scored = []
for model in models:
caps = model.get_capabilities()
score = 0
# Quality score (most important)
quality_score = caps['quality_score']
if quality_score >= quality_threshold:
score += quality_score * 100 # 0-100 points
else:
score += quality_score * 50 # Penalty for not meeting threshold
# Speed bonus (inverse of time)
time_score = 10 / max(caps['avg_time'], 0.1) # 0-100 points
score += time_score * 20 # Weight: 20%
# Cost penalty
if caps['cost'] == 'paid':
cost_penalty = caps.get('cost_per_image', 0.01) * 1000
score -= cost_penalty # Subtract cost in cents
scored.append((score, model))
# Sort by score (descending) and return best
scored.sort(key=lambda x: x[0], reverse=True)
return scored[0][1]
def process_with_fallback(self, image_path: str, task_config: TaskConfig) -> ProcessingResult:
"""
Process image with automatic fallback on failure.
Args:
image_path: Path to input image
task_config: Task configuration
Returns:
Processing result from first successful model
"""
# Build fallback chain based on quality
models_to_try = self._build_fallback_chain(task_config)
last_error = None
for model in models_to_try:
try:
print(f"🤖 Trying {model.__class__.__name__}...")
result = model.process_image(image_path, task_config)
if result.success:
return result
last_error = result.error
except Exception as e:
last_error = str(e)
continue
# All models failed
from models import ProcessingResult
return ProcessingResult(
model_used='None',
confidence=0.0,
processing_time=0.0,
cost=0.0,
output_path='',
metadata={'error': last_error or 'All models failed'},
success=False,
error=last_error or 'All models failed'
)
def _build_fallback_chain(self, task_config: TaskConfig) -> List[BaseModel]:
"""Build ordered list of models to try (fallback chain)."""
# Default chain for remove-bg:
# 1. Nano Banana Pro (98%, free, 7.9s)
# 2. Gemini 2.5 Flash (95%, free, 8.4s)
# 3. remove.bg (100%, paid, 2.8s)
capable = [m for m in self.models
if task_config.task_type in m.get_capabilities()['tasks']]
# Filter healthy models
healthy = [m for m in capable if m.health_check()]
# Sort by quality (descending), then by cost (free first)
def sort_key(model):
caps = model.get_capabilities()
cost_value = 0 if caps['cost'] == 'free' else 1
return (-caps['quality_score'], cost_value, caps['avg_time'])
healthy.sort(key=sort_key)
return healthy
def get_model_status(self) -> Dict[str, Any]:
"""Get status of all models."""
status = {}
for model in self.models:
caps = model.get_capabilities()
stats = model.get_stats()
model_name = model.__class__.__name__
status[model_name] = {
'healthy': model.health_check(),
'capabilities': caps,
'stats': stats
}
return status