| | |
| | |
| |
|
| | from internlm.utils.logger import get_logger |
| |
|
| | logger = get_logger(__file__) |
| |
|
| |
|
| | def partition_uniform(num_items, pipeline_parallel_size, num_chunks): |
| | assert ( |
| | num_items % num_chunks == 0 |
| | ), "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" |
| |
|
| | parts = [[] for _ in range(pipeline_parallel_size)] |
| | partition_items = num_items // num_chunks |
| | for idx in range(num_chunks): |
| | base_idx = idx * partition_items |
| | chunk_size = partition_items // pipeline_parallel_size |
| | left = pipeline_parallel_size - partition_items % pipeline_parallel_size |
| | if chunk_size == 0: |
| | raise ValueError("Some nodes in Pipeline have no requests") |
| |
|
| | for p in range(pipeline_parallel_size): |
| | st = base_idx |
| | base_idx += chunk_size + (p >= left) |
| | parts[p].append((st, base_idx)) |
| |
|
| | indexes = [] |
| | for _parts in parts: |
| | for s, e in _parts: |
| | indexes.extend(list(range(s, e))) |
| | assert len(indexes) == len(set(indexes)), indexes |
| | assert set(indexes) == set(list(range(num_items))), (indexes, num_items) |
| | return parts |
| |
|