Spaces:
Sleeping
Sleeping
| """ | |
| Model Versioning and Input Caching System | |
| Tracks model versions, performance, and implements intelligent caching | |
| Features: | |
| - Model version tracking with metadata | |
| - Performance metrics per model version | |
| - A/B testing framework | |
| - Automated rollback capabilities | |
| - SHA256 input fingerprinting | |
| - Intelligent caching with invalidation | |
| - Cache performance analytics | |
| Author: MiniMax Agent | |
| Date: 2025-10-29 | |
| Version: 1.0.0 | |
| """ | |
| import hashlib | |
| import json | |
| import logging | |
| from typing import Dict, List, Any, Optional, Tuple | |
| from datetime import datetime, timedelta | |
| from dataclasses import dataclass, asdict | |
| from collections import defaultdict, deque | |
| from enum import Enum | |
| import os | |
| logger = logging.getLogger(__name__) | |
| class ModelStatus(Enum): | |
| """Model deployment status""" | |
| ACTIVE = "active" | |
| TESTING = "testing" | |
| DEPRECATED = "deprecated" | |
| RETIRED = "retired" | |
| class ModelVersion: | |
| """Model version metadata""" | |
| model_id: str | |
| version: str | |
| model_name: str | |
| model_path: str | |
| deployment_date: str | |
| status: ModelStatus | |
| metadata: Dict[str, Any] | |
| performance_metrics: Dict[str, float] | |
| def to_dict(self) -> Dict[str, Any]: | |
| data = asdict(self) | |
| data["status"] = self.status.value | |
| return data | |
| class CacheEntry: | |
| """Cache entry with metadata""" | |
| cache_key: str | |
| input_hash: str | |
| result_data: Dict[str, Any] | |
| created_at: str | |
| last_accessed: str | |
| access_count: int | |
| model_version: str | |
| size_bytes: int | |
| def to_dict(self) -> Dict[str, Any]: | |
| return asdict(self) | |
| class ModelRegistry: | |
| """ | |
| Registry for tracking model versions and performance | |
| Supports version comparison and automated rollback | |
| """ | |
| def __init__(self): | |
| self.models: Dict[str, Dict[str, ModelVersion]] = defaultdict(dict) | |
| self.active_versions: Dict[str, str] = {} # model_id -> version | |
| self.performance_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000)) | |
| logger.info("Model Registry initialized") | |
| def register_model( | |
| self, | |
| model_id: str, | |
| version: str, | |
| model_name: str, | |
| model_path: str, | |
| metadata: Optional[Dict[str, Any]] = None, | |
| set_active: bool = False | |
| ) -> ModelVersion: | |
| """Register a new model version""" | |
| model_version = ModelVersion( | |
| model_id=model_id, | |
| version=version, | |
| model_name=model_name, | |
| model_path=model_path, | |
| deployment_date=datetime.utcnow().isoformat(), | |
| status=ModelStatus.TESTING if not set_active else ModelStatus.ACTIVE, | |
| metadata=metadata or {}, | |
| performance_metrics={} | |
| ) | |
| self.models[model_id][version] = model_version | |
| if set_active: | |
| self.set_active_version(model_id, version) | |
| logger.info(f"Registered model {model_id} v{version}") | |
| return model_version | |
| def set_active_version(self, model_id: str, version: str): | |
| """Set active version for a model""" | |
| if model_id not in self.models or version not in self.models[model_id]: | |
| raise ValueError(f"Model {model_id} v{version} not found") | |
| # Update previous active version status | |
| if model_id in self.active_versions: | |
| prev_version = self.active_versions[model_id] | |
| if prev_version in self.models[model_id]: | |
| self.models[model_id][prev_version].status = ModelStatus.DEPRECATED | |
| # Set new active version | |
| self.active_versions[model_id] = version | |
| self.models[model_id][version].status = ModelStatus.ACTIVE | |
| logger.info(f"Set active version: {model_id} -> v{version}") | |
| def get_active_version(self, model_id: str) -> Optional[ModelVersion]: | |
| """Get currently active model version""" | |
| if model_id not in self.active_versions: | |
| return None | |
| version = self.active_versions[model_id] | |
| return self.models[model_id].get(version) | |
| def record_performance( | |
| self, | |
| model_id: str, | |
| version: str, | |
| metrics: Dict[str, float] | |
| ): | |
| """Record performance metrics for a model version""" | |
| if model_id not in self.models or version not in self.models[model_id]: | |
| logger.warning(f"Cannot record performance for unknown model {model_id} v{version}") | |
| return | |
| performance_record = { | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "model_id": model_id, | |
| "version": version, | |
| "metrics": metrics | |
| } | |
| self.performance_history[f"{model_id}:{version}"].append(performance_record) | |
| # Update model version metrics (running average) | |
| model_version = self.models[model_id][version] | |
| for metric_name, value in metrics.items(): | |
| if metric_name in model_version.performance_metrics: | |
| # Running average | |
| current = model_version.performance_metrics[metric_name] | |
| model_version.performance_metrics[metric_name] = (current + value) / 2 | |
| else: | |
| model_version.performance_metrics[metric_name] = value | |
| def compare_versions( | |
| self, | |
| model_id: str, | |
| version1: str, | |
| version2: str, | |
| metric: str = "accuracy" | |
| ) -> Dict[str, Any]: | |
| """Compare performance between two model versions""" | |
| if model_id not in self.models: | |
| return {"error": f"Model {model_id} not found"} | |
| v1 = self.models[model_id].get(version1) | |
| v2 = self.models[model_id].get(version2) | |
| if not v1 or not v2: | |
| return {"error": "One or both versions not found"} | |
| v1_metric = v1.performance_metrics.get(metric, 0.0) | |
| v2_metric = v2.performance_metrics.get(metric, 0.0) | |
| return { | |
| "model_id": model_id, | |
| "versions": { | |
| version1: v1_metric, | |
| version2: v2_metric | |
| }, | |
| "difference": v2_metric - v1_metric, | |
| "improvement_percent": ((v2_metric - v1_metric) / v1_metric * 100) if v1_metric > 0 else 0.0, | |
| "metric": metric | |
| } | |
| def rollback_to_version(self, model_id: str, version: str) -> bool: | |
| """Rollback to a previous model version""" | |
| if model_id not in self.models or version not in self.models[model_id]: | |
| logger.error(f"Cannot rollback: model {model_id} v{version} not found") | |
| return False | |
| logger.warning(f"Rolling back {model_id} to v{version}") | |
| self.set_active_version(model_id, version) | |
| return True | |
| def get_model_inventory(self) -> Dict[str, Any]: | |
| """Get complete model inventory""" | |
| inventory = {} | |
| for model_id, versions in self.models.items(): | |
| inventory[model_id] = { | |
| "active_version": self.active_versions.get(model_id, "none"), | |
| "total_versions": len(versions), | |
| "versions": { | |
| ver: model.to_dict() for ver, model in versions.items() | |
| } | |
| } | |
| return inventory | |
| def auto_rollback_if_degraded( | |
| self, | |
| model_id: str, | |
| metric: str = "accuracy", | |
| threshold_drop: float = 0.05 # 5% drop | |
| ) -> bool: | |
| """Automatically rollback if performance degraded significantly""" | |
| if model_id not in self.active_versions: | |
| return False | |
| current_version = self.active_versions[model_id] | |
| current_model = self.models[model_id][current_version] | |
| # Find previous active version | |
| previous_versions = [ | |
| (ver, model) for ver, model in self.models[model_id].items() | |
| if model.status == ModelStatus.DEPRECATED | |
| ] | |
| if not previous_versions: | |
| return False | |
| # Get most recent deprecated version | |
| previous_versions.sort( | |
| key=lambda x: x[1].deployment_date, | |
| reverse=True | |
| ) | |
| prev_version, prev_model = previous_versions[0] | |
| # Compare performance | |
| current_metric = current_model.performance_metrics.get(metric, 0.0) | |
| prev_metric = prev_model.performance_metrics.get(metric, 0.0) | |
| if prev_metric == 0.0: | |
| return False | |
| drop_percent = (prev_metric - current_metric) / prev_metric | |
| if drop_percent > threshold_drop: | |
| logger.warning( | |
| f"Performance degradation detected for {model_id}: " | |
| f"{metric} dropped {drop_percent*100:.1f}%. " | |
| f"Rolling back to v{prev_version}" | |
| ) | |
| return self.rollback_to_version(model_id, prev_version) | |
| return False | |
| class InputCache: | |
| """ | |
| Intelligent caching system with SHA256 fingerprinting | |
| Caches analysis results to avoid reprocessing identical files | |
| """ | |
| def __init__( | |
| self, | |
| max_cache_size_mb: int = 1000, | |
| ttl_hours: int = 24 | |
| ): | |
| self.cache: Dict[str, CacheEntry] = {} | |
| self.max_cache_size_bytes = max_cache_size_mb * 1024 * 1024 | |
| self.current_cache_size = 0 | |
| self.ttl_hours = ttl_hours | |
| # Cache statistics | |
| self.hits = 0 | |
| self.misses = 0 | |
| self.evictions = 0 | |
| logger.info(f"Input Cache initialized (max size: {max_cache_size_mb}MB, TTL: {ttl_hours}h)") | |
| def compute_hash(self, file_path: str) -> str: | |
| """Compute SHA256 hash of file""" | |
| sha256_hash = hashlib.sha256() | |
| try: | |
| with open(file_path, "rb") as f: | |
| # Read file in chunks for memory efficiency | |
| for byte_block in iter(lambda: f.read(4096), b""): | |
| sha256_hash.update(byte_block) | |
| return sha256_hash.hexdigest() | |
| except Exception as e: | |
| logger.error(f"Failed to compute hash for {file_path}: {str(e)}") | |
| return "" | |
| def compute_data_hash(self, data: bytes) -> str: | |
| """Compute SHA256 hash of data bytes""" | |
| return hashlib.sha256(data).hexdigest() | |
| def get( | |
| self, | |
| input_hash: str, | |
| model_version: str | |
| ) -> Optional[Dict[str, Any]]: | |
| """Retrieve cached result""" | |
| cache_key = f"{input_hash}:{model_version}" | |
| if cache_key not in self.cache: | |
| self.misses += 1 | |
| return None | |
| entry = self.cache[cache_key] | |
| # Check TTL | |
| created_time = datetime.fromisoformat(entry.created_at) | |
| if datetime.utcnow() - created_time > timedelta(hours=self.ttl_hours): | |
| # Expired | |
| self._evict(cache_key) | |
| self.misses += 1 | |
| return None | |
| # Update access tracking | |
| entry.last_accessed = datetime.utcnow().isoformat() | |
| entry.access_count += 1 | |
| self.hits += 1 | |
| logger.info(f"Cache hit: {cache_key[:16]}...") | |
| return entry.result_data | |
| def put( | |
| self, | |
| input_hash: str, | |
| model_version: str, | |
| result_data: Dict[str, Any] | |
| ): | |
| """Store result in cache""" | |
| cache_key = f"{input_hash}:{model_version}" | |
| # Estimate size | |
| size_bytes = len(json.dumps(result_data).encode()) | |
| # Check if we need to evict | |
| while self.current_cache_size + size_bytes > self.max_cache_size_bytes: | |
| self._evict_lru() | |
| entry = CacheEntry( | |
| cache_key=cache_key, | |
| input_hash=input_hash, | |
| result_data=result_data, | |
| created_at=datetime.utcnow().isoformat(), | |
| last_accessed=datetime.utcnow().isoformat(), | |
| access_count=0, | |
| model_version=model_version, | |
| size_bytes=size_bytes | |
| ) | |
| self.cache[cache_key] = entry | |
| self.current_cache_size += size_bytes | |
| logger.info(f"Cache stored: {cache_key[:16]}... ({size_bytes} bytes)") | |
| def invalidate_model_version(self, model_version: str): | |
| """Invalidate all cache entries for a model version""" | |
| keys_to_remove = [ | |
| key for key, entry in self.cache.items() | |
| if entry.model_version == model_version | |
| ] | |
| for key in keys_to_remove: | |
| self._evict(key) | |
| logger.info(f"Invalidated {len(keys_to_remove)} cache entries for model v{model_version}") | |
| def _evict(self, cache_key: str): | |
| """Evict a specific cache entry""" | |
| if cache_key in self.cache: | |
| entry = self.cache.pop(cache_key) | |
| self.current_cache_size -= entry.size_bytes | |
| self.evictions += 1 | |
| def _evict_lru(self): | |
| """Evict least recently used entry""" | |
| if not self.cache: | |
| return | |
| # Find LRU entry | |
| lru_key = min( | |
| self.cache.keys(), | |
| key=lambda k: self.cache[k].last_accessed | |
| ) | |
| self._evict(lru_key) | |
| logger.debug(f"LRU eviction: {lru_key[:16]}...") | |
| def get_statistics(self) -> Dict[str, Any]: | |
| """Get cache performance statistics""" | |
| total_requests = self.hits + self.misses | |
| hit_rate = self.hits / total_requests if total_requests > 0 else 0.0 | |
| return { | |
| "total_entries": len(self.cache), | |
| "cache_size_mb": self.current_cache_size / (1024 * 1024), | |
| "max_size_mb": self.max_cache_size_bytes / (1024 * 1024), | |
| "utilization_percent": (self.current_cache_size / self.max_cache_size_bytes * 100), | |
| "total_requests": total_requests, | |
| "hits": self.hits, | |
| "misses": self.misses, | |
| "hit_rate_percent": hit_rate * 100, | |
| "evictions": self.evictions, | |
| "ttl_hours": self.ttl_hours | |
| } | |
| def clear(self): | |
| """Clear all cache entries""" | |
| entry_count = len(self.cache) | |
| self.cache.clear() | |
| self.current_cache_size = 0 | |
| logger.info(f"Cache cleared: {entry_count} entries removed") | |
| class ModelVersioningSystem: | |
| """ | |
| Complete model versioning and caching system | |
| Integrates model registry with input caching | |
| """ | |
| def __init__( | |
| self, | |
| cache_size_mb: int = 1000, | |
| cache_ttl_hours: int = 24 | |
| ): | |
| self.model_registry = ModelRegistry() | |
| self.input_cache = InputCache(cache_size_mb, cache_ttl_hours) | |
| # Initialize default models | |
| self._initialize_default_models() | |
| logger.info("Model Versioning System initialized") | |
| def _initialize_default_models(self): | |
| """Initialize default model versions""" | |
| default_models = [ | |
| ("document_classifier", "1.0.0", "Bio_ClinicalBERT", "emilyalsentzer/Bio_ClinicalBERT"), | |
| ("clinical_ner", "1.0.0", "Biomedical NER", "d4data/biomedical-ner-all"), | |
| ("clinical_generation", "1.0.0", "BioGPT-Large", "microsoft/BioGPT-Large"), | |
| ("medical_qa", "1.0.0", "RoBERTa-SQuAD2", "deepset/roberta-base-squad2"), | |
| ("general_medical", "1.0.0", "PubMedBERT", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"), | |
| ("drug_interaction", "1.0.0", "SciBERT", "allenai/scibert_scivocab_uncased"), | |
| ("clinical_summarization", "1.0.0", "BigBird-Pegasus", "google/bigbird-pegasus-large-pubmed") | |
| ] | |
| for model_id, version, name, path in default_models: | |
| self.model_registry.register_model( | |
| model_id=model_id, | |
| version=version, | |
| model_name=name, | |
| model_path=path, | |
| metadata={"initialized": "2025-10-29"}, | |
| set_active=True | |
| ) | |
| def process_with_cache( | |
| self, | |
| input_path: str, | |
| model_id: str, | |
| process_func: callable | |
| ) -> Tuple[Dict[str, Any], bool]: | |
| """ | |
| Process input with caching | |
| Returns: (result, from_cache) | |
| """ | |
| # Get active model version | |
| active_model = self.model_registry.get_active_version(model_id) | |
| if not active_model: | |
| logger.warning(f"No active version for model {model_id}") | |
| return process_func(input_path), False | |
| # Compute input hash | |
| input_hash = self.input_cache.compute_hash(input_path) | |
| if not input_hash: | |
| # Hash failed, process without cache | |
| return process_func(input_path), False | |
| # Check cache | |
| cached_result = self.input_cache.get(input_hash, active_model.version) | |
| if cached_result is not None: | |
| logger.info(f"Returning cached result for {model_id}") | |
| return cached_result, True | |
| # Process and cache | |
| result = process_func(input_path) | |
| self.input_cache.put(input_hash, active_model.version, result) | |
| return result, False | |
| def get_system_status(self) -> Dict[str, Any]: | |
| """Get complete system status""" | |
| return { | |
| "model_registry": { | |
| "total_models": len(self.model_registry.models), | |
| "active_models": len(self.model_registry.active_versions), | |
| "inventory": self.model_registry.get_model_inventory() | |
| }, | |
| "cache": self.input_cache.get_statistics(), | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| # Global instance | |
| _versioning_system = None | |
| def get_versioning_system() -> ModelVersioningSystem: | |
| """Get singleton versioning system instance""" | |
| global _versioning_system | |
| if _versioning_system is None: | |
| _versioning_system = ModelVersioningSystem() | |
| return _versioning_system | |