| import os |
| import sys |
| import logging |
| from pathlib import Path |
| from typing import Optional |
| from datetime import datetime |
|
|
| import torch |
|
|
|
|
| def setup_logging( |
| log_level: str = "INFO", |
| log_file: Optional[str] = None, |
| log_dir: Optional[str] = None, |
| ) -> logging.Logger: |
| logger = logging.getLogger("codsworth") |
| logger.setLevel(getattr(logging, log_level.upper())) |
| |
| logger.handlers.clear() |
| |
| formatter = logging.Formatter( |
| "%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| ) |
| |
| console_handler = logging.StreamHandler(sys.stdout) |
| console_handler.setFormatter(formatter) |
| logger.addHandler(console_handler) |
| |
| if log_file is not None or log_dir is not None: |
| if log_dir is not None: |
| os.makedirs(log_dir, exist_ok=True) |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| log_file = os.path.join(log_dir, f"codsworth_{timestamp}.log") |
| |
| file_handler = logging.FileHandler(log_file) |
| file_handler.setFormatter(formatter) |
| logger.addHandler(file_handler) |
| |
| return logger |
|
|
|
|
| def setup_wandb( |
| project: str = "codsworth", |
| entity: Optional[str] = None, |
| config: Optional[dict] = None, |
| name: Optional[str] = None, |
| notes: Optional[str] = None, |
| tags: Optional[list[str]] = None, |
| resume: bool = False, |
| ) -> Optional["wandb"]: |
| try: |
| import wandb |
| wandb.init( |
| project=project, |
| entity=entity, |
| config=config, |
| name=name, |
| notes=notes, |
| tags=tags, |
| resume=resume, |
| ) |
| return wandb |
| except ImportError: |
| logging.warning("wandb not installed. Run 'pip install wandb' to enable logging.") |
| return None |
|
|
|
|
| def get_device() -> torch.device: |
| if torch.cuda.is_available(): |
| return torch.device("cuda") |
| elif torch.backends.mps.is_available(): |
| return torch.device("mps") |
| return torch.device("cpu") |
|
|
|
|
| def get_device_count() -> int: |
| if torch.cuda.is_available(): |
| return torch.cuda.device_count() |
| return 1 |
|
|
|
|
| def set_seed(seed: int) -> None: |
| import random |
| import numpy as np |
| |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def count_parameters(model: torch.nn.Module, trainable_only: bool = False) -> int: |
| if trainable_only: |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
| return sum(p.numel() for p in model.parameters()) |
|
|
|
|
| def format_time(seconds: float) -> str: |
| hours = int(seconds // 3600) |
| minutes = int((seconds % 3600) // 60) |
| secs = int(seconds % 60) |
| |
| if hours > 0: |
| return f"{hours}h {minutes}m {secs}s" |
| elif minutes > 0: |
| return f"{minutes}m {secs}s" |
| return f"{secs}s" |
|
|
|
|
| def format_memory(bytes: int) -> str: |
| for unit in ["B", "KB", "MB", "GB", "TB"]: |
| if bytes < 1024: |
| return f"{bytes:.2f} {unit}" |
| bytes /= 1024 |
| return f"{bytes:.2f} PB" |
|
|
|
|
| def get_model_size(model: torch.nn.Module) -> dict: |
| param_size = 0 |
| buffer_size = 0 |
| |
| for param in model.parameters(): |
| param_size += param.nelement() * param.element_size() |
| |
| for buffer in model.buffers(): |
| buffer_size += buffer.nelement() * buffer.element_size() |
| |
| total_size = param_size + buffer_size |
| |
| return { |
| "param_size": param_size, |
| "buffer_size": buffer_size, |
| "total_size": total_size, |
| "param_size_formatted": format_memory(param_size), |
| "buffer_size_formatted": format_memory(buffer_size), |
| "total_size_formatted": format_memory(total_size), |
| } |
|
|
|
|
| def load_checkpoint( |
| model: torch.nn.Module, |
| checkpoint_path: str, |
| device: torch.device = None, |
| strict: bool = True, |
| ) -> dict: |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| |
| if "model_state_dict" in checkpoint: |
| model.load_state_dict(checkpoint["model_state_dict"], strict=strict) |
| else: |
| model.load_state_dict(checkpoint, strict=strict) |
| |
| return checkpoint |
|
|
|
|
| def save_checkpoint( |
| model: torch.nn.Module, |
| optimizer: torch.optim.Optimizer, |
| scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, |
| epoch: int = 0, |
| step: int = 0, |
| loss: float = 0.0, |
| metrics: Optional[dict] = None, |
| path: str = "checkpoint.pt", |
| ) -> None: |
| checkpoint = { |
| "epoch": epoch, |
| "step": step, |
| "loss": loss, |
| "model_state_dict": model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| } |
| |
| if scheduler is not None: |
| checkpoint["scheduler_state_dict"] = scheduler.state_dict() |
| |
| if metrics is not None: |
| checkpoint["metrics"] = metrics |
| |
| os.makedirs(os.path.dirname(path) or ".", exist_ok=True) |
| torch.save(checkpoint, path) |
|
|
|
|
| def ensure_dir(path: str) -> None: |
| Path(path).mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]: |
| checkpoints = list(Path(checkpoint_dir).glob("checkpoint_*.pt")) |
| |
| if not checkpoints: |
| return None |
| |
| return max(checkpoints, key=lambda p: p.stat().st_mtime).as_posix() |
|
|
|
|
| class AverageMeter: |
| def __init__(self, name: str = "metric"): |
| self.name = name |
| self.reset() |
| |
| def reset(self): |
| self.val = 0.0 |
| self.avg = 0.0 |
| self.sum = 0.0 |
| self.count = 0 |
| |
| def update(self, val: float, n: int = 1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
| |
| def __str__(self) -> str: |
| return f"{self.name}: {self.avg:.4f} (current: {self.val:.4f})" |
|
|
|
|
| class Timer: |
| def __init__(self): |
| self.start_time = None |
| self.elapsed = 0.0 |
| |
| def start(self): |
| import time |
| self.start_time = time.time() |
| |
| def stop(self): |
| import time |
| if self.start_time is not None: |
| self.elapsed = time.time() - self.start_time |
| self.start_time = None |
| return self.elapsed |
| |
| def __enter__(self): |
| self.start() |
| return self |
| |
| def __exit__(self, *args): |
| self.stop() |