File size: 409 Bytes
3f4ebee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""Utility helpers used by the legacy-style RNN generator."""

from __future__ import annotations

import numpy as np
import torch


def variable(tensor: torch.Tensor | np.ndarray) -> torch.Tensor:
    """Return a tensor on GPU when available."""
    if isinstance(tensor, np.ndarray):
        tensor = torch.from_numpy(tensor)
    if torch.cuda.is_available():
        return tensor.cuda()
    return tensor