"""AimStack logging integration.""" from pathlib import Path from typing import Dict, Any, Optional import subprocess import json from datetime import datetime try: from aim import Run HAS_AIM = True except ImportError: HAS_AIM = False from taoTrain.config import TrainingConfig class AimLogger: """AimStack logger for tracking training metrics and hyperparameters.""" def __init__(self, config: TrainingConfig): """ Initialize AimStack logger. Args: config: Training configuration """ self.config = config self.run: Optional[Run] = None if HAS_AIM: # Initialize AimStack run repo_path = Path(config.aim_repo) repo_path.mkdir(parents=True, exist_ok=True) self.run = Run(repo=str(repo_path)) # Log hyperparameters self._log_hyperparameters() else: print("Warning: AimStack not installed. Install with: pip install aim") def _log_hyperparameters(self): """Log hyperparameters to AimStack.""" if self.run is None: return # Log model config self.run["hparams/model"] = { "architecture": self.config.model.architecture_type.value, "vocab_size": self.config.model.vocab_size, "hidden_dim": self.config.model.hidden_dim, "num_layers": self.config.model.num_layers, "num_heads": self.config.model.num_heads, "dropout": self.config.model.dropout, "max_seq_length": self.config.model.max_seq_length, } # Log training config self.run["hparams/training"] = { "batch_size": self.config.batch_size, "num_epochs": self.config.num_epochs, "learning_rate": self.config.optimizer.learning_rate, "weight_decay": self.config.optimizer.weight_decay, "gradient_accumulation_steps": self.config.gradient_accumulation_steps, "max_grad_norm": self.config.max_grad_norm, "dtype": self.config.dtype.value, "seed": self.config.seed, } # Log optimizer and scheduler config self.run["hparams/optimizer"] = { "optimizer_type": self.config.optimizer.optimizer_type.value, "learning_rate": self.config.optimizer.learning_rate, "weight_decay": self.config.optimizer.weight_decay, } self.run["hparams/scheduler"] = { "scheduler_type": self.config.scheduler.scheduler_type.value, "warmup_steps": self.config.scheduler.warmup_steps, "warmup_ratio": self.config.scheduler.warmup_ratio, } # Log dataset config self.run["hparams/dataset"] = { "dataset_name": self.config.dataset.dataset_name, "split": self.config.dataset.split, "max_samples": self.config.dataset.max_samples, } # Log mode self.run["hparams/mode"] = self.config.mode.value # Log git hash if available try: git_hash = subprocess.check_output( ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL ).decode().strip() self.run["hparams/git_hash"] = git_hash except: pass # Log timestamp self.run["hparams/timestamp"] = datetime.now().isoformat() def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): """ Log metrics to AimStack. Args: metrics: Dict of metric names to values step: Global step (optional, auto-increments if not provided) """ if self.run is None: return step = metrics.pop("step", step) for metric_name, metric_value in metrics.items(): # Flatten nested dicts if isinstance(metric_value, dict): for nested_key, nested_val in metric_value.items(): self.run.track( float(nested_val), name=f"{metric_name}/{nested_key}", step=step, ) else: try: self.run.track( float(metric_value), name=metric_name, step=step, ) except (ValueError, TypeError): # Skip non-numeric metrics pass def log_text(self, name: str, value: str, step: Optional[int] = None): """Log text content.""" if self.run is None: return # AimStack doesn't have direct text logging, use metadata metadata = getattr(self.run, '_metadata', {}) if isinstance(metadata, dict): metadata[name] = value def finish(self): """Finish the run.""" if self.run: self.run.close()