from __future__ import annotations import json import math import os import random from pathlib import Path from typing import Any import numpy as np import torch import yaml def ensure_dir(path: str | os.PathLike[str]) -> Path: p = Path(path) p.mkdir(parents=True, exist_ok=True) return p def load_yaml(path: str | os.PathLike[str]) -> dict[str, Any]: with open(path, "r", encoding="utf-8") as f: return yaml.safe_load(f) def save_json(obj: Any, path: str | os.PathLike[str]) -> None: path = Path(path) ensure_dir(path.parent) with open(path, "w", encoding="utf-8") as f: json.dump(obj, f, indent=2, sort_keys=True) def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def wrap_angle(theta: float | np.ndarray) -> float | np.ndarray: return (theta + math.pi) % (2.0 * math.pi) - math.pi def angle_error(pred: np.ndarray, target: np.ndarray) -> np.ndarray: return np.arctan2(np.sin(pred - target), np.cos(pred - target)) def obs_from_state(state: np.ndarray) -> np.ndarray: theta = state[..., 2] return np.stack( [state[..., 0], state[..., 1], np.cos(theta), np.sin(theta)], axis=-1, ).astype(np.float32) def state_dim_for_action_dim(action_dim: int) -> int: return 6 + action_dim def pad_action(action: np.ndarray, max_dim: int = 3) -> np.ndarray: out = np.zeros(max_dim, dtype=np.float32) out[: action.shape[-1]] = action return out def unpad_action(action: np.ndarray, action_dim: int) -> np.ndarray: return np.asarray(action[..., :action_dim], dtype=np.float32) def device_from_arg(device: str | None = None) -> torch.device: if device: return torch.device(device) return torch.device("cuda" if torch.cuda.is_available() else "cpu") def configure_torch_runtime() -> None: """Prefer stable kernels on local ROCm builds. The local rocm712 environment can expose RDNA targets where MIOpen's GRU kernel compilation fails. Disabling the cuDNN/MIOpen RNN backend keeps the same ROCm device while using PyTorch's native kernels. """ if torch.cuda.is_available() and getattr(torch.version, "hip", None): torch.backends.cudnn.enabled = False