Vjeong Claude Sonnet 4.6 commited on
Commit
8a39fec
·
1 Parent(s): 8a58ffe

refactor(data): replace per-worker seed strategy with full sharding in IterableDataset

Browse files

- Add num_shards and shard_index params to _load_dataset()
- Apply ds.shard() before shuffle to eliminate document overlap across workers
- Pass worker_info.num_workers/id from __iter__() to _load_dataset()
- Maintain backward compatibility with single-process (num_workers=0) mode
- Fix .gitignore to unblock llm_lab/data/ from data/ exclusion rule

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. .gitignore +1 -0
  2. llm_lab/data/dataset.py +32 -8
.gitignore CHANGED
@@ -41,6 +41,7 @@ runs/
41
  *.csv
42
  *.tsv
43
  data/
 
44
 
45
  # Secrets
46
  .env
 
41
  *.csv
42
  *.tsv
43
  data/
44
+ !llm_lab/data/
45
 
46
  # Secrets
47
  .env
llm_lab/data/dataset.py CHANGED
@@ -44,8 +44,17 @@ class PackedStreamingDataset(IterableDataset):
44
  self.seed = seed
45
  self.max_seq_len = config.max_seq_len
46
 
47
- def _load_dataset(self):
48
- """HuggingFace 데이터셋을 스트리밍 모드로 로드합니다."""
 
 
 
 
 
 
 
 
 
49
  from datasets import load_dataset
50
 
51
  ds = load_dataset(
@@ -56,6 +65,11 @@ class PackedStreamingDataset(IterableDataset):
56
  trust_remote_code=True,
57
  )
58
 
 
 
 
 
 
59
  # 셔플 (스트리밍에서는 버퍼 기반 근사 셔플)
60
  ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
61
 
@@ -109,21 +123,31 @@ class PackedStreamingDataset(IterableDataset):
109
  def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
110
  """DataLoader가 호출하는 이터레이터.
111
 
112
- 멀티 워커 지원:
113
- - 워커가 서로 다른 시드로 셔플된 스트림을
114
- - 워커 간 데이터 중복 최소화
 
 
 
 
 
 
115
  """
116
  worker_info = torch.utils.data.get_worker_info()
117
 
118
  if worker_info is not None:
119
- # 멀티 워커: 워커 다른 시드
 
 
120
  worker_seed = self.seed + worker_info.id
121
  else:
 
 
 
122
  worker_seed = self.seed
123
 
124
- # 워커별 시드로 데이터셋 로드
125
  self.seed = worker_seed
126
- dataset = self._load_dataset()
127
 
128
  return self._tokenize_and_pack(dataset)
129
 
 
44
  self.seed = seed
45
  self.max_seq_len = config.max_seq_len
46
 
47
+ def _load_dataset(self, num_shards: int = 1, shard_index: int = 0):
48
+ """HuggingFace 데이터셋을 스트리밍 모드로 로드합니다.
49
+
50
+ Args:
51
+ num_shards: 전체 샤드 수 (= DataLoader num_workers)
52
+ shard_index: 이 워커가 담당할 샤드 번호 (0 ~ num_shards-1)
53
+
54
+ 샤딩 원리:
55
+ num_shards=4 일 때 스트림을 4등분하여 각 워커가 서로 다른 1/4만 처리.
56
+ 셔플은 샤딩 이후에 적용하므로 워커 간 문서 중복이 없음.
57
+ """
58
  from datasets import load_dataset
59
 
60
  ds = load_dataset(
 
65
  trust_remote_code=True,
66
  )
67
 
68
+ # 완전 분할(샤딩): 워커 i는 전체 스트림의 1/num_shards 구간만 처리
69
+ # 반드시 셔플 전에 적용해야 각 워커가 겹치지 않는 문서 집합을 가짐
70
+ if num_shards > 1:
71
+ ds = ds.shard(num_shards=num_shards, index=shard_index)
72
+
73
  # 셔플 (스트리밍에서는 버퍼 기반 근사 셔플)
74
  ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
75
 
 
123
  def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
124
  """DataLoader가 호출하는 이터레이터.
125
 
126
+ 멀티 워커 지원 (완전 분할 방식):
127
+ - 이전: 모든 워커가 동일한 스트림을 읽고 시드만 달함 → 문서 중복 가능
128
+ - 개선: ds.shard()로 스트림을 num_workers등분 → 워커 간 문서 중복 없음
129
+
130
+ 예시 (num_workers=4, 전체 문서 N개):
131
+ Worker 0: 문서 0, 4, 8, 12, ... (N/4개)
132
+ Worker 1: 문서 1, 5, 9, 13, ... (N/4개)
133
+ Worker 2: 문서 2, 6, 10, 14, ... (N/4개)
134
+ Worker 3: 문서 3, 7, 11, 15, ... (N/4개)
135
  """
136
  worker_info = torch.utils.data.get_worker_info()
137
 
138
  if worker_info is not None:
139
+ # 완전 분할: 워커 샤드 할당 + 독립적인 셔플 시드
140
+ num_shards = worker_info.num_workers
141
+ shard_index = worker_info.id
142
  worker_seed = self.seed + worker_info.id
143
  else:
144
+ # 단일 프로세스: 샤딩 없이 전체 스트림 처리
145
+ num_shards = 1
146
+ shard_index = 0
147
  worker_seed = self.seed
148
 
 
149
  self.seed = worker_seed
150
+ dataset = self._load_dataset(num_shards=num_shards, shard_index=shard_index)
151
 
152
  return self._tokenize_and_pack(dataset)
153