File size: 2,342 Bytes
604e535 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | from __future__ import annotations
import json
import math
import os
import random
from pathlib import Path
from typing import Any
import numpy as np
import torch
import yaml
def ensure_dir(path: str | os.PathLike[str]) -> Path:
p = Path(path)
p.mkdir(parents=True, exist_ok=True)
return p
def load_yaml(path: str | os.PathLike[str]) -> dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def save_json(obj: Any, path: str | os.PathLike[str]) -> None:
path = Path(path)
ensure_dir(path.parent)
with open(path, "w", encoding="utf-8") as f:
json.dump(obj, f, indent=2, sort_keys=True)
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def wrap_angle(theta: float | np.ndarray) -> float | np.ndarray:
return (theta + math.pi) % (2.0 * math.pi) - math.pi
def angle_error(pred: np.ndarray, target: np.ndarray) -> np.ndarray:
return np.arctan2(np.sin(pred - target), np.cos(pred - target))
def obs_from_state(state: np.ndarray) -> np.ndarray:
theta = state[..., 2]
return np.stack(
[state[..., 0], state[..., 1], np.cos(theta), np.sin(theta)],
axis=-1,
).astype(np.float32)
def state_dim_for_action_dim(action_dim: int) -> int:
return 6 + action_dim
def pad_action(action: np.ndarray, max_dim: int = 3) -> np.ndarray:
out = np.zeros(max_dim, dtype=np.float32)
out[: action.shape[-1]] = action
return out
def unpad_action(action: np.ndarray, action_dim: int) -> np.ndarray:
return np.asarray(action[..., :action_dim], dtype=np.float32)
def device_from_arg(device: str | None = None) -> torch.device:
if device:
return torch.device(device)
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def configure_torch_runtime() -> None:
"""Prefer stable kernels on local ROCm builds.
The local rocm712 environment can expose RDNA targets where MIOpen's GRU
kernel compilation fails. Disabling the cuDNN/MIOpen RNN backend keeps the
same ROCm device while using PyTorch's native kernels.
"""
if torch.cuda.is_available() and getattr(torch.version, "hip", None):
torch.backends.cudnn.enabled = False
|