Spaces:
Build error
Build error
| """ | |
| Logging utilities for Nexuss Transformer Framework | |
| Provides setup_logging and debug logging capabilities | |
| """ | |
| import logging | |
| import sys | |
| from typing import Optional | |
| from pathlib import Path | |
| def setup_logging( | |
| level: str = "INFO", | |
| log_file: Optional[str] = None, | |
| format_string: Optional[str] = None, | |
| ) -> logging.Logger: | |
| """ | |
| Set up logging configuration for NTF. | |
| Args: | |
| level: Logging level ('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL') | |
| log_file: Optional file path to write logs to | |
| format_string: Custom format string (default: includes timestamp, level, module) | |
| Returns: | |
| Configured logger instance | |
| """ | |
| # Default format with detailed info for debugging | |
| if format_string is None: | |
| format_string = ( | |
| "%(asctime)s | %(levelname)-8s | %(name)s | %(module)s:%(lineno)d | %(message)s" | |
| ) | |
| # Convert string level to logging constant | |
| numeric_level = getattr(logging, level.upper(), logging.INFO) | |
| # Create logger | |
| logger = logging.getLogger("ntf") | |
| logger.setLevel(numeric_level) | |
| # Clear existing handlers | |
| logger.handlers.clear() | |
| # Console handler | |
| console_handler = logging.StreamHandler(sys.stdout) | |
| console_handler.setLevel(numeric_level) | |
| console_formatter = logging.Formatter(format_string) | |
| console_handler.setFormatter(console_formatter) | |
| logger.addHandler(console_handler) | |
| # File handler (if specified) | |
| if log_file: | |
| log_path = Path(log_file) | |
| log_path.parent.mkdir(parents=True, exist_ok=True) | |
| file_handler = logging.FileHandler(log_file) | |
| file_handler.setLevel(numeric_level) | |
| file_handler.setFormatter(console_formatter) | |
| logger.addHandler(file_handler) | |
| return logger | |
| def get_logger(name: str = "ntf") -> logging.Logger: | |
| """Get a logger instance with the specified name.""" | |
| return logging.getLogger(name) | |
| def set_log_level(level: str): | |
| """Set the logging level for the NTF logger.""" | |
| logger = logging.getLogger("ntf") | |
| numeric_level = getattr(logging, level.upper(), logging.INFO) | |
| logger.setLevel(numeric_level) | |
| # Update all handlers | |
| for handler in logger.handlers: | |
| handler.setLevel(numeric_level) | |
| class DebugLogger: | |
| """ | |
| Enhanced debug logger for training and model debugging. | |
| Provides methods for: | |
| - Logging tensor statistics | |
| - Tracking gradient norms | |
| - Monitoring memory usage | |
| - Debugging NaN/Inf values | |
| """ | |
| def __init__(self, logger: Optional[logging.Logger] = None): | |
| self.logger = logger or get_logger("ntf.debug") | |
| def log_tensor_stats(self, name: str, tensor, step: int = 0): | |
| """Log statistics for a tensor""" | |
| if tensor is None: | |
| return | |
| stats = { | |
| "shape": tuple(tensor.shape), | |
| "mean": tensor.mean().item() if tensor.numel() > 0 else 0, | |
| "std": tensor.std().item() if tensor.numel() > 1 else 0, | |
| "min": tensor.min().item() if tensor.numel() > 0 else 0, | |
| "max": tensor.max().item() if tensor.numel() > 0 else 0, | |
| "has_nan": bool(torch.isnan(tensor).any()) if hasattr(torch, 'isnan') else False, | |
| "has_inf": bool(torch.isinf(tensor).any()) if hasattr(torch, 'isinf') else False, | |
| } | |
| self.logger.debug(f"[Step {step}] {name}: {stats}") | |
| def log_gradient_norms(self, model, step: int = 0): | |
| """Log gradient norms for all parameters""" | |
| total_norm = 0 | |
| layer_norms = {} | |
| for name, param in model.named_parameters(): | |
| if param.grad is not None: | |
| param_norm = param.grad.data.norm(2) | |
| total_norm += param_norm.item() ** 2 | |
| layer_norms[name] = param_norm.item() | |
| total_norm = total_norm ** 0.5 | |
| self.logger.debug(f"[Step {step}] Total gradient norm: {total_norm:.6f}") | |
| # Log top 5 largest gradients | |
| sorted_norms = sorted(layer_norms.items(), key=lambda x: x[1], reverse=True)[:5] | |
| for name, norm in sorted_norms: | |
| self.logger.debug(f" {name}: {norm:.6f}") | |
| def check_nan_inf(self, value, name: str = "value", raise_error: bool = False) -> bool: | |
| """Check if a value contains NaN or Inf""" | |
| import torch | |
| has_issues = False | |
| if isinstance(value, torch.Tensor): | |
| has_nan = bool(torch.isnan(value).any()) | |
| has_inf = bool(torch.isinf(value).any()) | |
| if has_nan or has_inf: | |
| has_issues = True | |
| msg = f"{name} contains " | |
| if has_nan: | |
| msg += "NaN" | |
| if has_inf: | |
| msg += " and " if has_nan else "" | |
| if has_inf: | |
| msg += "Inf" | |
| if raise_error: | |
| raise ValueError(msg) | |
| else: | |
| self.logger.warning(msg) | |
| return has_issues | |
| # Convenience function for validating configs | |
| def validate_config(config) -> list: | |
| """ | |
| Validate a configuration object and return list of errors. | |
| Args: | |
| config: Configuration object to validate | |
| Returns: | |
| List of error messages (empty if valid) | |
| """ | |
| errors = [] | |
| try: | |
| # Try to trigger __post_init__ validation if it exists | |
| if hasattr(config, '__post_init__'): | |
| config.__post_init__() | |
| except Exception as e: | |
| errors.append(str(e)) | |
| # Check for required attributes based on config type | |
| if hasattr(config, 'learning_rate'): | |
| if config.learning_rate <= 0: | |
| errors.append("learning_rate must be positive") | |
| if hasattr(config, 'per_device_train_batch_size'): | |
| if config.per_device_train_batch_size <= 0: | |
| errors.append("per_device_train_batch_size must be positive") | |
| if hasattr(config, 'max_seq_len'): | |
| if config.max_seq_len <= 0: | |
| errors.append("max_seq_len must be positive") | |
| return errors | |