File size: 409 Bytes
3f4ebee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | """Utility helpers used by the legacy-style RNN generator."""
from __future__ import annotations
import numpy as np
import torch
def variable(tensor: torch.Tensor | np.ndarray) -> torch.Tensor:
"""Return a tensor on GPU when available."""
if isinstance(tensor, np.ndarray):
tensor = torch.from_numpy(tensor)
if torch.cuda.is_available():
return tensor.cuda()
return tensor
|