nroggendorff commited on
Commit
3661d37
·
verified ·
1 Parent(s): e105139

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +31 -26
train.py CHANGED
@@ -5,6 +5,7 @@ from datasets import Dataset
5
  from typing import cast
6
  import os
7
  import shutil
 
8
  from torch.utils.data import DataLoader
9
  from PIL import Image
10
  from functools import partial
@@ -71,8 +72,8 @@ def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, out
71
  print(f"[GPU {gpu_id}] Loading model...", flush=True)
72
  processor, model = load_model(model_name, gpu_id)
73
 
74
- print(f"[GPU {gpu_id}] Streaming data shard [{start}:{end}]...", flush=True)
75
- loaded = datasets.load_dataset(input_dataset, split=f"train[{start}:{end}]", streaming=False)
76
 
77
  if isinstance(loaded, datasets.DatasetDict):
78
  shard = cast(Dataset, loaded["train"])
@@ -133,6 +134,8 @@ def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, out
133
 
134
 
135
  def main():
 
 
136
  input_dataset = "none-yet/anime-captions"
137
  output_dataset = "nroggendorff/anime-captions"
138
  model_name = "datalab-to/chandra"
@@ -154,31 +157,33 @@ def main():
154
  print(f"Using {num_gpus} GPUs")
155
  print(f"Shard size: {shard_size}")
156
 
157
- import concurrent.futures
158
-
159
  temp_files = []
160
-
161
- with concurrent.futures.ProcessPoolExecutor(max_workers=num_gpus) as executor:
162
- futures = []
163
- for i in range(num_gpus):
164
- start = i * shard_size
165
- end = start + shard_size if i < num_gpus - 1 else total_size
166
- output_file = f"temp_shard_{i}"
167
- temp_files.append(output_file)
168
-
169
- future = executor.submit(
170
- process_shard,
171
- i, start, end, model_name, batch_size, input_dataset, output_file
172
- )
173
- futures.append(future)
174
-
175
- for future in concurrent.futures.as_completed(futures):
176
- try:
177
- future.result()
178
- except Exception as e:
179
- print(f"Process failed with error: {e}", flush=True)
180
- executor.shutdown(wait=False, cancel_futures=True)
181
- raise
 
 
 
182
 
183
  print("\nAll processes completed. Loading and concatenating results...")
184
 
 
5
  from typing import cast
6
  import os
7
  import shutil
8
+ import multiprocessing as mp
9
  from torch.utils.data import DataLoader
10
  from PIL import Image
11
  from functools import partial
 
72
  print(f"[GPU {gpu_id}] Loading model...", flush=True)
73
  processor, model = load_model(model_name, gpu_id)
74
 
75
+ print(f"[GPU {gpu_id}] Loading data shard [{start}:{end}]...", flush=True)
76
+ loaded = datasets.load_dataset(input_dataset, split=f"train[{start}:{end}]")
77
 
78
  if isinstance(loaded, datasets.DatasetDict):
79
  shard = cast(Dataset, loaded["train"])
 
134
 
135
 
136
  def main():
137
+ mp.set_start_method('spawn', force=True)
138
+
139
  input_dataset = "none-yet/anime-captions"
140
  output_dataset = "nroggendorff/anime-captions"
141
  model_name = "datalab-to/chandra"
 
157
  print(f"Using {num_gpus} GPUs")
158
  print(f"Shard size: {shard_size}")
159
 
160
+ processes = []
 
161
  temp_files = []
162
+
163
+ for i in range(num_gpus):
164
+ start = i * shard_size
165
+ end = start + shard_size if i < num_gpus - 1 else total_size
166
+ output_file = f"temp_shard_{i}"
167
+ temp_files.append(output_file)
168
+
169
+ p = mp.Process(
170
+ target=process_shard,
171
+ args=(i, start, end, model_name, batch_size, input_dataset, output_file),
172
+ )
173
+ p.start()
174
+ processes.append(p)
175
+
176
+ for p in processes:
177
+ p.join()
178
+ if p.exitcode != 0:
179
+ print(f"\nProcess failed with exit code {p.exitcode}", flush=True)
180
+ print("Terminating all processes...", flush=True)
181
+ for proc in processes:
182
+ if proc.is_alive():
183
+ proc.terminate()
184
+ for proc in processes:
185
+ proc.join()
186
+ raise RuntimeError(f"At least one process failed")
187
 
188
  print("\nAll processes completed. Loading and concatenating results...")
189