import os import sys import logging from pathlib import Path from typing import Optional from datetime import datetime import torch def setup_logging( log_level: str = "INFO", log_file: Optional[str] = None, log_dir: Optional[str] = None, ) -> logging.Logger: logger = logging.getLogger("codsworth") logger.setLevel(getattr(logging, log_level.upper())) logger.handlers.clear() formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(formatter) logger.addHandler(console_handler) if log_file is not None or log_dir is not None: if log_dir is not None: os.makedirs(log_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = os.path.join(log_dir, f"codsworth_{timestamp}.log") file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger def setup_wandb( project: str = "codsworth", entity: Optional[str] = None, config: Optional[dict] = None, name: Optional[str] = None, notes: Optional[str] = None, tags: Optional[list[str]] = None, resume: bool = False, ) -> Optional["wandb"]: try: import wandb wandb.init( project=project, entity=entity, config=config, name=name, notes=notes, tags=tags, resume=resume, ) return wandb except ImportError: logging.warning("wandb not installed. Run 'pip install wandb' to enable logging.") return None def get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") elif torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def get_device_count() -> int: if torch.cuda.is_available(): return torch.cuda.device_count() return 1 def set_seed(seed: int) -> None: import random import numpy as np random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def count_parameters(model: torch.nn.Module, trainable_only: bool = False) -> int: if trainable_only: return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters()) def format_time(seconds: float) -> str: hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) secs = int(seconds % 60) if hours > 0: return f"{hours}h {minutes}m {secs}s" elif minutes > 0: return f"{minutes}m {secs}s" return f"{secs}s" def format_memory(bytes: int) -> str: for unit in ["B", "KB", "MB", "GB", "TB"]: if bytes < 1024: return f"{bytes:.2f} {unit}" bytes /= 1024 return f"{bytes:.2f} PB" def get_model_size(model: torch.nn.Module) -> dict: param_size = 0 buffer_size = 0 for param in model.parameters(): param_size += param.nelement() * param.element_size() for buffer in model.buffers(): buffer_size += buffer.nelement() * buffer.element_size() total_size = param_size + buffer_size return { "param_size": param_size, "buffer_size": buffer_size, "total_size": total_size, "param_size_formatted": format_memory(param_size), "buffer_size_formatted": format_memory(buffer_size), "total_size_formatted": format_memory(total_size), } def load_checkpoint( model: torch.nn.Module, checkpoint_path: str, device: torch.device = None, strict: bool = True, ) -> dict: checkpoint = torch.load(checkpoint_path, map_location=device) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"], strict=strict) else: model.load_state_dict(checkpoint, strict=strict) return checkpoint def save_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, epoch: int = 0, step: int = 0, loss: float = 0.0, metrics: Optional[dict] = None, path: str = "checkpoint.pt", ) -> None: checkpoint = { "epoch": epoch, "step": step, "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), } if scheduler is not None: checkpoint["scheduler_state_dict"] = scheduler.state_dict() if metrics is not None: checkpoint["metrics"] = metrics os.makedirs(os.path.dirname(path) or ".", exist_ok=True) torch.save(checkpoint, path) def ensure_dir(path: str) -> None: Path(path).mkdir(parents=True, exist_ok=True) def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]: checkpoints = list(Path(checkpoint_dir).glob("checkpoint_*.pt")) if not checkpoints: return None return max(checkpoints, key=lambda p: p.stat().st_mtime).as_posix() class AverageMeter: def __init__(self, name: str = "metric"): self.name = name self.reset() def reset(self): self.val = 0.0 self.avg = 0.0 self.sum = 0.0 self.count = 0 def update(self, val: float, n: int = 1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self) -> str: return f"{self.name}: {self.avg:.4f} (current: {self.val:.4f})" class Timer: def __init__(self): self.start_time = None self.elapsed = 0.0 def start(self): import time self.start_time = time.time() def stop(self): import time if self.start_time is not None: self.elapsed = time.time() - self.start_time self.start_time = None return self.elapsed def __enter__(self): self.start() return self def __exit__(self, *args): self.stop()