dragon / optimizers /opt_utils.py
alexandretl's picture
CompletedP | memory+norm logging | proper MoE with ScatterMoE, update bias, Latent-MoE | Muon experiments | VE for Mamba3 | fix torch recompiles during varlen training
b9f197c
import torch
from collections import defaultdict
from torch import Tensor
from torch.distributed.tensor import DTensor
from typing import Generator, List, Optional, Union
def to_local(tensor: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]]:
"""
Convert a single DTensor or list of DTensors to local tensors.
This is a no-op for regular tensors.
"""
if isinstance(tensor, Tensor):
return tensor.to_local() if isinstance(tensor, DTensor) else tensor
return [t.to_local() if isinstance(t, DTensor) else t for t in tensor]
def dtensor_from_local(
tensor: Union[Tensor, List[Tensor]], ref: Tensor
) -> Union[DTensor, List[DTensor]]:
"""
Convert a single local Tensor or list of local Tensors to DTensor.
The reference tensor's device mesh and placements are used to create the DTensor.
if the reference tensor is not a DTensor, we return the input unmodified.
"""
if not isinstance(ref, DTensor):
assert isinstance(ref, Tensor)
return tensor
device_mesh = ref.device_mesh
placements = ref.placements
# If we have a single tensor
if isinstance(tensor, Tensor):
assert not isinstance(tensor, DTensor)
return DTensor.from_local(
tensor, device_mesh=device_mesh, placements=placements
)
# We have a list of tensors
assert not isinstance(tensor[0], DTensor)
return [
DTensor.from_local(t, device_mesh=device_mesh, placements=placements)
for t in tensor
]
def create_param_batches(
params: List[Tensor], batch_size: int
) -> Generator[List[Tensor], None, None]:
"""
Batch parameters into groups of size `batch_size`.
Tensors in each batch will have identical shape, sharding, and dtype.
"""
# Group parameters by shape, sharding, and dtype
groups = defaultdict(list)
for p in params:
sharding = p.placements if isinstance(p, DTensor) else None
groups[(p.shape, sharding, p.dtype)].append(p)
# Create batches from grouped parameters
for group in groups.values():
for i in range(0, len(group), batch_size):
batch = group[i : i + batch_size]
yield batch
def pad_batch(batch: List[Tensor], batch_size: int) -> List[Tensor]:
"""
Insert dummy tensors so the batch has exactly `batch_size` elements.
"""
assert len(batch) > 0
assert len(batch) <= batch_size
while len(batch) < batch_size:
batch.append(torch.empty_like(batch[0]))
return batch
class AsyncTask:
"""
AsyncTask wraps a Python generator to run until the next yield statement.
This is used to allow other tasks to run while waiting for distributed operations.
"""
def __init__(self, generator: Generator[None, None, None]):
self._generator = generator
self.run() # Start running the generator
def run(self) -> bool:
# Run the next step of the async task.
# Returns True if the task is still running and False if completed.
try:
next(self._generator)
return True
except StopIteration:
pass
return False
class AsyncRuntime:
"""
Event loop for running multiple async tasks concurrently.
"""
def __init__(
self, task_gen: Generator["AsyncTask", None, None], max_concurrent_tasks: int
):
# Initialize runtime with a generator that produces AsyncTask objects
if max_concurrent_tasks <= 0:
raise ValueError(f"{max_concurrent_tasks=} cannot be <= 0")
self._task_gen = task_gen
self._max_concurrent_tasks = max_concurrent_tasks
def _get_next_task(self) -> Optional["AsyncTask"]:
try:
task = next(self._task_gen)
return task
except StopIteration:
return None
def run(self):
# Run the event loop until all tasks are completed
have_new_tasks = True
previous_tasks: List["AsyncTask"] = []
while have_new_tasks or previous_tasks:
# See if we can add another task
running_tasks = []
if have_new_tasks and len(previous_tasks) < self._max_concurrent_tasks:
new_task = self._get_next_task()
if new_task is not None:
# Add new task to the queue
running_tasks.append(new_task)
else:
# No more tasks left
have_new_tasks = False
# Run all previous tasks for one step
for task in previous_tasks:
still_running = task.run()
if still_running:
running_tasks.append(task)
# Update task list for next iteration
previous_tasks = running_tasks