| """ |
| Logging utilities for Nexuss Transformer Framework |
| Provides setup_logging and debug logging capabilities |
| """ |
|
|
| import logging |
| import sys |
| from typing import Optional |
| from pathlib import Path |
|
|
|
|
| def setup_logging( |
| level: str = "INFO", |
| log_file: Optional[str] = None, |
| format_string: Optional[str] = None, |
| ) -> logging.Logger: |
| """ |
| Set up logging configuration for NTF. |
| |
| Args: |
| level: Logging level ('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL') |
| log_file: Optional file path to write logs to |
| format_string: Custom format string (default: includes timestamp, level, module) |
| |
| Returns: |
| Configured logger instance |
| """ |
| |
| if format_string is None: |
| format_string = ( |
| "%(asctime)s | %(levelname)-8s | %(name)s | %(module)s:%(lineno)d | %(message)s" |
| ) |
| |
| |
| numeric_level = getattr(logging, level.upper(), logging.INFO) |
| |
| |
| logger = logging.getLogger("ntf") |
| logger.setLevel(numeric_level) |
| |
| |
| logger.handlers.clear() |
| |
| |
| console_handler = logging.StreamHandler(sys.stdout) |
| console_handler.setLevel(numeric_level) |
| console_formatter = logging.Formatter(format_string) |
| console_handler.setFormatter(console_formatter) |
| logger.addHandler(console_handler) |
| |
| |
| if log_file: |
| log_path = Path(log_file) |
| log_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| file_handler = logging.FileHandler(log_file) |
| file_handler.setLevel(numeric_level) |
| file_handler.setFormatter(console_formatter) |
| logger.addHandler(file_handler) |
| |
| return logger |
|
|
|
|
| def get_logger(name: str = "ntf") -> logging.Logger: |
| """Get a logger instance with the specified name.""" |
| return logging.getLogger(name) |
|
|
|
|
| def set_log_level(level: str): |
| """Set the logging level for the NTF logger.""" |
| logger = logging.getLogger("ntf") |
| numeric_level = getattr(logging, level.upper(), logging.INFO) |
| logger.setLevel(numeric_level) |
| |
| |
| for handler in logger.handlers: |
| handler.setLevel(numeric_level) |
|
|
|
|
| class DebugLogger: |
| """ |
| Enhanced debug logger for training and model debugging. |
| |
| Provides methods for: |
| - Logging tensor statistics |
| - Tracking gradient norms |
| - Monitoring memory usage |
| - Debugging NaN/Inf values |
| """ |
| |
| def __init__(self, logger: Optional[logging.Logger] = None): |
| self.logger = logger or get_logger("ntf.debug") |
| |
| def log_tensor_stats(self, name: str, tensor, step: int = 0): |
| """Log statistics for a tensor""" |
| if tensor is None: |
| return |
| |
| stats = { |
| "shape": tuple(tensor.shape), |
| "mean": tensor.mean().item() if tensor.numel() > 0 else 0, |
| "std": tensor.std().item() if tensor.numel() > 1 else 0, |
| "min": tensor.min().item() if tensor.numel() > 0 else 0, |
| "max": tensor.max().item() if tensor.numel() > 0 else 0, |
| "has_nan": bool(torch.isnan(tensor).any()) if hasattr(torch, 'isnan') else False, |
| "has_inf": bool(torch.isinf(tensor).any()) if hasattr(torch, 'isinf') else False, |
| } |
| |
| self.logger.debug(f"[Step {step}] {name}: {stats}") |
| |
| def log_gradient_norms(self, model, step: int = 0): |
| """Log gradient norms for all parameters""" |
| total_norm = 0 |
| layer_norms = {} |
| |
| for name, param in model.named_parameters(): |
| if param.grad is not None: |
| param_norm = param.grad.data.norm(2) |
| total_norm += param_norm.item() ** 2 |
| layer_norms[name] = param_norm.item() |
| |
| total_norm = total_norm ** 0.5 |
| |
| self.logger.debug(f"[Step {step}] Total gradient norm: {total_norm:.6f}") |
| |
| |
| sorted_norms = sorted(layer_norms.items(), key=lambda x: x[1], reverse=True)[:5] |
| for name, norm in sorted_norms: |
| self.logger.debug(f" {name}: {norm:.6f}") |
| |
| def check_nan_inf(self, value, name: str = "value", raise_error: bool = False) -> bool: |
| """Check if a value contains NaN or Inf""" |
| import torch |
| |
| has_issues = False |
| |
| if isinstance(value, torch.Tensor): |
| has_nan = bool(torch.isnan(value).any()) |
| has_inf = bool(torch.isinf(value).any()) |
| |
| if has_nan or has_inf: |
| has_issues = True |
| msg = f"{name} contains " |
| if has_nan: |
| msg += "NaN" |
| if has_inf: |
| msg += " and " if has_nan else "" |
| if has_inf: |
| msg += "Inf" |
| |
| if raise_error: |
| raise ValueError(msg) |
| else: |
| self.logger.warning(msg) |
| |
| return has_issues |
|
|
|
|
| |
| def validate_config(config) -> list: |
| """ |
| Validate a configuration object and return list of errors. |
| |
| Args: |
| config: Configuration object to validate |
| |
| Returns: |
| List of error messages (empty if valid) |
| """ |
| errors = [] |
| |
| try: |
| |
| if hasattr(config, '__post_init__'): |
| config.__post_init__() |
| except Exception as e: |
| errors.append(str(e)) |
| |
| |
| if hasattr(config, 'learning_rate'): |
| if config.learning_rate <= 0: |
| errors.append("learning_rate must be positive") |
| |
| if hasattr(config, 'per_device_train_batch_size'): |
| if config.per_device_train_batch_size <= 0: |
| errors.append("per_device_train_batch_size must be positive") |
| |
| if hasattr(config, 'max_seq_len'): |
| if config.max_seq_len <= 0: |
| errors.append("max_seq_len must be positive") |
| |
| return errors |
|
|