Spaces:
Paused
Paused
Update train.py
Browse files
train.py
CHANGED
|
@@ -5,6 +5,7 @@ from datasets import Dataset
|
|
| 5 |
from typing import cast
|
| 6 |
import os
|
| 7 |
import shutil
|
|
|
|
| 8 |
from torch.utils.data import DataLoader
|
| 9 |
from PIL import Image
|
| 10 |
from functools import partial
|
|
@@ -71,8 +72,8 @@ def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, out
|
|
| 71 |
print(f"[GPU {gpu_id}] Loading model...", flush=True)
|
| 72 |
processor, model = load_model(model_name, gpu_id)
|
| 73 |
|
| 74 |
-
print(f"[GPU {gpu_id}]
|
| 75 |
-
loaded = datasets.load_dataset(input_dataset, split=f"train[{start}:{end}]"
|
| 76 |
|
| 77 |
if isinstance(loaded, datasets.DatasetDict):
|
| 78 |
shard = cast(Dataset, loaded["train"])
|
|
@@ -133,6 +134,8 @@ def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, out
|
|
| 133 |
|
| 134 |
|
| 135 |
def main():
|
|
|
|
|
|
|
| 136 |
input_dataset = "none-yet/anime-captions"
|
| 137 |
output_dataset = "nroggendorff/anime-captions"
|
| 138 |
model_name = "datalab-to/chandra"
|
|
@@ -154,31 +157,33 @@ def main():
|
|
| 154 |
print(f"Using {num_gpus} GPUs")
|
| 155 |
print(f"Shard size: {shard_size}")
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
temp_files = []
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
print("\nAll processes completed. Loading and concatenating results...")
|
| 184 |
|
|
|
|
| 5 |
from typing import cast
|
| 6 |
import os
|
| 7 |
import shutil
|
| 8 |
+
import multiprocessing as mp
|
| 9 |
from torch.utils.data import DataLoader
|
| 10 |
from PIL import Image
|
| 11 |
from functools import partial
|
|
|
|
| 72 |
print(f"[GPU {gpu_id}] Loading model...", flush=True)
|
| 73 |
processor, model = load_model(model_name, gpu_id)
|
| 74 |
|
| 75 |
+
print(f"[GPU {gpu_id}] Loading data shard [{start}:{end}]...", flush=True)
|
| 76 |
+
loaded = datasets.load_dataset(input_dataset, split=f"train[{start}:{end}]")
|
| 77 |
|
| 78 |
if isinstance(loaded, datasets.DatasetDict):
|
| 79 |
shard = cast(Dataset, loaded["train"])
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
def main():
|
| 137 |
+
mp.set_start_method('spawn', force=True)
|
| 138 |
+
|
| 139 |
input_dataset = "none-yet/anime-captions"
|
| 140 |
output_dataset = "nroggendorff/anime-captions"
|
| 141 |
model_name = "datalab-to/chandra"
|
|
|
|
| 157 |
print(f"Using {num_gpus} GPUs")
|
| 158 |
print(f"Shard size: {shard_size}")
|
| 159 |
|
| 160 |
+
processes = []
|
|
|
|
| 161 |
temp_files = []
|
| 162 |
+
|
| 163 |
+
for i in range(num_gpus):
|
| 164 |
+
start = i * shard_size
|
| 165 |
+
end = start + shard_size if i < num_gpus - 1 else total_size
|
| 166 |
+
output_file = f"temp_shard_{i}"
|
| 167 |
+
temp_files.append(output_file)
|
| 168 |
+
|
| 169 |
+
p = mp.Process(
|
| 170 |
+
target=process_shard,
|
| 171 |
+
args=(i, start, end, model_name, batch_size, input_dataset, output_file),
|
| 172 |
+
)
|
| 173 |
+
p.start()
|
| 174 |
+
processes.append(p)
|
| 175 |
+
|
| 176 |
+
for p in processes:
|
| 177 |
+
p.join()
|
| 178 |
+
if p.exitcode != 0:
|
| 179 |
+
print(f"\nProcess failed with exit code {p.exitcode}", flush=True)
|
| 180 |
+
print("Terminating all processes...", flush=True)
|
| 181 |
+
for proc in processes:
|
| 182 |
+
if proc.is_alive():
|
| 183 |
+
proc.terminate()
|
| 184 |
+
for proc in processes:
|
| 185 |
+
proc.join()
|
| 186 |
+
raise RuntimeError(f"At least one process failed")
|
| 187 |
|
| 188 |
print("\nAll processes completed. Loading and concatenating results...")
|
| 189 |
|