File size: 829 Bytes
26dae50
 
df6b3ac
 
 
26dae50
 
 
 
df6b3ac
26dae50
 
 
 
 
 
df6b3ac
26dae50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""ZeroGPU helpers.

`gpu` is a decorator: on HF Spaces (ZeroGPU) it allocates a GPU per call;
locally (without the `spaces` package) it becomes a transparent no-op, so the
same code runs on the local GPUs and on the Space.
"""
import functools

try:
    import spaces  # available only on ZeroGPU Spaces
    _HAS_SPACES = True
except Exception:
    _HAS_SPACES = False


def gpu(duration: int = 60):
    """Allocate a GPU for `duration`s on the call (ZeroGPU). No-op locally."""
    def decorate(fn):
        if _HAS_SPACES:
            return spaces.GPU(duration=duration)(fn)

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            return fn(*args, **kwargs)

        return wrapper

    return decorate


def device() -> str:
    import torch
    return "cuda" if torch.cuda.is_available() else "cpu"