File size: 13,112 Bytes
75c1496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36c78b1
75c1496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
"""
Comprehensive error handling and recovery utilities for BitTransformerLM.

Provides robust error recovery mechanisms, graceful degradation, and detailed
error logging for production deployments.
"""

import logging
import traceback
import functools
from typing import Dict, Any, Optional, Callable, Union, Type
from contextlib import contextmanager
import torch
import numpy as np

from .types import ErrorHandler, RecoveryStrategy, LogLevel, TensorLike


class BitTransformerError(Exception):
    """Base exception class for BitTransformerLM errors."""
    
    def __init__(self, message: str, error_code: str = "BTLM_ERROR", 
                 context: Optional[Dict[str, Any]] = None):
        self.message = message
        self.error_code = error_code
        self.context = context or {}
        super().__init__(f"[{error_code}] {message}")


class ModelError(BitTransformerError):
    """Errors related to model operations."""
    pass


class CompressionError(BitTransformerError):
    """Errors related to compression/decompression."""
    pass


class SafetyError(BitTransformerError):
    """Errors related to safety gates and telemetry."""
    pass


class DataError(BitTransformerError):
    """Errors related to data processing."""
    pass


class DistributedError(BitTransformerError):
    """Errors related to distributed training."""
    pass


class ErrorRecoveryManager:
    """Manages error recovery strategies and fallback mechanisms."""
    
    def __init__(self, logger: Optional[logging.Logger] = None):
        self.logger = logger or logging.getLogger(__name__)
        self.recovery_strategies: Dict[Type[Exception], RecoveryStrategy] = {}
        self.error_counts: Dict[str, int] = {}
        self.max_retries = 3
        
    def register_recovery_strategy(self, 
                                 error_type: Type[Exception], 
                                 strategy: RecoveryStrategy) -> None:
        """Register a recovery strategy for a specific error type."""
        self.recovery_strategies[error_type] = strategy
        
    def handle_error(self, 
                    error: Exception, 
                    context: Optional[Dict[str, Any]] = None,
                    allow_recovery: bool = True) -> Any:
        """Handle an error with potential recovery."""
        error_key = f"{type(error).__name__}:{str(error)}"
        self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1
        
        self.logger.error(
            f"Error occurred: {error}\n"
            f"Context: {context}\n"
            f"Traceback: {traceback.format_exc()}"
        )
        
        if allow_recovery and self.error_counts[error_key] <= self.max_retries:
            # Try recovery strategy
            for error_type, strategy in self.recovery_strategies.items():
                if isinstance(error, error_type):
                    try:
                        self.logger.info(f"Attempting recovery for {type(error).__name__}")
                        return strategy()
                    except Exception as recovery_error:
                        self.logger.error(f"Recovery failed: {recovery_error}")
                        break
        
        # If no recovery or recovery failed, raise the original error
        raise error


# Global error recovery manager instance
error_manager = ErrorRecoveryManager()


def with_error_recovery(recovery_value: Any = None, 
                       max_retries: int = 3,
                       error_types: Optional[tuple] = None):
    """Decorator for adding error recovery to functions."""
    def decorator(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            last_error = None
            
            for attempt in range(max_retries + 1):
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    last_error = e
                    
                    # Check if we should handle this error type
                    if error_types and not isinstance(e, error_types):
                        raise
                    
                    if attempt < max_retries:
                        error_manager.logger.warning(
                            f"Function {func.__name__} failed (attempt {attempt + 1}), retrying..."
                        )
                        continue
                    
                    # Final attempt failed
                    error_manager.logger.error(
                        f"Function {func.__name__} failed after {max_retries + 1} attempts"
                    )
                    break
            
            # Return recovery value or raise last error
            if recovery_value is not None:
                return recovery_value
            raise last_error
            
        return wrapper
    return decorator


@contextmanager
def safe_operation(operation_name: str, 
                  context: Optional[Dict[str, Any]] = None,
                  recovery_value: Any = None):
    """Context manager for safe operations with error handling."""
    try:
        error_manager.logger.debug(f"Starting operation: {operation_name}")
        yield
        error_manager.logger.debug(f"Completed operation: {operation_name}")
    except Exception as e:
        error_context = {"operation": operation_name}
        if context:
            error_context.update(context)
            
        try:
            return error_manager.handle_error(e, error_context)
        except:
            if recovery_value is not None:
                error_manager.logger.warning(
                    f"Operation {operation_name} failed, using recovery value"
                )
                return recovery_value
            raise


def safe_tensor_operation(tensor_op: Callable[[torch.Tensor], torch.Tensor],
                         fallback_value: Optional[torch.Tensor] = None) -> Callable:
    """Wrapper for tensor operations with safety checks."""
    def wrapper(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        # Validate input tensor
        if not isinstance(tensor, torch.Tensor):
            raise DataError("Input must be a torch.Tensor")
        
        if tensor.numel() == 0:
            if fallback_value is not None:
                return fallback_value
            raise DataError("Cannot operate on empty tensor")
        
        # Check for NaN or Inf values
        if torch.isnan(tensor).any():
            error_manager.logger.warning("NaN values detected in tensor, attempting to clean")
            tensor = torch.nan_to_num(tensor, nan=0.0)
        
        if torch.isinf(tensor).any():
            error_manager.logger.warning("Inf values detected in tensor, attempting to clean")
            tensor = torch.nan_to_num(tensor, posinf=1e6, neginf=-1e6)
        
        try:
            return tensor_op(tensor, *args, **kwargs)
        except (RuntimeError, ValueError) as e:
            if "out of memory" in str(e).lower():
                # OOM recovery: try with smaller chunks
                error_manager.logger.warning("OOM detected, attempting chunked operation")
                return _chunked_tensor_operation(tensor_op, tensor, *args, **kwargs)
            elif "device" in str(e).lower():
                # Device mismatch recovery
                error_manager.logger.warning("Device mismatch, attempting CPU fallback")
                return tensor_op(tensor.cpu(), *args, **kwargs)
            else:
                raise
    
    return wrapper


def _chunked_tensor_operation(tensor_op: Callable, 
                             tensor: torch.Tensor, 
                             chunk_size: int = 1024,
                             *args, **kwargs) -> torch.Tensor:
    """Execute tensor operation in chunks to avoid OOM."""
    if tensor.size(0) <= chunk_size:
        return tensor_op(tensor, *args, **kwargs)
    
    results = []
    for i in range(0, tensor.size(0), chunk_size):
        chunk = tensor[i:i + chunk_size]
        chunk_result = tensor_op(chunk, *args, **kwargs)
        results.append(chunk_result)
    
    return torch.cat(results, dim=0)


def validate_model_inputs(inputs: torch.Tensor, 
                         max_seq_len: int = 8192,
                         expected_dtype: torch.dtype = torch.long) -> torch.Tensor:
    """Validate and sanitize model inputs."""
    if not isinstance(inputs, torch.Tensor):
        raise DataError("Model inputs must be torch.Tensor")
    
    # Check dimensions
    if inputs.dim() == 1:
        inputs = inputs.unsqueeze(0)  # Add batch dimension
    elif inputs.dim() > 2:
        raise DataError(f"Input tensor has too many dimensions: {inputs.dim()}")
    
    # Check sequence length
    if inputs.size(-1) > max_seq_len:
        error_manager.logger.warning(f"Sequence length {inputs.size(-1)} exceeds max {max_seq_len}, truncating")
        inputs = inputs[:, :max_seq_len]
    
    # Check dtype
    if inputs.dtype != expected_dtype:
        error_manager.logger.warning(f"Converting input dtype from {inputs.dtype} to {expected_dtype}")
        inputs = inputs.to(expected_dtype)
    
    # Check value range for bit sequences
    if expected_dtype == torch.long:
        invalid_values = (inputs < 0) | (inputs > 1)
        if invalid_values.any():
            error_manager.logger.warning("Invalid bit values detected, clamping to [0, 1]")
            inputs = torch.clamp(inputs, 0, 1)
    
    return inputs


def safe_model_forward(model: torch.nn.Module, 
                      inputs: torch.Tensor,
                      **kwargs) -> torch.Tensor:
    """Safely execute model forward pass with error recovery."""
    inputs = validate_model_inputs(inputs)
    
    try:
        with safe_operation("model_forward"):
            return model(inputs, **kwargs)
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            # Try with gradient checkpointing
            error_manager.logger.warning("OOM in forward pass, enabling gradient checkpointing")
            from torch.utils.checkpoint import checkpoint
            return checkpoint(model, inputs, **kwargs)
        elif "device" in str(e).lower():
            # Device mismatch recovery
            device = next(model.parameters()).device
            inputs = inputs.to(device)
            return model(inputs, **kwargs)
        else:
            raise


def recovery_checkpoint_save(model: torch.nn.Module, 
                            path: str,
                            additional_data: Optional[Dict[str, Any]] = None) -> bool:
    """Save model checkpoint with error recovery."""
    try:
        checkpoint_data = {
            'model_state_dict': model.state_dict(),
            'timestamp': torch.tensor(0),  # placeholder
        }
        if additional_data:
            checkpoint_data.update(additional_data)
        
        torch.save(checkpoint_data, path, _use_new_zipfile_serialization=True)
        error_manager.logger.info(f"Checkpoint saved successfully to {path}")
        return True
        
    except Exception as e:
        error_manager.logger.error(f"Failed to save checkpoint to {path}: {e}")
        
        # Try backup location
        backup_path = path + ".backup"
        try:
            torch.save(checkpoint_data, backup_path)
            error_manager.logger.info(f"Checkpoint saved to backup location: {backup_path}")
            return True
        except Exception as backup_e:
            error_manager.logger.error(f"Backup save also failed: {backup_e}")
            return False


def setup_error_logging(log_level: LogLevel = "INFO",
                       log_file: Optional[str] = None) -> logging.Logger:
    """Set up comprehensive error logging."""
    logger = logging.getLogger("BitTransformerLM")
    logger.setLevel(getattr(logging, log_level))
    
    # Console handler
    console_handler = logging.StreamHandler()
    console_formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    console_handler.setFormatter(console_formatter)
    logger.addHandler(console_handler)
    
    # File handler if specified
    if log_file:
        file_handler = logging.FileHandler(log_file)
        file_formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
        )
        file_handler.setFormatter(file_formatter)
        logger.addHandler(file_handler)
    
    return logger


# Default recovery strategies
def default_tensor_recovery() -> torch.Tensor:
    """Default recovery strategy for tensor operations."""
    return torch.zeros(1, dtype=torch.long)


def default_model_recovery() -> Dict[str, torch.Tensor]:
    """Default recovery strategy for model operations."""
    return {"output": torch.zeros(1, dtype=torch.float32)}


# Register default recovery strategies
error_manager.register_recovery_strategy(RuntimeError, default_tensor_recovery)
error_manager.register_recovery_strategy(ModelError, default_model_recovery)