| """Logging and Metrics Tracking for Training"""
|
|
|
| import json
|
| import logging
|
| import os
|
| from datetime import datetime
|
| from pathlib import Path
|
| from typing import Any, Dict, List, Optional
|
|
|
| import numpy as np
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| def setup_logging(
|
| log_dir: str = "./logs",
|
| log_level: str = "INFO",
|
| console: bool = True,
|
| file: bool = True,
|
| ):
|
| """Setup logging configuration."""
|
| log_dir = Path(log_dir)
|
| log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| formatter = logging.Formatter(
|
| fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| datefmt="%Y-%m-%d %H:%M:%S",
|
| )
|
|
|
|
|
| root_logger = logging.getLogger()
|
| root_logger.setLevel(getattr(logging, log_level.upper()))
|
|
|
|
|
| root_logger.handlers.clear()
|
|
|
|
|
| if console:
|
| console_handler = logging.StreamHandler()
|
| console_handler.setFormatter(formatter)
|
| root_logger.addHandler(console_handler)
|
|
|
|
|
| if file:
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| log_file = log_dir / f"training_{timestamp}.log"
|
| file_handler = logging.FileHandler(log_file)
|
| file_handler.setFormatter(formatter)
|
| root_logger.addHandler(file_handler)
|
|
|
| logger.info(f"Logging initialized. Log file: {log_file if file else 'console only'}")
|
|
|
|
|
| class MetricsLogger:
|
| """Track and log metrics during training."""
|
|
|
| def __init__(
|
| self,
|
| log_dir: str = "./logs",
|
| experiment_name: Optional[str] = None,
|
| ):
|
| self.log_dir = Path(log_dir)
|
| self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| self.experiment_name = experiment_name or f"zenith_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| self.metrics_file = self.log_dir / f"{self.experiment_name}_metrics.jsonl"
|
| self.tensorboard_logdir = self.log_dir / "tensorboard" / self.experiment_name
|
|
|
|
|
| self.history: List[Dict[str, Any]] = []
|
|
|
|
|
| try:
|
| from torch.utils.tensorboard import SummaryWriter
|
| self.writer = SummaryWriter(log_dir=str(self.tensorboard_logdir))
|
| self.has_tensorboard = True
|
| except ImportError:
|
| self.has_tensorboard = False
|
| logger.warning("TensorBoard not available. Install with: pip install tensorboard")
|
|
|
| def log(
|
| self,
|
| metrics: Dict[str, float],
|
| step: int,
|
| prefix: str = "train",
|
| ):
|
| """Log metrics."""
|
|
|
| log_entry = {
|
| "timestamp": datetime.now().isoformat(),
|
| "step": step,
|
| "prefix": prefix,
|
| **{f"{prefix}/{k}" if prefix != k else k: v for k, v in metrics.items()},
|
| }
|
| self.history.append(log_entry)
|
|
|
|
|
| with open(self.metrics_file, "a") as f:
|
| f.write(json.dumps(log_entry) + "\n")
|
|
|
|
|
| if self.has_tensorboard:
|
| for key, value in metrics.items():
|
| self.writer.add_scalar(f"{prefix}/{key}", value, step)
|
|
|
| def log_hyperparams(self, params: Dict[str, Any]):
|
| """Log hyperparameters."""
|
| if self.has_tensorboard:
|
| from torch.utils.tensorboard import SummaryWriter
|
|
|
| flat_params = self._flatten_dict(params)
|
| self.writer.add_hparams(flat_params, {})
|
|
|
| def _flatten_dict(self, d: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]:
|
| """Flatten nested dictionary."""
|
| items = []
|
| for k, v in d.items():
|
| new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
| if isinstance(v, dict):
|
| items.extend(self._flatten_dict(v, new_key, sep=sep).items())
|
| else:
|
| items.append((new_key, v))
|
| return dict(items)
|
|
|
| def get_metrics(self, prefix: Optional[str] = None) -> List[Dict[str, Any]]:
|
| """Get metrics history, optionally filtered by prefix."""
|
| if prefix is None:
|
| return self.history
|
|
|
| filtered = []
|
| for entry in self.history:
|
| if entry.get("prefix") == prefix:
|
| filtered.append(entry)
|
| return filtered
|
|
|
| def close(self):
|
| """Close logger."""
|
| if self.has_tensorboard:
|
| self.writer.close()
|
|
|
|
|
| class ProgressLogger:
|
| """Simple progress tracking with ETA."""
|
|
|
| def __init__(self, total: int, desc: str = "Progress"):
|
| self.total = total
|
| self.desc = desc
|
| self.current = 0
|
| self.start_time = datetime.now()
|
|
|
| def update(self, n: int = 1):
|
| """Update progress."""
|
| self.current += n
|
| self._log_progress()
|
|
|
| def _log_progress(self):
|
| """Log current progress."""
|
| elapsed = (datetime.now() - self.start_time).total_seconds()
|
| if self.current > 0:
|
| rate = elapsed / self.current
|
| remaining = rate * (self.total - self.current)
|
| logger.info(
|
| f"{self.desc}: {self.current}/{self.total} "
|
| f"({100 * self.current / self.total:.1f}%) "
|
| f"- ETA: {remaining / 60:.1f}m"
|
| )
|
|
|
|
|
| def log_metrics_summary(metrics: Dict[str, float], step: int, logger_obj: Optional[logging.Logger] = None):
|
| """Log a summary of metrics in a nice format."""
|
| if logger_obj is None:
|
| logger_obj = logger
|
|
|
| lines = [f"Step {step} - Metrics Summary:"]
|
| for key, value in sorted(metrics.items()):
|
| if isinstance(value, float):
|
| lines.append(f" {key}: {value:.4f}")
|
| else:
|
| lines.append(f" {key}: {value}")
|
|
|
| logger_obj.info("\n".join(lines))
|
|
|
|
|
| def save_metrics_to_csv(metrics_history: List[Dict[str, Any]], filepath: str):
|
| """Save metrics history to CSV."""
|
| import pandas as pd
|
|
|
| df = pd.DataFrame(metrics_history)
|
| df.to_csv(filepath, index=False)
|
| logger.info(f"Metrics saved to {filepath}")
|
|
|
|
|
| def load_metrics_from_jsonl(filepath: str) -> List[Dict[str, Any]]:
|
| """Load metrics from JSONL file."""
|
| metrics = []
|
| with open(filepath, "r") as f:
|
| for line in f:
|
| metrics.append(json.loads(line.strip()))
|
| return metrics
|
|
|