File size: 13,112 Bytes
75c1496 36c78b1 75c1496 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 |
"""
Comprehensive error handling and recovery utilities for BitTransformerLM.
Provides robust error recovery mechanisms, graceful degradation, and detailed
error logging for production deployments.
"""
import logging
import traceback
import functools
from typing import Dict, Any, Optional, Callable, Union, Type
from contextlib import contextmanager
import torch
import numpy as np
from .types import ErrorHandler, RecoveryStrategy, LogLevel, TensorLike
class BitTransformerError(Exception):
"""Base exception class for BitTransformerLM errors."""
def __init__(self, message: str, error_code: str = "BTLM_ERROR",
context: Optional[Dict[str, Any]] = None):
self.message = message
self.error_code = error_code
self.context = context or {}
super().__init__(f"[{error_code}] {message}")
class ModelError(BitTransformerError):
"""Errors related to model operations."""
pass
class CompressionError(BitTransformerError):
"""Errors related to compression/decompression."""
pass
class SafetyError(BitTransformerError):
"""Errors related to safety gates and telemetry."""
pass
class DataError(BitTransformerError):
"""Errors related to data processing."""
pass
class DistributedError(BitTransformerError):
"""Errors related to distributed training."""
pass
class ErrorRecoveryManager:
"""Manages error recovery strategies and fallback mechanisms."""
def __init__(self, logger: Optional[logging.Logger] = None):
self.logger = logger or logging.getLogger(__name__)
self.recovery_strategies: Dict[Type[Exception], RecoveryStrategy] = {}
self.error_counts: Dict[str, int] = {}
self.max_retries = 3
def register_recovery_strategy(self,
error_type: Type[Exception],
strategy: RecoveryStrategy) -> None:
"""Register a recovery strategy for a specific error type."""
self.recovery_strategies[error_type] = strategy
def handle_error(self,
error: Exception,
context: Optional[Dict[str, Any]] = None,
allow_recovery: bool = True) -> Any:
"""Handle an error with potential recovery."""
error_key = f"{type(error).__name__}:{str(error)}"
self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1
self.logger.error(
f"Error occurred: {error}\n"
f"Context: {context}\n"
f"Traceback: {traceback.format_exc()}"
)
if allow_recovery and self.error_counts[error_key] <= self.max_retries:
# Try recovery strategy
for error_type, strategy in self.recovery_strategies.items():
if isinstance(error, error_type):
try:
self.logger.info(f"Attempting recovery for {type(error).__name__}")
return strategy()
except Exception as recovery_error:
self.logger.error(f"Recovery failed: {recovery_error}")
break
# If no recovery or recovery failed, raise the original error
raise error
# Global error recovery manager instance
error_manager = ErrorRecoveryManager()
def with_error_recovery(recovery_value: Any = None,
max_retries: int = 3,
error_types: Optional[tuple] = None):
"""Decorator for adding error recovery to functions."""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
last_error = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
# Check if we should handle this error type
if error_types and not isinstance(e, error_types):
raise
if attempt < max_retries:
error_manager.logger.warning(
f"Function {func.__name__} failed (attempt {attempt + 1}), retrying..."
)
continue
# Final attempt failed
error_manager.logger.error(
f"Function {func.__name__} failed after {max_retries + 1} attempts"
)
break
# Return recovery value or raise last error
if recovery_value is not None:
return recovery_value
raise last_error
return wrapper
return decorator
@contextmanager
def safe_operation(operation_name: str,
context: Optional[Dict[str, Any]] = None,
recovery_value: Any = None):
"""Context manager for safe operations with error handling."""
try:
error_manager.logger.debug(f"Starting operation: {operation_name}")
yield
error_manager.logger.debug(f"Completed operation: {operation_name}")
except Exception as e:
error_context = {"operation": operation_name}
if context:
error_context.update(context)
try:
return error_manager.handle_error(e, error_context)
except:
if recovery_value is not None:
error_manager.logger.warning(
f"Operation {operation_name} failed, using recovery value"
)
return recovery_value
raise
def safe_tensor_operation(tensor_op: Callable[[torch.Tensor], torch.Tensor],
fallback_value: Optional[torch.Tensor] = None) -> Callable:
"""Wrapper for tensor operations with safety checks."""
def wrapper(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
# Validate input tensor
if not isinstance(tensor, torch.Tensor):
raise DataError("Input must be a torch.Tensor")
if tensor.numel() == 0:
if fallback_value is not None:
return fallback_value
raise DataError("Cannot operate on empty tensor")
# Check for NaN or Inf values
if torch.isnan(tensor).any():
error_manager.logger.warning("NaN values detected in tensor, attempting to clean")
tensor = torch.nan_to_num(tensor, nan=0.0)
if torch.isinf(tensor).any():
error_manager.logger.warning("Inf values detected in tensor, attempting to clean")
tensor = torch.nan_to_num(tensor, posinf=1e6, neginf=-1e6)
try:
return tensor_op(tensor, *args, **kwargs)
except (RuntimeError, ValueError) as e:
if "out of memory" in str(e).lower():
# OOM recovery: try with smaller chunks
error_manager.logger.warning("OOM detected, attempting chunked operation")
return _chunked_tensor_operation(tensor_op, tensor, *args, **kwargs)
elif "device" in str(e).lower():
# Device mismatch recovery
error_manager.logger.warning("Device mismatch, attempting CPU fallback")
return tensor_op(tensor.cpu(), *args, **kwargs)
else:
raise
return wrapper
def _chunked_tensor_operation(tensor_op: Callable,
tensor: torch.Tensor,
chunk_size: int = 1024,
*args, **kwargs) -> torch.Tensor:
"""Execute tensor operation in chunks to avoid OOM."""
if tensor.size(0) <= chunk_size:
return tensor_op(tensor, *args, **kwargs)
results = []
for i in range(0, tensor.size(0), chunk_size):
chunk = tensor[i:i + chunk_size]
chunk_result = tensor_op(chunk, *args, **kwargs)
results.append(chunk_result)
return torch.cat(results, dim=0)
def validate_model_inputs(inputs: torch.Tensor,
max_seq_len: int = 8192,
expected_dtype: torch.dtype = torch.long) -> torch.Tensor:
"""Validate and sanitize model inputs."""
if not isinstance(inputs, torch.Tensor):
raise DataError("Model inputs must be torch.Tensor")
# Check dimensions
if inputs.dim() == 1:
inputs = inputs.unsqueeze(0) # Add batch dimension
elif inputs.dim() > 2:
raise DataError(f"Input tensor has too many dimensions: {inputs.dim()}")
# Check sequence length
if inputs.size(-1) > max_seq_len:
error_manager.logger.warning(f"Sequence length {inputs.size(-1)} exceeds max {max_seq_len}, truncating")
inputs = inputs[:, :max_seq_len]
# Check dtype
if inputs.dtype != expected_dtype:
error_manager.logger.warning(f"Converting input dtype from {inputs.dtype} to {expected_dtype}")
inputs = inputs.to(expected_dtype)
# Check value range for bit sequences
if expected_dtype == torch.long:
invalid_values = (inputs < 0) | (inputs > 1)
if invalid_values.any():
error_manager.logger.warning("Invalid bit values detected, clamping to [0, 1]")
inputs = torch.clamp(inputs, 0, 1)
return inputs
def safe_model_forward(model: torch.nn.Module,
inputs: torch.Tensor,
**kwargs) -> torch.Tensor:
"""Safely execute model forward pass with error recovery."""
inputs = validate_model_inputs(inputs)
try:
with safe_operation("model_forward"):
return model(inputs, **kwargs)
except RuntimeError as e:
if "out of memory" in str(e).lower():
# Try with gradient checkpointing
error_manager.logger.warning("OOM in forward pass, enabling gradient checkpointing")
from torch.utils.checkpoint import checkpoint
return checkpoint(model, inputs, **kwargs)
elif "device" in str(e).lower():
# Device mismatch recovery
device = next(model.parameters()).device
inputs = inputs.to(device)
return model(inputs, **kwargs)
else:
raise
def recovery_checkpoint_save(model: torch.nn.Module,
path: str,
additional_data: Optional[Dict[str, Any]] = None) -> bool:
"""Save model checkpoint with error recovery."""
try:
checkpoint_data = {
'model_state_dict': model.state_dict(),
'timestamp': torch.tensor(0), # placeholder
}
if additional_data:
checkpoint_data.update(additional_data)
torch.save(checkpoint_data, path, _use_new_zipfile_serialization=True)
error_manager.logger.info(f"Checkpoint saved successfully to {path}")
return True
except Exception as e:
error_manager.logger.error(f"Failed to save checkpoint to {path}: {e}")
# Try backup location
backup_path = path + ".backup"
try:
torch.save(checkpoint_data, backup_path)
error_manager.logger.info(f"Checkpoint saved to backup location: {backup_path}")
return True
except Exception as backup_e:
error_manager.logger.error(f"Backup save also failed: {backup_e}")
return False
def setup_error_logging(log_level: LogLevel = "INFO",
log_file: Optional[str] = None) -> logging.Logger:
"""Set up comprehensive error logging."""
logger = logging.getLogger("BitTransformerLM")
logger.setLevel(getattr(logging, log_level))
# Console handler
console_handler = logging.StreamHandler()
console_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)
# File handler if specified
if log_file:
file_handler = logging.FileHandler(log_file)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
return logger
# Default recovery strategies
def default_tensor_recovery() -> torch.Tensor:
"""Default recovery strategy for tensor operations."""
return torch.zeros(1, dtype=torch.long)
def default_model_recovery() -> Dict[str, torch.Tensor]:
"""Default recovery strategy for model operations."""
return {"output": torch.zeros(1, dtype=torch.float32)}
# Register default recovery strategies
error_manager.register_recovery_strategy(RuntimeError, default_tensor_recovery)
error_manager.register_recovery_strategy(ModelError, default_model_recovery) |