3d_model / ylff /services /training /distributed_adapter.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Distributed adapter scaffolding (Phase 4).
The plan calls for keeping multi-GPU policy decisions explicit while keeping the
interface stable. This module provides:
- SingleProcessAdapter: always available
- FSDPAdapter: scaffold that raises NotImplementedError where policy decisions are needed
Torch is imported lazily to keep non-training unit tests runnable without torch.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional, Protocol
class DistributedAdapter(Protocol):
def setup(self) -> None:
raise NotImplementedError
@property
def is_distributed(self) -> bool:
raise NotImplementedError
@property
def rank(self) -> int:
raise NotImplementedError
@property
def world_size(self) -> int:
raise NotImplementedError
def barrier(self) -> None:
raise NotImplementedError
def is_main_process(self) -> bool:
raise NotImplementedError
def wrap_model(self, model: Any) -> Any:
raise NotImplementedError
@dataclass(frozen=True)
class SingleProcessAdapter:
def setup(self) -> None:
return
@property
def is_distributed(self) -> bool:
return False
@property
def rank(self) -> int:
return 0
@property
def world_size(self) -> int:
return 1
def barrier(self) -> None:
return
def is_main_process(self) -> bool:
return True
def wrap_model(self, model: Any) -> Any:
return model
@dataclass(frozen=True)
class FSDPAdapter:
"""
FSDP scaffold.
This intentionally raises NotImplementedError for policy decisions:
- wrapping strategy / auto-wrap policy
- sharding plan / activation checkpointing policy
- mixed precision policy details
"""
sharding_strategy: str = "FULL_SHARD"
mixed_precision: str = "bf16"
auto_wrap_policy: Optional[str] = None
def setup(self) -> None:
try:
import torch.distributed as dist # type: ignore
except Exception as e: # pragma: no cover
raise RuntimeError("torch.distributed is required for FSDP") from e
if not dist.is_initialized():
raise NotImplementedError(
"FSDPAdapter requires process-group initialization (e.g. via torchrun)."
)
@property
def is_distributed(self) -> bool:
try:
import torch.distributed as dist # type: ignore
except Exception:
return False
return bool(dist.is_initialized())
@property
def rank(self) -> int:
try:
import torch.distributed as dist # type: ignore
except Exception:
return 0
return int(dist.get_rank()) if dist.is_initialized() else 0
@property
def world_size(self) -> int:
try:
import torch.distributed as dist # type: ignore
except Exception:
return 1
return int(dist.get_world_size()) if dist.is_initialized() else 1
def barrier(self) -> None:
try:
import torch.distributed as dist # type: ignore
except Exception:
return
if dist.is_initialized():
dist.barrier()
def is_main_process(self) -> bool:
return self.rank == 0
def wrap_model(self, model: Any) -> Any:
# Explicit but non-stubbed wrapping: callers can instantiate this adapter with
# the intended policy and get a correctly wrapped model.
from ...utils.fsdp_utils import wrap_model_fsdp
# device_id is optional; torchrun typically sets LOCAL_RANK.
device_id = None
try:
import os
device_id = int(os.environ.get("LOCAL_RANK", "0"))
except Exception:
device_id = None
return wrap_model_fsdp(
model,
sharding_strategy=str(self.sharding_strategy),
mixed_precision=(
str(self.mixed_precision) if self.mixed_precision is not None else None
),
auto_wrap_policy=self.auto_wrap_policy,
device_id=device_id,
)