import os import random import torch import yaml import json def load_config(config_path: str): """ Load configuration from a YAML or JSON file. """ ext = os.path.splitext(config_path)[1].lower() if ext in (".yaml", ".yml"): with open(config_path, 'r') as f: return yaml.safe_load(f) elif ext == ".json": with open(config_path, 'r') as f: return json.load(f) else: raise ValueError("Unsupported config file format: " + ext) def save_checkpoint(model, optimizer, filepath: str): """ Save model and optimizer state to a file. """ os.makedirs(os.path.dirname(filepath), exist_ok=True) model_state = model.module.state_dict() if hasattr(model, "module") else model.state_dict() optim_state = optimizer.state_dict() if optimizer is not None else None checkpoint = {"model_state": model_state} if optim_state is not None: checkpoint["optimizer_state"] = optim_state torch.save(checkpoint, filepath) def load_checkpoint(model, optimizer, filepath: str, device: str): """ Load model (and optimizer) state from a checkpoint file. """ checkpoint = torch.load(filepath, map_location=device) model_state = checkpoint.get("model_state", checkpoint) optim_state = checkpoint.get("optimizer_state") # Load model parameters if hasattr(model, "module"): model.module.load_state_dict(model_state) else: model.load_state_dict(model_state) # Load optimizer parameters if available if optimizer is not None and optim_state is not None: optimizer.load_state_dict(optim_state) def set_seed(seed: int): """ Set random seed for reproducibility. """ random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def get_device(): """ Return 'cuda' if GPU is available, else 'cpu'. """ return "cuda" if torch.cuda.is_available() else "cpu"