"""Logging and Metrics Tracking for Training""" import json import logging import os from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional import numpy as np logger = logging.getLogger(__name__) def setup_logging( log_dir: str = "./logs", log_level: str = "INFO", console: bool = True, file: bool = True, ): """Setup logging configuration.""" log_dir = Path(log_dir) log_dir.mkdir(parents=True, exist_ok=True) # Create formatter formatter = logging.Formatter( fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # Configure root logger root_logger = logging.getLogger() root_logger.setLevel(getattr(logging, log_level.upper())) # Clear existing handlers root_logger.handlers.clear() # Console handler if console: console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) root_logger.addHandler(console_handler) # File handler if file: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = log_dir / f"training_{timestamp}.log" file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) root_logger.addHandler(file_handler) logger.info(f"Logging initialized. Log file: {log_file if file else 'console only'}") class MetricsLogger: """Track and log metrics during training.""" def __init__( self, log_dir: str = "./logs", experiment_name: Optional[str] = None, ): self.log_dir = Path(log_dir) self.log_dir.mkdir(parents=True, exist_ok=True) self.experiment_name = experiment_name or f"zenith_{datetime.now().strftime('%Y%m%d_%H%M%S')}" self.metrics_file = self.log_dir / f"{self.experiment_name}_metrics.jsonl" self.tensorboard_logdir = self.log_dir / "tensorboard" / self.experiment_name # Metrics history self.history: List[Dict[str, Any]] = [] # TensorBoard try: from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=str(self.tensorboard_logdir)) self.has_tensorboard = True except ImportError: self.has_tensorboard = False logger.warning("TensorBoard not available. Install with: pip install tensorboard") def log( self, metrics: Dict[str, float], step: int, prefix: str = "train", ): """Log metrics.""" # Add timestamp and step log_entry = { "timestamp": datetime.now().isoformat(), "step": step, "prefix": prefix, **{f"{prefix}/{k}" if prefix != k else k: v for k, v in metrics.items()}, } self.history.append(log_entry) # Write to file with open(self.metrics_file, "a") as f: f.write(json.dumps(log_entry) + "\n") # TensorBoard if self.has_tensorboard: for key, value in metrics.items(): self.writer.add_scalar(f"{prefix}/{key}", value, step) def log_hyperparams(self, params: Dict[str, Any]): """Log hyperparameters.""" if self.has_tensorboard: from torch.utils.tensorboard import SummaryWriter # TensorBoard expects flat dict flat_params = self._flatten_dict(params) self.writer.add_hparams(flat_params, {}) def _flatten_dict(self, d: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]: """Flatten nested dictionary.""" items = [] for k, v in d.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k if isinstance(v, dict): items.extend(self._flatten_dict(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items) def get_metrics(self, prefix: Optional[str] = None) -> List[Dict[str, Any]]: """Get metrics history, optionally filtered by prefix.""" if prefix is None: return self.history filtered = [] for entry in self.history: if entry.get("prefix") == prefix: filtered.append(entry) return filtered def close(self): """Close logger.""" if self.has_tensorboard: self.writer.close() class ProgressLogger: """Simple progress tracking with ETA.""" def __init__(self, total: int, desc: str = "Progress"): self.total = total self.desc = desc self.current = 0 self.start_time = datetime.now() def update(self, n: int = 1): """Update progress.""" self.current += n self._log_progress() def _log_progress(self): """Log current progress.""" elapsed = (datetime.now() - self.start_time).total_seconds() if self.current > 0: rate = elapsed / self.current remaining = rate * (self.total - self.current) logger.info( f"{self.desc}: {self.current}/{self.total} " f"({100 * self.current / self.total:.1f}%) " f"- ETA: {remaining / 60:.1f}m" ) def log_metrics_summary(metrics: Dict[str, float], step: int, logger_obj: Optional[logging.Logger] = None): """Log a summary of metrics in a nice format.""" if logger_obj is None: logger_obj = logger lines = [f"Step {step} - Metrics Summary:"] for key, value in sorted(metrics.items()): if isinstance(value, float): lines.append(f" {key}: {value:.4f}") else: lines.append(f" {key}: {value}") logger_obj.info("\n".join(lines)) def save_metrics_to_csv(metrics_history: List[Dict[str, Any]], filepath: str): """Save metrics history to CSV.""" import pandas as pd df = pd.DataFrame(metrics_history) df.to_csv(filepath, index=False) logger.info(f"Metrics saved to {filepath}") def load_metrics_from_jsonl(filepath: str) -> List[Dict[str, Any]]: """Load metrics from JSONL file.""" metrics = [] with open(filepath, "r") as f: for line in f: metrics.append(json.loads(line.strip())) return metrics