| """ |
| Distributed adapter scaffolding (Phase 4). |
| |
| The plan calls for keeping multi-GPU policy decisions explicit while keeping the |
| interface stable. This module provides: |
| - SingleProcessAdapter: always available |
| - FSDPAdapter: scaffold that raises NotImplementedError where policy decisions are needed |
| |
| Torch is imported lazily to keep non-training unit tests runnable without torch. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any, Optional, Protocol |
|
|
|
|
| class DistributedAdapter(Protocol): |
| def setup(self) -> None: |
| raise NotImplementedError |
|
|
| @property |
| def is_distributed(self) -> bool: |
| raise NotImplementedError |
|
|
| @property |
| def rank(self) -> int: |
| raise NotImplementedError |
|
|
| @property |
| def world_size(self) -> int: |
| raise NotImplementedError |
|
|
| def barrier(self) -> None: |
| raise NotImplementedError |
|
|
| def is_main_process(self) -> bool: |
| raise NotImplementedError |
|
|
| def wrap_model(self, model: Any) -> Any: |
| raise NotImplementedError |
|
|
|
|
| @dataclass(frozen=True) |
| class SingleProcessAdapter: |
| def setup(self) -> None: |
| return |
|
|
| @property |
| def is_distributed(self) -> bool: |
| return False |
|
|
| @property |
| def rank(self) -> int: |
| return 0 |
|
|
| @property |
| def world_size(self) -> int: |
| return 1 |
|
|
| def barrier(self) -> None: |
| return |
|
|
| def is_main_process(self) -> bool: |
| return True |
|
|
| def wrap_model(self, model: Any) -> Any: |
| return model |
|
|
|
|
| @dataclass(frozen=True) |
| class FSDPAdapter: |
| """ |
| FSDP scaffold. |
| |
| This intentionally raises NotImplementedError for policy decisions: |
| - wrapping strategy / auto-wrap policy |
| - sharding plan / activation checkpointing policy |
| - mixed precision policy details |
| """ |
|
|
| sharding_strategy: str = "FULL_SHARD" |
| mixed_precision: str = "bf16" |
| auto_wrap_policy: Optional[str] = None |
|
|
| def setup(self) -> None: |
| try: |
| import torch.distributed as dist |
| except Exception as e: |
| raise RuntimeError("torch.distributed is required for FSDP") from e |
|
|
| if not dist.is_initialized(): |
| raise NotImplementedError( |
| "FSDPAdapter requires process-group initialization (e.g. via torchrun)." |
| ) |
|
|
| @property |
| def is_distributed(self) -> bool: |
| try: |
| import torch.distributed as dist |
| except Exception: |
| return False |
| return bool(dist.is_initialized()) |
|
|
| @property |
| def rank(self) -> int: |
| try: |
| import torch.distributed as dist |
| except Exception: |
| return 0 |
| return int(dist.get_rank()) if dist.is_initialized() else 0 |
|
|
| @property |
| def world_size(self) -> int: |
| try: |
| import torch.distributed as dist |
| except Exception: |
| return 1 |
| return int(dist.get_world_size()) if dist.is_initialized() else 1 |
|
|
| def barrier(self) -> None: |
| try: |
| import torch.distributed as dist |
| except Exception: |
| return |
| if dist.is_initialized(): |
| dist.barrier() |
|
|
| def is_main_process(self) -> bool: |
| return self.rank == 0 |
|
|
| def wrap_model(self, model: Any) -> Any: |
| |
| |
| from ...utils.fsdp_utils import wrap_model_fsdp |
|
|
| |
| device_id = None |
| try: |
| import os |
|
|
| device_id = int(os.environ.get("LOCAL_RANK", "0")) |
| except Exception: |
| device_id = None |
|
|
| return wrap_model_fsdp( |
| model, |
| sharding_strategy=str(self.sharding_strategy), |
| mixed_precision=( |
| str(self.mixed_precision) if self.mixed_precision is not None else None |
| ), |
| auto_wrap_policy=self.auto_wrap_policy, |
| device_id=device_id, |
| ) |
|
|