Spaces:
Paused
Paused
Update train.py
Browse files
train.py
CHANGED
|
@@ -115,34 +115,35 @@ def process_shard_worker(args):
|
|
| 115 |
|
| 116 |
return result
|
| 117 |
|
| 118 |
-
|
|
|
|
| 119 |
|
| 120 |
-
input_dataset = "none-yet/anime-captions"
|
| 121 |
-
output_dataset = "
|
| 122 |
-
model_name = "datalab-to/chandra"
|
| 123 |
|
| 124 |
-
loaded = datasets.load_dataset(input_dataset, split="train")
|
| 125 |
|
| 126 |
-
if isinstance(loaded, datasets.DatasetDict):
|
| 127 |
-
|
| 128 |
-
else:
|
| 129 |
-
|
| 130 |
|
| 131 |
-
num_gpus = torch.cuda.device_count()
|
| 132 |
-
batch_size =
|
| 133 |
-
total_size = len(ds)
|
| 134 |
-
shard_size = total_size // num_gpus
|
| 135 |
|
| 136 |
-
worker_args = []
|
| 137 |
-
for i in range(num_gpus):
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
|
| 142 |
-
with mp.Pool(processes=num_gpus) as pool:
|
| 143 |
-
|
| 144 |
|
| 145 |
-
ds = datasets.concatenate_datasets(shards)
|
| 146 |
|
| 147 |
# %%
|
| 148 |
ds.push_to_hub(output_dataset)
|
|
|
|
| 115 |
|
| 116 |
return result
|
| 117 |
|
| 118 |
+
if __name__ == '__main__':
|
| 119 |
+
mp.set_start_method("spawn", force=True)
|
| 120 |
|
| 121 |
+
input_dataset = "none-yet/anime-captions"
|
| 122 |
+
output_dataset = "none-yet/anime-captions"
|
| 123 |
+
model_name = "datalab-to/chandra"
|
| 124 |
|
| 125 |
+
loaded = datasets.load_dataset(input_dataset, split="train")
|
| 126 |
|
| 127 |
+
if isinstance(loaded, datasets.DatasetDict):
|
| 128 |
+
ds = cast(Dataset, loaded["train"])
|
| 129 |
+
else:
|
| 130 |
+
ds = cast(Dataset, loaded)
|
| 131 |
|
| 132 |
+
num_gpus = torch.cuda.device_count()
|
| 133 |
+
batch_size = 8
|
| 134 |
+
total_size = len(ds)
|
| 135 |
+
shard_size = total_size // num_gpus
|
| 136 |
|
| 137 |
+
worker_args = []
|
| 138 |
+
for i in range(num_gpus):
|
| 139 |
+
start = i * shard_size
|
| 140 |
+
end = start + shard_size if i < num_gpus - 1 else total_size
|
| 141 |
+
worker_args.append((i, i, start, end, model_name, batch_size))
|
| 142 |
|
| 143 |
+
with mp.Pool(processes=num_gpus) as pool:
|
| 144 |
+
shards = pool.map(process_shard_worker, worker_args)
|
| 145 |
|
| 146 |
+
ds = datasets.concatenate_datasets(shards)
|
| 147 |
|
| 148 |
# %%
|
| 149 |
ds.push_to_hub(output_dataset)
|