nroggendorff commited on
Commit
55cdf87
·
verified ·
1 Parent(s): 89b44b4

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +22 -21
train.py CHANGED
@@ -115,34 +115,35 @@ def process_shard_worker(args):
115
 
116
  return result
117
 
118
- mp.set_start_method("spawn", force=True)
 
119
 
120
- input_dataset = "none-yet/anime-captions"
121
- output_dataset = "nroggendorff/anime-captions"
122
- model_name = "datalab-to/chandra"
123
 
124
- loaded = datasets.load_dataset(input_dataset, split="train")
125
 
126
- if isinstance(loaded, datasets.DatasetDict):
127
- ds = cast(Dataset, loaded["train"])
128
- else:
129
- ds = cast(Dataset, loaded)
130
 
131
- num_gpus = torch.cuda.device_count()
132
- batch_size = 32
133
- total_size = len(ds)
134
- shard_size = total_size // num_gpus
135
 
136
- worker_args = []
137
- for i in range(num_gpus):
138
- start = i * shard_size
139
- end = start + shard_size if i < num_gpus - 1 else total_size
140
- worker_args.append((i, i, start, end, model_name, batch_size))
141
 
142
- with mp.Pool(processes=num_gpus) as pool:
143
- shards = pool.map(process_shard_worker, worker_args)
144
 
145
- ds = datasets.concatenate_datasets(shards)
146
 
147
  # %%
148
  ds.push_to_hub(output_dataset)
 
115
 
116
  return result
117
 
118
+ if __name__ == '__main__':
119
+ mp.set_start_method("spawn", force=True)
120
 
121
+ input_dataset = "none-yet/anime-captions"
122
+ output_dataset = "none-yet/anime-captions"
123
+ model_name = "datalab-to/chandra"
124
 
125
+ loaded = datasets.load_dataset(input_dataset, split="train")
126
 
127
+ if isinstance(loaded, datasets.DatasetDict):
128
+ ds = cast(Dataset, loaded["train"])
129
+ else:
130
+ ds = cast(Dataset, loaded)
131
 
132
+ num_gpus = torch.cuda.device_count()
133
+ batch_size = 8
134
+ total_size = len(ds)
135
+ shard_size = total_size // num_gpus
136
 
137
+ worker_args = []
138
+ for i in range(num_gpus):
139
+ start = i * shard_size
140
+ end = start + shard_size if i < num_gpus - 1 else total_size
141
+ worker_args.append((i, i, start, end, model_name, batch_size))
142
 
143
+ with mp.Pool(processes=num_gpus) as pool:
144
+ shards = pool.map(process_shard_worker, worker_args)
145
 
146
+ ds = datasets.concatenate_datasets(shards)
147
 
148
  # %%
149
  ds.push_to_hub(output_dataset)