FlowMo-WM / driftwm /utils.py
cccat6's picture
Initial FlowMo-WM public code release
604e535 verified
from __future__ import annotations
import json
import math
import os
import random
from pathlib import Path
from typing import Any
import numpy as np
import torch
import yaml
def ensure_dir(path: str | os.PathLike[str]) -> Path:
p = Path(path)
p.mkdir(parents=True, exist_ok=True)
return p
def load_yaml(path: str | os.PathLike[str]) -> dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def save_json(obj: Any, path: str | os.PathLike[str]) -> None:
path = Path(path)
ensure_dir(path.parent)
with open(path, "w", encoding="utf-8") as f:
json.dump(obj, f, indent=2, sort_keys=True)
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def wrap_angle(theta: float | np.ndarray) -> float | np.ndarray:
return (theta + math.pi) % (2.0 * math.pi) - math.pi
def angle_error(pred: np.ndarray, target: np.ndarray) -> np.ndarray:
return np.arctan2(np.sin(pred - target), np.cos(pred - target))
def obs_from_state(state: np.ndarray) -> np.ndarray:
theta = state[..., 2]
return np.stack(
[state[..., 0], state[..., 1], np.cos(theta), np.sin(theta)],
axis=-1,
).astype(np.float32)
def state_dim_for_action_dim(action_dim: int) -> int:
return 6 + action_dim
def pad_action(action: np.ndarray, max_dim: int = 3) -> np.ndarray:
out = np.zeros(max_dim, dtype=np.float32)
out[: action.shape[-1]] = action
return out
def unpad_action(action: np.ndarray, action_dim: int) -> np.ndarray:
return np.asarray(action[..., :action_dim], dtype=np.float32)
def device_from_arg(device: str | None = None) -> torch.device:
if device:
return torch.device(device)
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def configure_torch_runtime() -> None:
"""Prefer stable kernels on local ROCm builds.
The local rocm712 environment can expose RDNA targets where MIOpen's GRU
kernel compilation fails. Disabling the cuDNN/MIOpen RNN backend keeps the
same ROCm device while using PyTorch's native kernels.
"""
if torch.cuda.is_available() and getattr(torch.version, "hip", None):
torch.backends.cudnn.enabled = False