File size: 4,822 Bytes
b9f197c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import torch
from collections import defaultdict
from torch import Tensor
from torch.distributed.tensor import DTensor
from typing import Generator, List, Optional, Union
def to_local(tensor: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]]:
"""
Convert a single DTensor or list of DTensors to local tensors.
This is a no-op for regular tensors.
"""
if isinstance(tensor, Tensor):
return tensor.to_local() if isinstance(tensor, DTensor) else tensor
return [t.to_local() if isinstance(t, DTensor) else t for t in tensor]
def dtensor_from_local(
tensor: Union[Tensor, List[Tensor]], ref: Tensor
) -> Union[DTensor, List[DTensor]]:
"""
Convert a single local Tensor or list of local Tensors to DTensor.
The reference tensor's device mesh and placements are used to create the DTensor.
if the reference tensor is not a DTensor, we return the input unmodified.
"""
if not isinstance(ref, DTensor):
assert isinstance(ref, Tensor)
return tensor
device_mesh = ref.device_mesh
placements = ref.placements
# If we have a single tensor
if isinstance(tensor, Tensor):
assert not isinstance(tensor, DTensor)
return DTensor.from_local(
tensor, device_mesh=device_mesh, placements=placements
)
# We have a list of tensors
assert not isinstance(tensor[0], DTensor)
return [
DTensor.from_local(t, device_mesh=device_mesh, placements=placements)
for t in tensor
]
def create_param_batches(
params: List[Tensor], batch_size: int
) -> Generator[List[Tensor], None, None]:
"""
Batch parameters into groups of size `batch_size`.
Tensors in each batch will have identical shape, sharding, and dtype.
"""
# Group parameters by shape, sharding, and dtype
groups = defaultdict(list)
for p in params:
sharding = p.placements if isinstance(p, DTensor) else None
groups[(p.shape, sharding, p.dtype)].append(p)
# Create batches from grouped parameters
for group in groups.values():
for i in range(0, len(group), batch_size):
batch = group[i : i + batch_size]
yield batch
def pad_batch(batch: List[Tensor], batch_size: int) -> List[Tensor]:
"""
Insert dummy tensors so the batch has exactly `batch_size` elements.
"""
assert len(batch) > 0
assert len(batch) <= batch_size
while len(batch) < batch_size:
batch.append(torch.empty_like(batch[0]))
return batch
class AsyncTask:
"""
AsyncTask wraps a Python generator to run until the next yield statement.
This is used to allow other tasks to run while waiting for distributed operations.
"""
def __init__(self, generator: Generator[None, None, None]):
self._generator = generator
self.run() # Start running the generator
def run(self) -> bool:
# Run the next step of the async task.
# Returns True if the task is still running and False if completed.
try:
next(self._generator)
return True
except StopIteration:
pass
return False
class AsyncRuntime:
"""
Event loop for running multiple async tasks concurrently.
"""
def __init__(
self, task_gen: Generator["AsyncTask", None, None], max_concurrent_tasks: int
):
# Initialize runtime with a generator that produces AsyncTask objects
if max_concurrent_tasks <= 0:
raise ValueError(f"{max_concurrent_tasks=} cannot be <= 0")
self._task_gen = task_gen
self._max_concurrent_tasks = max_concurrent_tasks
def _get_next_task(self) -> Optional["AsyncTask"]:
try:
task = next(self._task_gen)
return task
except StopIteration:
return None
def run(self):
# Run the event loop until all tasks are completed
have_new_tasks = True
previous_tasks: List["AsyncTask"] = []
while have_new_tasks or previous_tasks:
# See if we can add another task
running_tasks = []
if have_new_tasks and len(previous_tasks) < self._max_concurrent_tasks:
new_task = self._get_next_task()
if new_task is not None:
# Add new task to the queue
running_tasks.append(new_task)
else:
# No more tasks left
have_new_tasks = False
# Run all previous tasks for one step
for task in previous_tasks:
still_running = task.run()
if still_running:
running_tasks.append(task)
# Update task list for next iteration
previous_tasks = running_tasks
|