Jaqshanahan's picture
Initial upload of Codsworth model
b84d85a verified
import os
import sys
import logging
from pathlib import Path
from typing import Optional
from datetime import datetime
import torch
def setup_logging(
log_level: str = "INFO",
log_file: Optional[str] = None,
log_dir: Optional[str] = None,
) -> logging.Logger:
logger = logging.getLogger("codsworth")
logger.setLevel(getattr(logging, log_level.upper()))
logger.handlers.clear()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
if log_file is not None or log_dir is not None:
if log_dir is not None:
os.makedirs(log_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_dir, f"codsworth_{timestamp}.log")
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def setup_wandb(
project: str = "codsworth",
entity: Optional[str] = None,
config: Optional[dict] = None,
name: Optional[str] = None,
notes: Optional[str] = None,
tags: Optional[list[str]] = None,
resume: bool = False,
) -> Optional["wandb"]:
try:
import wandb
wandb.init(
project=project,
entity=entity,
config=config,
name=name,
notes=notes,
tags=tags,
resume=resume,
)
return wandb
except ImportError:
logging.warning("wandb not installed. Run 'pip install wandb' to enable logging.")
return None
def get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def get_device_count() -> int:
if torch.cuda.is_available():
return torch.cuda.device_count()
return 1
def set_seed(seed: int) -> None:
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def count_parameters(model: torch.nn.Module, trainable_only: bool = False) -> int:
if trainable_only:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
return sum(p.numel() for p in model.parameters())
def format_time(seconds: float) -> str:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
if hours > 0:
return f"{hours}h {minutes}m {secs}s"
elif minutes > 0:
return f"{minutes}m {secs}s"
return f"{secs}s"
def format_memory(bytes: int) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if bytes < 1024:
return f"{bytes:.2f} {unit}"
bytes /= 1024
return f"{bytes:.2f} PB"
def get_model_size(model: torch.nn.Module) -> dict:
param_size = 0
buffer_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
total_size = param_size + buffer_size
return {
"param_size": param_size,
"buffer_size": buffer_size,
"total_size": total_size,
"param_size_formatted": format_memory(param_size),
"buffer_size_formatted": format_memory(buffer_size),
"total_size_formatted": format_memory(total_size),
}
def load_checkpoint(
model: torch.nn.Module,
checkpoint_path: str,
device: torch.device = None,
strict: bool = True,
) -> dict:
checkpoint = torch.load(checkpoint_path, map_location=device)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"], strict=strict)
else:
model.load_state_dict(checkpoint, strict=strict)
return checkpoint
def save_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
epoch: int = 0,
step: int = 0,
loss: float = 0.0,
metrics: Optional[dict] = None,
path: str = "checkpoint.pt",
) -> None:
checkpoint = {
"epoch": epoch,
"step": step,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
if scheduler is not None:
checkpoint["scheduler_state_dict"] = scheduler.state_dict()
if metrics is not None:
checkpoint["metrics"] = metrics
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
torch.save(checkpoint, path)
def ensure_dir(path: str) -> None:
Path(path).mkdir(parents=True, exist_ok=True)
def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]:
checkpoints = list(Path(checkpoint_dir).glob("checkpoint_*.pt"))
if not checkpoints:
return None
return max(checkpoints, key=lambda p: p.stat().st_mtime).as_posix()
class AverageMeter:
def __init__(self, name: str = "metric"):
self.name = name
self.reset()
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0
def update(self, val: float, n: int = 1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self) -> str:
return f"{self.name}: {self.avg:.4f} (current: {self.val:.4f})"
class Timer:
def __init__(self):
self.start_time = None
self.elapsed = 0.0
def start(self):
import time
self.start_time = time.time()
def stop(self):
import time
if self.start_time is not None:
self.elapsed = time.time() - self.start_time
self.start_time = None
return self.elapsed
def __enter__(self):
self.start()
return self
def __exit__(self, *args):
self.stop()