WCNegentropy commited on
Commit
75c1496
·
verified ·
1 Parent(s): 8414e94

🚀 Final optimization: Update error_handling.py with production-ready enhancements

Browse files
Files changed (1) hide show
  1. bit_transformer/error_handling.py +350 -0
bit_transformer/error_handling.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive error handling and recovery utilities for BitTransformerLM.
3
+
4
+ Provides robust error recovery mechanisms, graceful degradation, and detailed
5
+ error logging for production deployments.
6
+ """
7
+
8
+ import logging
9
+ import traceback
10
+ import functools
11
+ from typing import Dict, Any, Optional, Callable, Union, Type
12
+ from contextlib import contextmanager
13
+ import torch
14
+ import numpy as np
15
+
16
+ from .types import ErrorHandler, RecoveryStrategy, LogLevel, TensorLike
17
+
18
+
19
+ class BitTransformerError(Exception):
20
+ """Base exception class for BitTransformerLM errors."""
21
+
22
+ def __init__(self, message: str, error_code: str = "BTLM_ERROR",
23
+ context: Optional[Dict[str, Any]] = None):
24
+ self.message = message
25
+ self.error_code = error_code
26
+ self.context = context or {}
27
+ super().__init__(f"[{error_code}] {message}")
28
+
29
+
30
+ class ModelError(BitTransformerError):
31
+ """Errors related to model operations."""
32
+ pass
33
+
34
+
35
+ class CompressionError(BitTransformerError):
36
+ """Errors related to compression/decompression."""
37
+ pass
38
+
39
+
40
+ class SafetyError(BitTransformerError):
41
+ """Errors related to safety gates and telemetry."""
42
+ pass
43
+
44
+
45
+ class DataError(BitTransformerError):
46
+ """Errors related to data processing."""
47
+ pass
48
+
49
+
50
+ class DistributedError(BitTransformerError):
51
+ """Errors related to distributed training."""
52
+ pass
53
+
54
+
55
+ class ErrorRecoveryManager:
56
+ """Manages error recovery strategies and fallback mechanisms."""
57
+
58
+ def __init__(self, logger: Optional[logging.Logger] = None):
59
+ self.logger = logger or logging.getLogger(__name__)
60
+ self.recovery_strategies: Dict[Type[Exception], RecoveryStrategy] = {}
61
+ self.error_counts: Dict[str, int] = {}
62
+ self.max_retries = 3
63
+
64
+ def register_recovery_strategy(self,
65
+ error_type: Type[Exception],
66
+ strategy: RecoveryStrategy) -> None:
67
+ """Register a recovery strategy for a specific error type."""
68
+ self.recovery_strategies[error_type] = strategy
69
+
70
+ def handle_error(self,
71
+ error: Exception,
72
+ context: Optional[Dict[str, Any]] = None,
73
+ allow_recovery: bool = True) -> Any:
74
+ """Handle an error with potential recovery."""
75
+ error_key = f"{type(error).__name__}:{str(error)}"
76
+ self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1
77
+
78
+ self.logger.error(
79
+ f"Error occurred: {error}\n"
80
+ f"Context: {context}\n"
81
+ f"Traceback: {traceback.format_exc()}"
82
+ )
83
+
84
+ if allow_recovery and self.error_counts[error_key] <= self.max_retries:
85
+ # Try recovery strategy
86
+ for error_type, strategy in self.recovery_strategies.items():
87
+ if isinstance(error, error_type):
88
+ try:
89
+ self.logger.info(f"Attempting recovery for {type(error).__name__}")
90
+ return strategy()
91
+ except Exception as recovery_error:
92
+ self.logger.error(f"Recovery failed: {recovery_error}")
93
+ break
94
+
95
+ # If no recovery or recovery failed, raise the original error
96
+ raise error
97
+
98
+
99
+ # Global error recovery manager instance
100
+ error_manager = ErrorRecoveryManager()
101
+
102
+
103
+ def with_error_recovery(recovery_value: Any = None,
104
+ max_retries: int = 3,
105
+ error_types: Optional[tuple] = None):
106
+ """Decorator for adding error recovery to functions."""
107
+ def decorator(func: Callable) -> Callable:
108
+ @functools.wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ last_error = None
111
+
112
+ for attempt in range(max_retries + 1):
113
+ try:
114
+ return func(*args, **kwargs)
115
+ except Exception as e:
116
+ last_error = e
117
+
118
+ # Check if we should handle this error type
119
+ if error_types and not isinstance(e, error_types):
120
+ raise
121
+
122
+ if attempt < max_retries:
123
+ error_manager.logger.warning(
124
+ f"Function {func.__name__} failed (attempt {attempt + 1}), retrying..."
125
+ )
126
+ continue
127
+
128
+ # Final attempt failed
129
+ error_manager.logger.error(
130
+ f"Function {func.__name__} failed after {max_retries + 1} attempts"
131
+ )
132
+ break
133
+
134
+ # Return recovery value or raise last error
135
+ if recovery_value is not None:
136
+ return recovery_value
137
+ raise last_error
138
+
139
+ return wrapper
140
+ return decorator
141
+
142
+
143
+ @contextmanager
144
+ def safe_operation(operation_name: str,
145
+ context: Optional[Dict[str, Any]] = None,
146
+ recovery_value: Any = None):
147
+ """Context manager for safe operations with error handling."""
148
+ try:
149
+ error_manager.logger.debug(f"Starting operation: {operation_name}")
150
+ yield
151
+ error_manager.logger.debug(f"Completed operation: {operation_name}")
152
+ except Exception as e:
153
+ error_context = {"operation": operation_name}
154
+ if context:
155
+ error_context.update(context)
156
+
157
+ try:
158
+ return error_manager.handle_error(e, error_context)
159
+ except:
160
+ if recovery_value is not None:
161
+ error_manager.logger.warning(
162
+ f"Operation {operation_name} failed, using recovery value"
163
+ )
164
+ return recovery_value
165
+ raise
166
+
167
+
168
+ def safe_tensor_operation(tensor_op: Callable[[torch.Tensor], torch.Tensor],
169
+ fallback_value: Optional[torch.Tensor] = None) -> Callable:
170
+ """Wrapper for tensor operations with safety checks."""
171
+ def wrapper(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
172
+ # Validate input tensor
173
+ if not isinstance(tensor, torch.Tensor):
174
+ raise DataError("Input must be a torch.Tensor")
175
+
176
+ if tensor.numel() == 0:
177
+ if fallback_value is not None:
178
+ return fallback_value
179
+ raise DataError("Cannot operate on empty tensor")
180
+
181
+ # Check for NaN or Inf values
182
+ if torch.isnan(tensor).any():
183
+ error_manager.logger.warning("NaN values detected in tensor, attempting to clean")
184
+ tensor = torch.nan_to_num(tensor, nan=0.0)
185
+
186
+ if torch.isinf(tensor).any():
187
+ error_manager.logger.warning("Inf values detected in tensor, attempting to clean")
188
+ tensor = torch.nan_to_num(tensor, posinf=1e6, neginf=-1e6)
189
+
190
+ try:
191
+ return tensor_op(tensor, *args, **kwargs)
192
+ except (RuntimeError, ValueError) as e:
193
+ if "out of memory" in str(e).lower():
194
+ # OOM recovery: try with smaller chunks
195
+ error_manager.logger.warning("OOM detected, attempting chunked operation")
196
+ return _chunked_tensor_operation(tensor_op, tensor, *args, **kwargs)
197
+ elif "device" in str(e).lower():
198
+ # Device mismatch recovery
199
+ error_manager.logger.warning("Device mismatch, attempting CPU fallback")
200
+ return tensor_op(tensor.cpu(), *args, **kwargs)
201
+ else:
202
+ raise
203
+
204
+ return wrapper
205
+
206
+
207
+ def _chunked_tensor_operation(tensor_op: Callable,
208
+ tensor: torch.Tensor,
209
+ chunk_size: int = 1024,
210
+ *args, **kwargs) -> torch.Tensor:
211
+ """Execute tensor operation in chunks to avoid OOM."""
212
+ if tensor.size(0) <= chunk_size:
213
+ return tensor_op(tensor, *args, **kwargs)
214
+
215
+ results = []
216
+ for i in range(0, tensor.size(0), chunk_size):
217
+ chunk = tensor[i:i + chunk_size]
218
+ chunk_result = tensor_op(chunk, *args, **kwargs)
219
+ results.append(chunk_result)
220
+
221
+ return torch.cat(results, dim=0)
222
+
223
+
224
+ def validate_model_inputs(inputs: torch.Tensor,
225
+ max_seq_len: int = 8192,
226
+ expected_dtype: torch.dtype = torch.long) -> torch.Tensor:
227
+ """Validate and sanitize model inputs."""
228
+ if not isinstance(inputs, torch.Tensor):
229
+ raise DataError("Model inputs must be torch.Tensor")
230
+
231
+ # Check dimensions
232
+ if inputs.dim() == 1:
233
+ inputs = inputs.unsqueeze(0) # Add batch dimension
234
+ elif inputs.dim() > 2:
235
+ raise DataError(f"Input tensor has too many dimensions: {inputs.dim()}")
236
+
237
+ # Check sequence length
238
+ if inputs.size(-1) > max_seq_len:
239
+ error_manager.logger.warning(f"Sequence length {inputs.size(-1)} exceeds max {max_seq_len}, truncating")
240
+ inputs = inputs[:, :max_seq_len]
241
+
242
+ # Check dtype
243
+ if inputs.dtype != expected_dtype:
244
+ error_manager.logger.warning(f"Converting input dtype from {inputs.dtype} to {expected_dtype}")
245
+ inputs = inputs.to(expected_dtype)
246
+
247
+ # Check value range for bit sequences
248
+ if expected_dtype == torch.long:
249
+ invalid_values = (inputs < 0) | (inputs > 1)
250
+ if invalid_values.any():
251
+ error_manager.logger.warning("Invalid bit values detected, clamping to [0, 1]")
252
+ inputs = torch.clamp(inputs, 0, 1)
253
+
254
+ return inputs
255
+
256
+
257
+ def safe_model_forward(model: torch.nn.Module,
258
+ inputs: torch.Tensor,
259
+ **kwargs) -> torch.Tensor:
260
+ """Safely execute model forward pass with error recovery."""
261
+ inputs = validate_model_inputs(inputs)
262
+
263
+ try:
264
+ with safe_operation("model_forward"):
265
+ return model(inputs, **kwargs)
266
+ except RuntimeError as e:
267
+ if "out of memory" in str(e).lower():
268
+ # Try with gradient checkpointing
269
+ error_manager.logger.warning("OOM in forward pass, enabling gradient checkpointing")
270
+ from torch.utils.checkpoint import checkpoint
271
+ return checkpoint(model, inputs, **kwargs)
272
+ elif "device" in str(e).lower():
273
+ # Device mismatch recovery
274
+ device = next(model.parameters()).device
275
+ inputs = inputs.to(device)
276
+ return model(inputs, **kwargs)
277
+ else:
278
+ raise
279
+
280
+
281
+ def recovery_checkpoint_save(model: torch.nn.Module,
282
+ path: str,
283
+ additional_data: Optional[Dict[str, Any]] = None) -> bool:
284
+ """Save model checkpoint with error recovery."""
285
+ try:
286
+ checkpoint_data = {
287
+ 'model_state_dict': model.state_dict(),
288
+ 'timestamp': torch.tensor(0), # placeholder
289
+ }
290
+ if additional_data:
291
+ checkpoint_data.update(additional_data)
292
+
293
+ torch.save(checkpoint_data, path)
294
+ error_manager.logger.info(f"Checkpoint saved successfully to {path}")
295
+ return True
296
+
297
+ except Exception as e:
298
+ error_manager.logger.error(f"Failed to save checkpoint to {path}: {e}")
299
+
300
+ # Try backup location
301
+ backup_path = path + ".backup"
302
+ try:
303
+ torch.save(checkpoint_data, backup_path)
304
+ error_manager.logger.info(f"Checkpoint saved to backup location: {backup_path}")
305
+ return True
306
+ except Exception as backup_e:
307
+ error_manager.logger.error(f"Backup save also failed: {backup_e}")
308
+ return False
309
+
310
+
311
+ def setup_error_logging(log_level: LogLevel = "INFO",
312
+ log_file: Optional[str] = None) -> logging.Logger:
313
+ """Set up comprehensive error logging."""
314
+ logger = logging.getLogger("BitTransformerLM")
315
+ logger.setLevel(getattr(logging, log_level))
316
+
317
+ # Console handler
318
+ console_handler = logging.StreamHandler()
319
+ console_formatter = logging.Formatter(
320
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
321
+ )
322
+ console_handler.setFormatter(console_formatter)
323
+ logger.addHandler(console_handler)
324
+
325
+ # File handler if specified
326
+ if log_file:
327
+ file_handler = logging.FileHandler(log_file)
328
+ file_formatter = logging.Formatter(
329
+ '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
330
+ )
331
+ file_handler.setFormatter(file_formatter)
332
+ logger.addHandler(file_handler)
333
+
334
+ return logger
335
+
336
+
337
+ # Default recovery strategies
338
+ def default_tensor_recovery() -> torch.Tensor:
339
+ """Default recovery strategy for tensor operations."""
340
+ return torch.zeros(1, dtype=torch.long)
341
+
342
+
343
+ def default_model_recovery() -> Dict[str, torch.Tensor]:
344
+ """Default recovery strategy for model operations."""
345
+ return {"output": torch.zeros(1, dtype=torch.float32)}
346
+
347
+
348
+ # Register default recovery strategies
349
+ error_manager.register_recovery_strategy(RuntimeError, default_tensor_recovery)
350
+ error_manager.register_recovery_strategy(ModelError, default_model_recovery)