๐ Final optimization: Update error_handling.py with production-ready enhancements
75c1496 verified | """ | |
| Comprehensive error handling and recovery utilities for BitTransformerLM. | |
| Provides robust error recovery mechanisms, graceful degradation, and detailed | |
| error logging for production deployments. | |
| """ | |
| import logging | |
| import traceback | |
| import functools | |
| from typing import Dict, Any, Optional, Callable, Union, Type | |
| from contextlib import contextmanager | |
| import torch | |
| import numpy as np | |
| from .types import ErrorHandler, RecoveryStrategy, LogLevel, TensorLike | |
| class BitTransformerError(Exception): | |
| """Base exception class for BitTransformerLM errors.""" | |
| def __init__(self, message: str, error_code: str = "BTLM_ERROR", | |
| context: Optional[Dict[str, Any]] = None): | |
| self.message = message | |
| self.error_code = error_code | |
| self.context = context or {} | |
| super().__init__(f"[{error_code}] {message}") | |
| class ModelError(BitTransformerError): | |
| """Errors related to model operations.""" | |
| pass | |
| class CompressionError(BitTransformerError): | |
| """Errors related to compression/decompression.""" | |
| pass | |
| class SafetyError(BitTransformerError): | |
| """Errors related to safety gates and telemetry.""" | |
| pass | |
| class DataError(BitTransformerError): | |
| """Errors related to data processing.""" | |
| pass | |
| class DistributedError(BitTransformerError): | |
| """Errors related to distributed training.""" | |
| pass | |
| class ErrorRecoveryManager: | |
| """Manages error recovery strategies and fallback mechanisms.""" | |
| def __init__(self, logger: Optional[logging.Logger] = None): | |
| self.logger = logger or logging.getLogger(__name__) | |
| self.recovery_strategies: Dict[Type[Exception], RecoveryStrategy] = {} | |
| self.error_counts: Dict[str, int] = {} | |
| self.max_retries = 3 | |
| def register_recovery_strategy(self, | |
| error_type: Type[Exception], | |
| strategy: RecoveryStrategy) -> None: | |
| """Register a recovery strategy for a specific error type.""" | |
| self.recovery_strategies[error_type] = strategy | |
| def handle_error(self, | |
| error: Exception, | |
| context: Optional[Dict[str, Any]] = None, | |
| allow_recovery: bool = True) -> Any: | |
| """Handle an error with potential recovery.""" | |
| error_key = f"{type(error).__name__}:{str(error)}" | |
| self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1 | |
| self.logger.error( | |
| f"Error occurred: {error}\n" | |
| f"Context: {context}\n" | |
| f"Traceback: {traceback.format_exc()}" | |
| ) | |
| if allow_recovery and self.error_counts[error_key] <= self.max_retries: | |
| # Try recovery strategy | |
| for error_type, strategy in self.recovery_strategies.items(): | |
| if isinstance(error, error_type): | |
| try: | |
| self.logger.info(f"Attempting recovery for {type(error).__name__}") | |
| return strategy() | |
| except Exception as recovery_error: | |
| self.logger.error(f"Recovery failed: {recovery_error}") | |
| break | |
| # If no recovery or recovery failed, raise the original error | |
| raise error | |
| # Global error recovery manager instance | |
| error_manager = ErrorRecoveryManager() | |
| def with_error_recovery(recovery_value: Any = None, | |
| max_retries: int = 3, | |
| error_types: Optional[tuple] = None): | |
| """Decorator for adding error recovery to functions.""" | |
| def decorator(func: Callable) -> Callable: | |
| def wrapper(*args, **kwargs): | |
| last_error = None | |
| for attempt in range(max_retries + 1): | |
| try: | |
| return func(*args, **kwargs) | |
| except Exception as e: | |
| last_error = e | |
| # Check if we should handle this error type | |
| if error_types and not isinstance(e, error_types): | |
| raise | |
| if attempt < max_retries: | |
| error_manager.logger.warning( | |
| f"Function {func.__name__} failed (attempt {attempt + 1}), retrying..." | |
| ) | |
| continue | |
| # Final attempt failed | |
| error_manager.logger.error( | |
| f"Function {func.__name__} failed after {max_retries + 1} attempts" | |
| ) | |
| break | |
| # Return recovery value or raise last error | |
| if recovery_value is not None: | |
| return recovery_value | |
| raise last_error | |
| return wrapper | |
| return decorator | |
| def safe_operation(operation_name: str, | |
| context: Optional[Dict[str, Any]] = None, | |
| recovery_value: Any = None): | |
| """Context manager for safe operations with error handling.""" | |
| try: | |
| error_manager.logger.debug(f"Starting operation: {operation_name}") | |
| yield | |
| error_manager.logger.debug(f"Completed operation: {operation_name}") | |
| except Exception as e: | |
| error_context = {"operation": operation_name} | |
| if context: | |
| error_context.update(context) | |
| try: | |
| return error_manager.handle_error(e, error_context) | |
| except: | |
| if recovery_value is not None: | |
| error_manager.logger.warning( | |
| f"Operation {operation_name} failed, using recovery value" | |
| ) | |
| return recovery_value | |
| raise | |
| def safe_tensor_operation(tensor_op: Callable[[torch.Tensor], torch.Tensor], | |
| fallback_value: Optional[torch.Tensor] = None) -> Callable: | |
| """Wrapper for tensor operations with safety checks.""" | |
| def wrapper(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: | |
| # Validate input tensor | |
| if not isinstance(tensor, torch.Tensor): | |
| raise DataError("Input must be a torch.Tensor") | |
| if tensor.numel() == 0: | |
| if fallback_value is not None: | |
| return fallback_value | |
| raise DataError("Cannot operate on empty tensor") | |
| # Check for NaN or Inf values | |
| if torch.isnan(tensor).any(): | |
| error_manager.logger.warning("NaN values detected in tensor, attempting to clean") | |
| tensor = torch.nan_to_num(tensor, nan=0.0) | |
| if torch.isinf(tensor).any(): | |
| error_manager.logger.warning("Inf values detected in tensor, attempting to clean") | |
| tensor = torch.nan_to_num(tensor, posinf=1e6, neginf=-1e6) | |
| try: | |
| return tensor_op(tensor, *args, **kwargs) | |
| except (RuntimeError, ValueError) as e: | |
| if "out of memory" in str(e).lower(): | |
| # OOM recovery: try with smaller chunks | |
| error_manager.logger.warning("OOM detected, attempting chunked operation") | |
| return _chunked_tensor_operation(tensor_op, tensor, *args, **kwargs) | |
| elif "device" in str(e).lower(): | |
| # Device mismatch recovery | |
| error_manager.logger.warning("Device mismatch, attempting CPU fallback") | |
| return tensor_op(tensor.cpu(), *args, **kwargs) | |
| else: | |
| raise | |
| return wrapper | |
| def _chunked_tensor_operation(tensor_op: Callable, | |
| tensor: torch.Tensor, | |
| chunk_size: int = 1024, | |
| *args, **kwargs) -> torch.Tensor: | |
| """Execute tensor operation in chunks to avoid OOM.""" | |
| if tensor.size(0) <= chunk_size: | |
| return tensor_op(tensor, *args, **kwargs) | |
| results = [] | |
| for i in range(0, tensor.size(0), chunk_size): | |
| chunk = tensor[i:i + chunk_size] | |
| chunk_result = tensor_op(chunk, *args, **kwargs) | |
| results.append(chunk_result) | |
| return torch.cat(results, dim=0) | |
| def validate_model_inputs(inputs: torch.Tensor, | |
| max_seq_len: int = 8192, | |
| expected_dtype: torch.dtype = torch.long) -> torch.Tensor: | |
| """Validate and sanitize model inputs.""" | |
| if not isinstance(inputs, torch.Tensor): | |
| raise DataError("Model inputs must be torch.Tensor") | |
| # Check dimensions | |
| if inputs.dim() == 1: | |
| inputs = inputs.unsqueeze(0) # Add batch dimension | |
| elif inputs.dim() > 2: | |
| raise DataError(f"Input tensor has too many dimensions: {inputs.dim()}") | |
| # Check sequence length | |
| if inputs.size(-1) > max_seq_len: | |
| error_manager.logger.warning(f"Sequence length {inputs.size(-1)} exceeds max {max_seq_len}, truncating") | |
| inputs = inputs[:, :max_seq_len] | |
| # Check dtype | |
| if inputs.dtype != expected_dtype: | |
| error_manager.logger.warning(f"Converting input dtype from {inputs.dtype} to {expected_dtype}") | |
| inputs = inputs.to(expected_dtype) | |
| # Check value range for bit sequences | |
| if expected_dtype == torch.long: | |
| invalid_values = (inputs < 0) | (inputs > 1) | |
| if invalid_values.any(): | |
| error_manager.logger.warning("Invalid bit values detected, clamping to [0, 1]") | |
| inputs = torch.clamp(inputs, 0, 1) | |
| return inputs | |
| def safe_model_forward(model: torch.nn.Module, | |
| inputs: torch.Tensor, | |
| **kwargs) -> torch.Tensor: | |
| """Safely execute model forward pass with error recovery.""" | |
| inputs = validate_model_inputs(inputs) | |
| try: | |
| with safe_operation("model_forward"): | |
| return model(inputs, **kwargs) | |
| except RuntimeError as e: | |
| if "out of memory" in str(e).lower(): | |
| # Try with gradient checkpointing | |
| error_manager.logger.warning("OOM in forward pass, enabling gradient checkpointing") | |
| from torch.utils.checkpoint import checkpoint | |
| return checkpoint(model, inputs, **kwargs) | |
| elif "device" in str(e).lower(): | |
| # Device mismatch recovery | |
| device = next(model.parameters()).device | |
| inputs = inputs.to(device) | |
| return model(inputs, **kwargs) | |
| else: | |
| raise | |
| def recovery_checkpoint_save(model: torch.nn.Module, | |
| path: str, | |
| additional_data: Optional[Dict[str, Any]] = None) -> bool: | |
| """Save model checkpoint with error recovery.""" | |
| try: | |
| checkpoint_data = { | |
| 'model_state_dict': model.state_dict(), | |
| 'timestamp': torch.tensor(0), # placeholder | |
| } | |
| if additional_data: | |
| checkpoint_data.update(additional_data) | |
| torch.save(checkpoint_data, path) | |
| error_manager.logger.info(f"Checkpoint saved successfully to {path}") | |
| return True | |
| except Exception as e: | |
| error_manager.logger.error(f"Failed to save checkpoint to {path}: {e}") | |
| # Try backup location | |
| backup_path = path + ".backup" | |
| try: | |
| torch.save(checkpoint_data, backup_path) | |
| error_manager.logger.info(f"Checkpoint saved to backup location: {backup_path}") | |
| return True | |
| except Exception as backup_e: | |
| error_manager.logger.error(f"Backup save also failed: {backup_e}") | |
| return False | |
| def setup_error_logging(log_level: LogLevel = "INFO", | |
| log_file: Optional[str] = None) -> logging.Logger: | |
| """Set up comprehensive error logging.""" | |
| logger = logging.getLogger("BitTransformerLM") | |
| logger.setLevel(getattr(logging, log_level)) | |
| # Console handler | |
| console_handler = logging.StreamHandler() | |
| console_formatter = logging.Formatter( | |
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| console_handler.setFormatter(console_formatter) | |
| logger.addHandler(console_handler) | |
| # File handler if specified | |
| if log_file: | |
| file_handler = logging.FileHandler(log_file) | |
| file_formatter = logging.Formatter( | |
| '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s' | |
| ) | |
| file_handler.setFormatter(file_formatter) | |
| logger.addHandler(file_handler) | |
| return logger | |
| # Default recovery strategies | |
| def default_tensor_recovery() -> torch.Tensor: | |
| """Default recovery strategy for tensor operations.""" | |
| return torch.zeros(1, dtype=torch.long) | |
| def default_model_recovery() -> Dict[str, torch.Tensor]: | |
| """Default recovery strategy for model operations.""" | |
| return {"output": torch.zeros(1, dtype=torch.float32)} | |
| # Register default recovery strategies | |
| error_manager.register_recovery_strategy(RuntimeError, default_tensor_recovery) | |
| error_manager.register_recovery_strategy(ModelError, default_model_recovery) |