nroggendorff commited on
Commit
c79fe94
·
verified ·
1 Parent(s): 80f71ae

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +136 -70
train.py CHANGED
@@ -31,7 +31,7 @@ def load_model(model_name, device_id=0):
31
  return processor, model
32
 
33
 
34
- def getTemplate(processor):
35
  msg = [
36
  {
37
  "role": "user",
@@ -44,29 +44,97 @@ def getTemplate(processor):
44
  ],
45
  }
46
  ]
47
-
48
  return processor.apply_chat_template(
49
  msg, add_generation_prompt=True, tokenize=False
50
  )
51
 
52
 
53
- def caption_batch(batch, processor, model, text):
54
- images = batch["image"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  pil_images = []
57
- for image in images:
58
  if isinstance(image, Image.Image):
59
  if image.mode != "RGB":
60
  image = image.convert("RGB")
61
  pil_images.append(image)
62
 
63
- texts = [text] * len(pil_images)
64
-
65
- inputs = processor(text=texts, images=pil_images, return_tensors="pt", padding=True)
66
-
67
- inputs = {k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items()}
 
68
 
69
- with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
70
  generated = model.generate(
71
  **inputs,
72
  max_new_tokens=128,
@@ -76,91 +144,87 @@ def caption_batch(batch, processor, model, text):
76
  decoded = processor.batch_decode(generated, skip_special_tokens=False)
77
 
78
  captions = []
79
- special_tokens = set(processor.tokenizer.all_special_tokens)
 
80
  for d in decoded:
81
  if "<|im_start|>assistant" in d:
82
  d = d.split("<|im_start|>assistant")[-1]
83
-
84
- for token in special_tokens:
85
  d = d.replace(token, "")
 
86
 
87
- d = d.strip()
88
- captions.append(d)
89
 
90
- return {
91
- "text": captions,
92
- }
93
 
94
-
95
- def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, output_file):
 
96
  try:
97
  torch.cuda.set_device(gpu_id)
98
 
99
- print(f"[GPU {gpu_id}] Loading model...", flush=True)
100
  processor, model = load_model(model_name, gpu_id)
101
 
102
- print(f"[GPU {gpu_id}] Loading data shard [{start}:{end}]...", flush=True)
103
- loaded = datasets.load_dataset(input_dataset, split=f"train[{start}:{end}]")
104
-
105
- if isinstance(loaded, datasets.DatasetDict):
106
- shard = cast(Dataset, loaded["train"])
107
- else:
108
- shard = cast(Dataset, loaded)
109
 
110
- print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
111
  result = shard.map(
112
- lambda batch: caption_batch(batch, processor, model, getTemplate(processor)),
113
  batched=True,
114
  batch_size=batch_size,
115
- remove_columns=[col for col in shard.column_names if col != "image"],
116
  )
117
 
118
- print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
119
  result.save_to_disk(output_file)
120
 
121
- print(f"[GPU {gpu_id}] Done!", flush=True)
122
  return output_file
 
123
  except Exception as e:
124
  print(f"[GPU {gpu_id}] Error: {e}", flush=True)
125
  raise
126
 
127
 
128
  def main():
129
- mp.set_start_method('spawn', force=True)
130
 
131
  input_dataset = "none-yet/anime-captions"
 
132
  output_dataset = "nroggendorff/anime-captions"
133
  model_name = "datalab-to/chandra"
134
  batch_size = 20
135
 
136
- print("Loading dataset info...")
137
- loaded = datasets.load_dataset(input_dataset, split="train")
138
 
139
- if isinstance(loaded, datasets.DatasetDict):
140
- ds = cast(Dataset, loaded["train"])
141
- else:
142
- ds = cast(Dataset, loaded)
143
 
144
  num_gpus = torch.cuda.device_count()
145
- total_size = len(ds)
146
- shard_size = total_size // num_gpus
147
 
148
- print(f"Dataset size: {total_size}")
149
  print(f"Using {num_gpus} GPUs")
150
- print(f"Shard size: {shard_size}")
151
 
152
  processes = []
153
  temp_files = []
154
 
155
  for i in range(num_gpus):
156
- start = i * shard_size
157
- end = start + shard_size if i < num_gpus - 1 else total_size
158
- output_file = f"temp_shard_{i}"
159
- temp_files.append(output_file)
 
160
 
161
  p = mp.Process(
162
  target=process_shard,
163
- args=(i, start, end, model_name, batch_size, input_dataset, output_file),
164
  )
165
  p.start()
166
  processes.append(p)
@@ -168,30 +232,32 @@ def main():
168
  for p in processes:
169
  p.join()
170
  if p.exitcode != 0:
171
- print(f"\nProcess failed with exit code {p.exitcode}", flush=True)
172
- print("Terminating all processes...", flush=True)
173
- for proc in processes:
174
- if proc.is_alive():
175
- proc.terminate()
176
- for proc in processes:
177
- proc.join()
178
- raise RuntimeError(f"At least one process failed")
179
-
180
- print("\nAll processes completed. Loading and concatenating results...")
181
-
182
- shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
183
- final_ds = datasets.concatenate_datasets(shards)
184
-
185
- print(f"Final dataset size: {len(final_ds)}")
186
- print("Pushing to hub...")
 
 
 
187
  final_ds.push_to_hub(output_dataset, create_pr=False)
188
 
189
- print("Cleaning up temporary files...")
190
  for f in temp_files:
191
- if os.path.exists(f):
192
- shutil.rmtree(f)
193
 
194
- print("Done!")
195
 
196
 
197
  if __name__ == "__main__":
 
31
  return processor, model
32
 
33
 
34
+ def build_template(processor):
35
  msg = [
36
  {
37
  "role": "user",
 
44
  ],
45
  }
46
  ]
 
47
  return processor.apply_chat_template(
48
  msg, add_generation_prompt=True, tokenize=False
49
  )
50
 
51
 
52
+ def iterable_to_map(ds, chunk_size=10000):
53
+ buffer = []
54
+ for ex in ds:
55
+ buffer.append(ex)
56
+ if len(buffer) >= chunk_size:
57
+ yield buffer
58
+ buffer = []
59
+
60
+
61
+ def cpu_preprocess(input_dataset, output_folder, model_name):
62
+ print("CPU preprocessing…")
63
+
64
+ processor = AutoProcessor.from_pretrained(model_name)
65
+ template = build_template(processor)
66
+
67
+ def _pp(batch):
68
+ out_images = []
69
+ for img in batch["image"]:
70
+ if isinstance(img, Image.Image):
71
+ if img.mode != "RGB":
72
+ img = img.convert("RGB")
73
+ out_images.append(img)
74
+
75
+ prompts = [template] * len(out_images)
76
+ return {
77
+ "image": out_images,
78
+ "prompt": prompts,
79
+ }
80
+
81
+ ds = datasets.load_dataset(input_dataset, split="train")
82
+
83
+ if ds is None:
84
+ raise ValueError(
85
+ f"Failed to load dataset '{input_dataset}' with split 'train'. Check the dataset name or available splits."
86
+ )
87
+
88
+ if isinstance(ds, datasets.DatasetDict):
89
+ if "train" in ds:
90
+ ds = ds["train"]
91
+ else:
92
+ raise ValueError(
93
+ f"'{input_dataset}' does not contain a 'train' split. Available splits: {list(ds.keys())}"
94
+ )
95
+
96
+ if not isinstance(ds, datasets.Dataset):
97
+ raise TypeError(f"Expected a Dataset instance, got {type(ds)}")
98
+
99
+ print(f"Dataset loaded: {len(ds)} examples")
100
+
101
+ ds2 = ds.map(
102
+ _pp,
103
+ batched=True,
104
+ remove_columns=[c for c in ds.column_names if c not in ("image",)],
105
+ )
106
+
107
+ print("Saving CPU-preprocessed dataset…")
108
+ parts = []
109
+ for chunk in iterable_to_map(ds2):
110
+ part = Dataset.from_list(chunk)
111
+ parts.append(part)
112
+
113
+ ds2 = datasets.concatenate_datasets(parts)
114
+ ds2.save_to_disk(output_folder)
115
+
116
+ print("CPU preprocessing done.")
117
+
118
+
119
+ def caption_batch(batch, processor, model):
120
+ imgs = batch["image"]
121
+ prompts = batch["prompt"]
122
 
123
  pil_images = []
124
+ for image in imgs:
125
  if isinstance(image, Image.Image):
126
  if image.mode != "RGB":
127
  image = image.convert("RGB")
128
  pil_images.append(image)
129
 
130
+ inputs = processor(
131
+ text=prompts, images=pil_images, return_tensors="pt", padding=True
132
+ )
133
+ inputs = {
134
+ k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items()
135
+ }
136
 
137
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): # type: ignore
138
  generated = model.generate(
139
  **inputs,
140
  max_new_tokens=128,
 
144
  decoded = processor.batch_decode(generated, skip_special_tokens=False)
145
 
146
  captions = []
147
+ special = set(processor.tokenizer.all_special_tokens)
148
+
149
  for d in decoded:
150
  if "<|im_start|>assistant" in d:
151
  d = d.split("<|im_start|>assistant")[-1]
152
+ for token in special:
 
153
  d = d.replace(token, "")
154
+ captions.append(d.strip())
155
 
156
+ return {"text": captions}
 
157
 
 
 
 
158
 
159
+ def process_shard(
160
+ gpu_id, start, end, model_name, batch_size, prepped_folder, output_file
161
+ ):
162
  try:
163
  torch.cuda.set_device(gpu_id)
164
 
165
+ print(f"[GPU {gpu_id}] Loading model", flush=True)
166
  processor, model = load_model(model_name, gpu_id)
167
 
168
+ print(f"[GPU {gpu_id}] Loading preprocessed shard [{start}:{end}]", flush=True)
169
+ shard = datasets.load_from_disk(prepped_folder)
170
+ if isinstance(shard, datasets.DatasetDict):
171
+ shard = shard["train"]
172
+ shard = shard.select(range(start, end))
 
 
173
 
174
+ print(f"[GPU {gpu_id}] Captioning {len(shard)} examples", flush=True)
175
  result = shard.map(
176
+ lambda batch: caption_batch(batch, processor, model),
177
  batched=True,
178
  batch_size=batch_size,
179
+ remove_columns=["image", "prompt"],
180
  )
181
 
182
+ print(f"[GPU {gpu_id}] Saving {output_file}", flush=True)
183
  result.save_to_disk(output_file)
184
 
185
+ print(f"[GPU {gpu_id}] Done.", flush=True)
186
  return output_file
187
+
188
  except Exception as e:
189
  print(f"[GPU {gpu_id}] Error: {e}", flush=True)
190
  raise
191
 
192
 
193
  def main():
194
+ mp.set_start_method("spawn", force=True)
195
 
196
  input_dataset = "none-yet/anime-captions"
197
+ prepped_folder = "cpu_preprocessed"
198
  output_dataset = "nroggendorff/anime-captions"
199
  model_name = "datalab-to/chandra"
200
  batch_size = 20
201
 
202
+ if not os.path.exists(prepped_folder):
203
+ cpu_preprocess(input_dataset, prepped_folder, model_name)
204
 
205
+ ds = datasets.load_from_disk(prepped_folder)
206
+ total = len(ds)
 
 
207
 
208
  num_gpus = torch.cuda.device_count()
209
+ shard = total // num_gpus
 
210
 
211
+ print(f"Dataset size: {total}")
212
  print(f"Using {num_gpus} GPUs")
213
+ print(f"Shard size: {shard}")
214
 
215
  processes = []
216
  temp_files = []
217
 
218
  for i in range(num_gpus):
219
+ s = i * shard
220
+ e = s + shard if i < num_gpus - 1 else total
221
+
222
+ of = f"temp_shard_{i}"
223
+ temp_files.append(of)
224
 
225
  p = mp.Process(
226
  target=process_shard,
227
+ args=(i, s, e, model_name, batch_size, prepped_folder, of),
228
  )
229
  p.start()
230
  processes.append(p)
 
232
  for p in processes:
233
  p.join()
234
  if p.exitcode != 0:
235
+ print("A process failed, aborting…")
236
+ for q in processes:
237
+ if q.is_alive():
238
+ q.terminate()
239
+ for q in processes:
240
+ q.join()
241
+ raise RuntimeError("GPU worker failed.")
242
+
243
+ print("Merging shards…")
244
+ parts = []
245
+ for f in temp_files:
246
+ ds = datasets.load_from_disk(f)
247
+ if isinstance(ds, datasets.DatasetDict):
248
+ ds = ds["train"]
249
+ parts.append(ds)
250
+
251
+ final_ds = datasets.concatenate_datasets(parts)
252
+
253
+ print(f"Pushing final dataset to {output_dataset}…")
254
  final_ds.push_to_hub(output_dataset, create_pr=False)
255
 
256
+ print("Cleaning up")
257
  for f in temp_files:
258
+ shutil.rmtree(f, ignore_errors=True)
 
259
 
260
+ print("Done.")
261
 
262
 
263
  if __name__ == "__main__":