StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""AimStack logging integration."""
from pathlib import Path
from typing import Dict, Any, Optional
import subprocess
import json
from datetime import datetime
try:
from aim import Run
HAS_AIM = True
except ImportError:
HAS_AIM = False
from taoTrain.config import TrainingConfig
class AimLogger:
"""AimStack logger for tracking training metrics and hyperparameters."""
def __init__(self, config: TrainingConfig):
"""
Initialize AimStack logger.
Args:
config: Training configuration
"""
self.config = config
self.run: Optional[Run] = None
if HAS_AIM:
# Initialize AimStack run
repo_path = Path(config.aim_repo)
repo_path.mkdir(parents=True, exist_ok=True)
self.run = Run(repo=str(repo_path))
# Log hyperparameters
self._log_hyperparameters()
else:
print("Warning: AimStack not installed. Install with: pip install aim")
def _log_hyperparameters(self):
"""Log hyperparameters to AimStack."""
if self.run is None:
return
# Log model config
self.run["hparams/model"] = {
"architecture": self.config.model.architecture_type.value,
"vocab_size": self.config.model.vocab_size,
"hidden_dim": self.config.model.hidden_dim,
"num_layers": self.config.model.num_layers,
"num_heads": self.config.model.num_heads,
"dropout": self.config.model.dropout,
"max_seq_length": self.config.model.max_seq_length,
}
# Log training config
self.run["hparams/training"] = {
"batch_size": self.config.batch_size,
"num_epochs": self.config.num_epochs,
"learning_rate": self.config.optimizer.learning_rate,
"weight_decay": self.config.optimizer.weight_decay,
"gradient_accumulation_steps": self.config.gradient_accumulation_steps,
"max_grad_norm": self.config.max_grad_norm,
"dtype": self.config.dtype.value,
"seed": self.config.seed,
}
# Log optimizer and scheduler config
self.run["hparams/optimizer"] = {
"optimizer_type": self.config.optimizer.optimizer_type.value,
"learning_rate": self.config.optimizer.learning_rate,
"weight_decay": self.config.optimizer.weight_decay,
}
self.run["hparams/scheduler"] = {
"scheduler_type": self.config.scheduler.scheduler_type.value,
"warmup_steps": self.config.scheduler.warmup_steps,
"warmup_ratio": self.config.scheduler.warmup_ratio,
}
# Log dataset config
self.run["hparams/dataset"] = {
"dataset_name": self.config.dataset.dataset_name,
"split": self.config.dataset.split,
"max_samples": self.config.dataset.max_samples,
}
# Log mode
self.run["hparams/mode"] = self.config.mode.value
# Log git hash if available
try:
git_hash = subprocess.check_output(
["git", "rev-parse", "HEAD"],
stderr=subprocess.DEVNULL
).decode().strip()
self.run["hparams/git_hash"] = git_hash
except:
pass
# Log timestamp
self.run["hparams/timestamp"] = datetime.now().isoformat()
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
"""
Log metrics to AimStack.
Args:
metrics: Dict of metric names to values
step: Global step (optional, auto-increments if not provided)
"""
if self.run is None:
return
step = metrics.pop("step", step)
for metric_name, metric_value in metrics.items():
# Flatten nested dicts
if isinstance(metric_value, dict):
for nested_key, nested_val in metric_value.items():
self.run.track(
float(nested_val),
name=f"{metric_name}/{nested_key}",
step=step,
)
else:
try:
self.run.track(
float(metric_value),
name=metric_name,
step=step,
)
except (ValueError, TypeError):
# Skip non-numeric metrics
pass
def log_text(self, name: str, value: str, step: Optional[int] = None):
"""Log text content."""
if self.run is None:
return
# AimStack doesn't have direct text logging, use metadata
metadata = getattr(self.run, '_metadata', {})
if isinstance(metadata, dict):
metadata[name] = value
def finish(self):
"""Finish the run."""
if self.run:
self.run.close()