CompletedP | memory+norm logging | proper MoE with ScatterMoE, update bias, Latent-MoE | Muon experiments | VE for Mamba3 | fix torch recompiles during varlen training
b9f197c
| import torch | |
| from collections import defaultdict | |
| from torch import Tensor | |
| from torch.distributed.tensor import DTensor | |
| from typing import Generator, List, Optional, Union | |
| def to_local(tensor: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]]: | |
| """ | |
| Convert a single DTensor or list of DTensors to local tensors. | |
| This is a no-op for regular tensors. | |
| """ | |
| if isinstance(tensor, Tensor): | |
| return tensor.to_local() if isinstance(tensor, DTensor) else tensor | |
| return [t.to_local() if isinstance(t, DTensor) else t for t in tensor] | |
| def dtensor_from_local( | |
| tensor: Union[Tensor, List[Tensor]], ref: Tensor | |
| ) -> Union[DTensor, List[DTensor]]: | |
| """ | |
| Convert a single local Tensor or list of local Tensors to DTensor. | |
| The reference tensor's device mesh and placements are used to create the DTensor. | |
| if the reference tensor is not a DTensor, we return the input unmodified. | |
| """ | |
| if not isinstance(ref, DTensor): | |
| assert isinstance(ref, Tensor) | |
| return tensor | |
| device_mesh = ref.device_mesh | |
| placements = ref.placements | |
| # If we have a single tensor | |
| if isinstance(tensor, Tensor): | |
| assert not isinstance(tensor, DTensor) | |
| return DTensor.from_local( | |
| tensor, device_mesh=device_mesh, placements=placements | |
| ) | |
| # We have a list of tensors | |
| assert not isinstance(tensor[0], DTensor) | |
| return [ | |
| DTensor.from_local(t, device_mesh=device_mesh, placements=placements) | |
| for t in tensor | |
| ] | |
| def create_param_batches( | |
| params: List[Tensor], batch_size: int | |
| ) -> Generator[List[Tensor], None, None]: | |
| """ | |
| Batch parameters into groups of size `batch_size`. | |
| Tensors in each batch will have identical shape, sharding, and dtype. | |
| """ | |
| # Group parameters by shape, sharding, and dtype | |
| groups = defaultdict(list) | |
| for p in params: | |
| sharding = p.placements if isinstance(p, DTensor) else None | |
| groups[(p.shape, sharding, p.dtype)].append(p) | |
| # Create batches from grouped parameters | |
| for group in groups.values(): | |
| for i in range(0, len(group), batch_size): | |
| batch = group[i : i + batch_size] | |
| yield batch | |
| def pad_batch(batch: List[Tensor], batch_size: int) -> List[Tensor]: | |
| """ | |
| Insert dummy tensors so the batch has exactly `batch_size` elements. | |
| """ | |
| assert len(batch) > 0 | |
| assert len(batch) <= batch_size | |
| while len(batch) < batch_size: | |
| batch.append(torch.empty_like(batch[0])) | |
| return batch | |
| class AsyncTask: | |
| """ | |
| AsyncTask wraps a Python generator to run until the next yield statement. | |
| This is used to allow other tasks to run while waiting for distributed operations. | |
| """ | |
| def __init__(self, generator: Generator[None, None, None]): | |
| self._generator = generator | |
| self.run() # Start running the generator | |
| def run(self) -> bool: | |
| # Run the next step of the async task. | |
| # Returns True if the task is still running and False if completed. | |
| try: | |
| next(self._generator) | |
| return True | |
| except StopIteration: | |
| pass | |
| return False | |
| class AsyncRuntime: | |
| """ | |
| Event loop for running multiple async tasks concurrently. | |
| """ | |
| def __init__( | |
| self, task_gen: Generator["AsyncTask", None, None], max_concurrent_tasks: int | |
| ): | |
| # Initialize runtime with a generator that produces AsyncTask objects | |
| if max_concurrent_tasks <= 0: | |
| raise ValueError(f"{max_concurrent_tasks=} cannot be <= 0") | |
| self._task_gen = task_gen | |
| self._max_concurrent_tasks = max_concurrent_tasks | |
| def _get_next_task(self) -> Optional["AsyncTask"]: | |
| try: | |
| task = next(self._task_gen) | |
| return task | |
| except StopIteration: | |
| return None | |
| def run(self): | |
| # Run the event loop until all tasks are completed | |
| have_new_tasks = True | |
| previous_tasks: List["AsyncTask"] = [] | |
| while have_new_tasks or previous_tasks: | |
| # See if we can add another task | |
| running_tasks = [] | |
| if have_new_tasks and len(previous_tasks) < self._max_concurrent_tasks: | |
| new_task = self._get_next_task() | |
| if new_task is not None: | |
| # Add new task to the queue | |
| running_tasks.append(new_task) | |
| else: | |
| # No more tasks left | |
| have_new_tasks = False | |
| # Run all previous tasks for one step | |
| for task in previous_tasks: | |
| still_running = task.run() | |
| if still_running: | |
| running_tasks.append(task) | |
| # Update task list for next iteration | |
| previous_tasks = running_tasks | |