|
|
import threading |
|
|
|
|
|
import torch._C._lazy |
|
|
from torch.utils._pytree import tree_flatten, tree_unflatten |
|
|
|
|
|
from .closure import add_step_closure, run_step_closures |
|
|
|
|
|
|
|
|
def mark_step(device: str = "", wait=False): |
|
|
"""Triggers a mark step, which amounts to |
|
|
- collecting a group of 'live' lazy tensors to index into the compilation cache |
|
|
(lowering/compiling their IR graphs if not cached) |
|
|
- kicking off execution of the compiled function |
|
|
- (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator) |
|
|
""" |
|
|
|
|
|
torch._C._lazy._mark_step(device, [], wait=wait) |
|
|
|
|
|
run_step_closures() |
|
|
|
|
|
|
|
|
def wait_device_ops(devices=None): |
|
|
"""Waits for all the async operations on the given devices to complete. |
|
|
Args: |
|
|
devices (string..., optional): The devices whose async ops need to be waited |
|
|
for. If empty, all the local devices will be waited for. |
|
|
""" |
|
|
if devices is None: |
|
|
devices = [] |
|
|
torch._C._lazy._wait_device_ops(devices=devices) |
|
|
|
|
|
|
|
|
def sync_multi(tensors, devices): |
|
|
""" |
|
|
Sync the list of lazy tensors so there IR get lowered for the activate backend |
|
|
and the compiled computation graph get cached. |
|
|
""" |
|
|
torch._C._lazy._sync_multi(tensors, devices) |
|
|
|
|
|
|
|
|
def get_tensor_id(tensor): |
|
|
"""Return a unique id of the lazy tensor maintained by LTC""" |
|
|
return torch._C._lazy._get_tensor_id(tensor) |
|
|
|
|
|
|
|
|
def to_cpu(tensors, devices=None): |
|
|
devices = devices or ["lazy"] |
|
|
|
|
|
flattened, spec = tree_flatten(tensors) |
|
|
sync_multi(flattened, devices) |
|
|
return tree_unflatten([t.to("cpu") for t in flattened], spec) |
|
|
|
|
|
|
|
|
def save(tensors, *args, **kwargs): |
|
|
torch.save(to_cpu(tensors), *args, **kwargs) |
|
|
|