| """Utility helpers used by the legacy-style RNN generator.""" | |
| from __future__ import annotations | |
| import numpy as np | |
| import torch | |
| def variable(tensor: torch.Tensor | np.ndarray) -> torch.Tensor: | |
| """Return a tensor on GPU when available.""" | |
| if isinstance(tensor, np.ndarray): | |
| tensor = torch.from_numpy(tensor) | |
| if torch.cuda.is_available(): | |
| return tensor.cuda() | |
| return tensor | |