|
|
from ..modules import Module |
|
|
from typing import Any, Optional |
|
|
from .common_types import _devices_t, _device_t |
|
|
|
|
|
|
|
|
class DistributedDataParallel(Module): |
|
|
process_group: Any = ... |
|
|
dim: int = ... |
|
|
module: Module = ... |
|
|
device_ids: _devices_t = ... |
|
|
output_device: _device_t = ... |
|
|
broadcast_buffers: bool = ... |
|
|
check_reduction: bool = ... |
|
|
broadcast_bucket_size: float = ... |
|
|
bucket_bytes_cap: float = ... |
|
|
|
|
|
|
|
|
def __init__(self, module: Module, device_ids: Optional[_devices_t] = ..., |
|
|
output_device: Optional[_device_t] = ..., dim: int = ..., |
|
|
broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ..., |
|
|
find_unused_parameters: bool = ..., check_reduction: bool = ...) -> None: ... |
|
|
|