refactor(data): replace per-worker seed strategy with full sharding in IterableDataset
Browse files- Add num_shards and shard_index params to _load_dataset()
- Apply ds.shard() before shuffle to eliminate document overlap across workers
- Pass worker_info.num_workers/id from __iter__() to _load_dataset()
- Maintain backward compatibility with single-process (num_workers=0) mode
- Fix .gitignore to unblock llm_lab/data/ from data/ exclusion rule
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- .gitignore +1 -0
- llm_lab/data/dataset.py +32 -8
.gitignore
CHANGED
|
@@ -41,6 +41,7 @@ runs/
|
|
| 41 |
*.csv
|
| 42 |
*.tsv
|
| 43 |
data/
|
|
|
|
| 44 |
|
| 45 |
# Secrets
|
| 46 |
.env
|
|
|
|
| 41 |
*.csv
|
| 42 |
*.tsv
|
| 43 |
data/
|
| 44 |
+
!llm_lab/data/
|
| 45 |
|
| 46 |
# Secrets
|
| 47 |
.env
|
llm_lab/data/dataset.py
CHANGED
|
@@ -44,8 +44,17 @@ class PackedStreamingDataset(IterableDataset):
|
|
| 44 |
self.seed = seed
|
| 45 |
self.max_seq_len = config.max_seq_len
|
| 46 |
|
| 47 |
-
def _load_dataset(self):
|
| 48 |
-
"""HuggingFace 데이터셋을 스트리밍 모드로 로드합니다.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
from datasets import load_dataset
|
| 50 |
|
| 51 |
ds = load_dataset(
|
|
@@ -56,6 +65,11 @@ class PackedStreamingDataset(IterableDataset):
|
|
| 56 |
trust_remote_code=True,
|
| 57 |
)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# 셔플 (스트리밍에서는 버퍼 기반 근사 셔플)
|
| 60 |
ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
|
| 61 |
|
|
@@ -109,21 +123,31 @@ class PackedStreamingDataset(IterableDataset):
|
|
| 109 |
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
| 110 |
"""DataLoader가 호출하는 이터레이터.
|
| 111 |
|
| 112 |
-
멀티 워커 지원:
|
| 113 |
-
-
|
| 114 |
-
- 워커 간
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
"""
|
| 116 |
worker_info = torch.utils.data.get_worker_info()
|
| 117 |
|
| 118 |
if worker_info is not None:
|
| 119 |
-
#
|
|
|
|
|
|
|
| 120 |
worker_seed = self.seed + worker_info.id
|
| 121 |
else:
|
|
|
|
|
|
|
|
|
|
| 122 |
worker_seed = self.seed
|
| 123 |
|
| 124 |
-
# 워커별 시드로 데이터셋 로드
|
| 125 |
self.seed = worker_seed
|
| 126 |
-
dataset = self._load_dataset()
|
| 127 |
|
| 128 |
return self._tokenize_and_pack(dataset)
|
| 129 |
|
|
|
|
| 44 |
self.seed = seed
|
| 45 |
self.max_seq_len = config.max_seq_len
|
| 46 |
|
| 47 |
+
def _load_dataset(self, num_shards: int = 1, shard_index: int = 0):
|
| 48 |
+
"""HuggingFace 데이터셋을 스트리밍 모드로 로드합니다.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
num_shards: 전체 샤드 수 (= DataLoader num_workers)
|
| 52 |
+
shard_index: 이 워커가 담당할 샤드 번호 (0 ~ num_shards-1)
|
| 53 |
+
|
| 54 |
+
샤딩 원리:
|
| 55 |
+
num_shards=4 일 때 스트림을 4등분하여 각 워커가 서로 다른 1/4만 처리.
|
| 56 |
+
셔플은 샤딩 이후에 적용하므로 워커 간 문서 중복이 없음.
|
| 57 |
+
"""
|
| 58 |
from datasets import load_dataset
|
| 59 |
|
| 60 |
ds = load_dataset(
|
|
|
|
| 65 |
trust_remote_code=True,
|
| 66 |
)
|
| 67 |
|
| 68 |
+
# 완전 분할(샤딩): 워커 i는 전체 스트림의 1/num_shards 구간만 처리
|
| 69 |
+
# 반드시 셔플 전에 적용해야 각 워커가 겹치지 않는 문서 집합을 가짐
|
| 70 |
+
if num_shards > 1:
|
| 71 |
+
ds = ds.shard(num_shards=num_shards, index=shard_index)
|
| 72 |
+
|
| 73 |
# 셔플 (스트리밍에서는 버퍼 기반 근사 셔플)
|
| 74 |
ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
|
| 75 |
|
|
|
|
| 123 |
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
| 124 |
"""DataLoader가 호출하는 이터레이터.
|
| 125 |
|
| 126 |
+
멀티 워커 지원 (완전 분할 방식):
|
| 127 |
+
- 이전: 모든 워커가 동일한 스트림을 읽고 시드만 달리함 → 문서 중복 가능
|
| 128 |
+
- 개선: ds.shard()로 스트림을 num_workers등분 → 워커 간 문서 중복 없음
|
| 129 |
+
|
| 130 |
+
예시 (num_workers=4, 전체 문서 N개):
|
| 131 |
+
Worker 0: 문서 0, 4, 8, 12, ... (N/4개)
|
| 132 |
+
Worker 1: 문서 1, 5, 9, 13, ... (N/4개)
|
| 133 |
+
Worker 2: 문서 2, 6, 10, 14, ... (N/4개)
|
| 134 |
+
Worker 3: 문서 3, 7, 11, 15, ... (N/4개)
|
| 135 |
"""
|
| 136 |
worker_info = torch.utils.data.get_worker_info()
|
| 137 |
|
| 138 |
if worker_info is not None:
|
| 139 |
+
# 완전 분할: 워커별 샤드 할당 + 독립적인 셔플 시드
|
| 140 |
+
num_shards = worker_info.num_workers
|
| 141 |
+
shard_index = worker_info.id
|
| 142 |
worker_seed = self.seed + worker_info.id
|
| 143 |
else:
|
| 144 |
+
# 단일 프로세스: 샤딩 없이 전체 스트림 처리
|
| 145 |
+
num_shards = 1
|
| 146 |
+
shard_index = 0
|
| 147 |
worker_seed = self.seed
|
| 148 |
|
|
|
|
| 149 |
self.seed = worker_seed
|
| 150 |
+
dataset = self._load_dataset(num_shards=num_shards, shard_index=shard_index)
|
| 151 |
|
| 152 |
return self._tokenize_and_pack(dataset)
|
| 153 |
|