""" Model Versioning and Input Caching System Tracks model versions, performance, and implements intelligent caching Features: - Model version tracking with metadata - Performance metrics per model version - A/B testing framework - Automated rollback capabilities - SHA256 input fingerprinting - Intelligent caching with invalidation - Cache performance analytics Author: MiniMax Agent Date: 2025-10-29 Version: 1.0.0 """ import hashlib import json import logging from typing import Dict, List, Any, Optional, Tuple from datetime import datetime, timedelta from dataclasses import dataclass, asdict from collections import defaultdict, deque from enum import Enum import os logger = logging.getLogger(__name__) class ModelStatus(Enum): """Model deployment status""" ACTIVE = "active" TESTING = "testing" DEPRECATED = "deprecated" RETIRED = "retired" @dataclass class ModelVersion: """Model version metadata""" model_id: str version: str model_name: str model_path: str deployment_date: str status: ModelStatus metadata: Dict[str, Any] performance_metrics: Dict[str, float] def to_dict(self) -> Dict[str, Any]: data = asdict(self) data["status"] = self.status.value return data @dataclass class CacheEntry: """Cache entry with metadata""" cache_key: str input_hash: str result_data: Dict[str, Any] created_at: str last_accessed: str access_count: int model_version: str size_bytes: int def to_dict(self) -> Dict[str, Any]: return asdict(self) class ModelRegistry: """ Registry for tracking model versions and performance Supports version comparison and automated rollback """ def __init__(self): self.models: Dict[str, Dict[str, ModelVersion]] = defaultdict(dict) self.active_versions: Dict[str, str] = {} # model_id -> version self.performance_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000)) logger.info("Model Registry initialized") def register_model( self, model_id: str, version: str, model_name: str, model_path: str, metadata: Optional[Dict[str, Any]] = None, set_active: bool = False ) -> ModelVersion: """Register a new model version""" model_version = ModelVersion( model_id=model_id, version=version, model_name=model_name, model_path=model_path, deployment_date=datetime.utcnow().isoformat(), status=ModelStatus.TESTING if not set_active else ModelStatus.ACTIVE, metadata=metadata or {}, performance_metrics={} ) self.models[model_id][version] = model_version if set_active: self.set_active_version(model_id, version) logger.info(f"Registered model {model_id} v{version}") return model_version def set_active_version(self, model_id: str, version: str): """Set active version for a model""" if model_id not in self.models or version not in self.models[model_id]: raise ValueError(f"Model {model_id} v{version} not found") # Update previous active version status if model_id in self.active_versions: prev_version = self.active_versions[model_id] if prev_version in self.models[model_id]: self.models[model_id][prev_version].status = ModelStatus.DEPRECATED # Set new active version self.active_versions[model_id] = version self.models[model_id][version].status = ModelStatus.ACTIVE logger.info(f"Set active version: {model_id} -> v{version}") def get_active_version(self, model_id: str) -> Optional[ModelVersion]: """Get currently active model version""" if model_id not in self.active_versions: return None version = self.active_versions[model_id] return self.models[model_id].get(version) def record_performance( self, model_id: str, version: str, metrics: Dict[str, float] ): """Record performance metrics for a model version""" if model_id not in self.models or version not in self.models[model_id]: logger.warning(f"Cannot record performance for unknown model {model_id} v{version}") return performance_record = { "timestamp": datetime.utcnow().isoformat(), "model_id": model_id, "version": version, "metrics": metrics } self.performance_history[f"{model_id}:{version}"].append(performance_record) # Update model version metrics (running average) model_version = self.models[model_id][version] for metric_name, value in metrics.items(): if metric_name in model_version.performance_metrics: # Running average current = model_version.performance_metrics[metric_name] model_version.performance_metrics[metric_name] = (current + value) / 2 else: model_version.performance_metrics[metric_name] = value def compare_versions( self, model_id: str, version1: str, version2: str, metric: str = "accuracy" ) -> Dict[str, Any]: """Compare performance between two model versions""" if model_id not in self.models: return {"error": f"Model {model_id} not found"} v1 = self.models[model_id].get(version1) v2 = self.models[model_id].get(version2) if not v1 or not v2: return {"error": "One or both versions not found"} v1_metric = v1.performance_metrics.get(metric, 0.0) v2_metric = v2.performance_metrics.get(metric, 0.0) return { "model_id": model_id, "versions": { version1: v1_metric, version2: v2_metric }, "difference": v2_metric - v1_metric, "improvement_percent": ((v2_metric - v1_metric) / v1_metric * 100) if v1_metric > 0 else 0.0, "metric": metric } def rollback_to_version(self, model_id: str, version: str) -> bool: """Rollback to a previous model version""" if model_id not in self.models or version not in self.models[model_id]: logger.error(f"Cannot rollback: model {model_id} v{version} not found") return False logger.warning(f"Rolling back {model_id} to v{version}") self.set_active_version(model_id, version) return True def get_model_inventory(self) -> Dict[str, Any]: """Get complete model inventory""" inventory = {} for model_id, versions in self.models.items(): inventory[model_id] = { "active_version": self.active_versions.get(model_id, "none"), "total_versions": len(versions), "versions": { ver: model.to_dict() for ver, model in versions.items() } } return inventory def auto_rollback_if_degraded( self, model_id: str, metric: str = "accuracy", threshold_drop: float = 0.05 # 5% drop ) -> bool: """Automatically rollback if performance degraded significantly""" if model_id not in self.active_versions: return False current_version = self.active_versions[model_id] current_model = self.models[model_id][current_version] # Find previous active version previous_versions = [ (ver, model) for ver, model in self.models[model_id].items() if model.status == ModelStatus.DEPRECATED ] if not previous_versions: return False # Get most recent deprecated version previous_versions.sort( key=lambda x: x[1].deployment_date, reverse=True ) prev_version, prev_model = previous_versions[0] # Compare performance current_metric = current_model.performance_metrics.get(metric, 0.0) prev_metric = prev_model.performance_metrics.get(metric, 0.0) if prev_metric == 0.0: return False drop_percent = (prev_metric - current_metric) / prev_metric if drop_percent > threshold_drop: logger.warning( f"Performance degradation detected for {model_id}: " f"{metric} dropped {drop_percent*100:.1f}%. " f"Rolling back to v{prev_version}" ) return self.rollback_to_version(model_id, prev_version) return False class InputCache: """ Intelligent caching system with SHA256 fingerprinting Caches analysis results to avoid reprocessing identical files """ def __init__( self, max_cache_size_mb: int = 1000, ttl_hours: int = 24 ): self.cache: Dict[str, CacheEntry] = {} self.max_cache_size_bytes = max_cache_size_mb * 1024 * 1024 self.current_cache_size = 0 self.ttl_hours = ttl_hours # Cache statistics self.hits = 0 self.misses = 0 self.evictions = 0 logger.info(f"Input Cache initialized (max size: {max_cache_size_mb}MB, TTL: {ttl_hours}h)") def compute_hash(self, file_path: str) -> str: """Compute SHA256 hash of file""" sha256_hash = hashlib.sha256() try: with open(file_path, "rb") as f: # Read file in chunks for memory efficiency for byte_block in iter(lambda: f.read(4096), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() except Exception as e: logger.error(f"Failed to compute hash for {file_path}: {str(e)}") return "" def compute_data_hash(self, data: bytes) -> str: """Compute SHA256 hash of data bytes""" return hashlib.sha256(data).hexdigest() def get( self, input_hash: str, model_version: str ) -> Optional[Dict[str, Any]]: """Retrieve cached result""" cache_key = f"{input_hash}:{model_version}" if cache_key not in self.cache: self.misses += 1 return None entry = self.cache[cache_key] # Check TTL created_time = datetime.fromisoformat(entry.created_at) if datetime.utcnow() - created_time > timedelta(hours=self.ttl_hours): # Expired self._evict(cache_key) self.misses += 1 return None # Update access tracking entry.last_accessed = datetime.utcnow().isoformat() entry.access_count += 1 self.hits += 1 logger.info(f"Cache hit: {cache_key[:16]}...") return entry.result_data def put( self, input_hash: str, model_version: str, result_data: Dict[str, Any] ): """Store result in cache""" cache_key = f"{input_hash}:{model_version}" # Estimate size size_bytes = len(json.dumps(result_data).encode()) # Check if we need to evict while self.current_cache_size + size_bytes > self.max_cache_size_bytes: self._evict_lru() entry = CacheEntry( cache_key=cache_key, input_hash=input_hash, result_data=result_data, created_at=datetime.utcnow().isoformat(), last_accessed=datetime.utcnow().isoformat(), access_count=0, model_version=model_version, size_bytes=size_bytes ) self.cache[cache_key] = entry self.current_cache_size += size_bytes logger.info(f"Cache stored: {cache_key[:16]}... ({size_bytes} bytes)") def invalidate_model_version(self, model_version: str): """Invalidate all cache entries for a model version""" keys_to_remove = [ key for key, entry in self.cache.items() if entry.model_version == model_version ] for key in keys_to_remove: self._evict(key) logger.info(f"Invalidated {len(keys_to_remove)} cache entries for model v{model_version}") def _evict(self, cache_key: str): """Evict a specific cache entry""" if cache_key in self.cache: entry = self.cache.pop(cache_key) self.current_cache_size -= entry.size_bytes self.evictions += 1 def _evict_lru(self): """Evict least recently used entry""" if not self.cache: return # Find LRU entry lru_key = min( self.cache.keys(), key=lambda k: self.cache[k].last_accessed ) self._evict(lru_key) logger.debug(f"LRU eviction: {lru_key[:16]}...") def get_statistics(self) -> Dict[str, Any]: """Get cache performance statistics""" total_requests = self.hits + self.misses hit_rate = self.hits / total_requests if total_requests > 0 else 0.0 return { "total_entries": len(self.cache), "cache_size_mb": self.current_cache_size / (1024 * 1024), "max_size_mb": self.max_cache_size_bytes / (1024 * 1024), "utilization_percent": (self.current_cache_size / self.max_cache_size_bytes * 100), "total_requests": total_requests, "hits": self.hits, "misses": self.misses, "hit_rate_percent": hit_rate * 100, "evictions": self.evictions, "ttl_hours": self.ttl_hours } def clear(self): """Clear all cache entries""" entry_count = len(self.cache) self.cache.clear() self.current_cache_size = 0 logger.info(f"Cache cleared: {entry_count} entries removed") class ModelVersioningSystem: """ Complete model versioning and caching system Integrates model registry with input caching """ def __init__( self, cache_size_mb: int = 1000, cache_ttl_hours: int = 24 ): self.model_registry = ModelRegistry() self.input_cache = InputCache(cache_size_mb, cache_ttl_hours) # Initialize default models self._initialize_default_models() logger.info("Model Versioning System initialized") def _initialize_default_models(self): """Initialize default model versions""" default_models = [ ("document_classifier", "1.0.0", "Bio_ClinicalBERT", "emilyalsentzer/Bio_ClinicalBERT"), ("clinical_ner", "1.0.0", "Biomedical NER", "d4data/biomedical-ner-all"), ("clinical_generation", "1.0.0", "BioGPT-Large", "microsoft/BioGPT-Large"), ("medical_qa", "1.0.0", "RoBERTa-SQuAD2", "deepset/roberta-base-squad2"), ("general_medical", "1.0.0", "PubMedBERT", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"), ("drug_interaction", "1.0.0", "SciBERT", "allenai/scibert_scivocab_uncased"), ("clinical_summarization", "1.0.0", "BigBird-Pegasus", "google/bigbird-pegasus-large-pubmed") ] for model_id, version, name, path in default_models: self.model_registry.register_model( model_id=model_id, version=version, model_name=name, model_path=path, metadata={"initialized": "2025-10-29"}, set_active=True ) def process_with_cache( self, input_path: str, model_id: str, process_func: callable ) -> Tuple[Dict[str, Any], bool]: """ Process input with caching Returns: (result, from_cache) """ # Get active model version active_model = self.model_registry.get_active_version(model_id) if not active_model: logger.warning(f"No active version for model {model_id}") return process_func(input_path), False # Compute input hash input_hash = self.input_cache.compute_hash(input_path) if not input_hash: # Hash failed, process without cache return process_func(input_path), False # Check cache cached_result = self.input_cache.get(input_hash, active_model.version) if cached_result is not None: logger.info(f"Returning cached result for {model_id}") return cached_result, True # Process and cache result = process_func(input_path) self.input_cache.put(input_hash, active_model.version, result) return result, False def get_system_status(self) -> Dict[str, Any]: """Get complete system status""" return { "model_registry": { "total_models": len(self.model_registry.models), "active_models": len(self.model_registry.active_versions), "inventory": self.model_registry.get_model_inventory() }, "cache": self.input_cache.get_statistics(), "timestamp": datetime.utcnow().isoformat() } # Global instance _versioning_system = None def get_versioning_system() -> ModelVersioningSystem: """Get singleton versioning system instance""" global _versioning_system if _versioning_system is None: _versioning_system = ModelVersioningSystem() return _versioning_system