nroggendorff commited on
Commit
dae35de
·
verified ·
1 Parent(s): 4b8a72e

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +48 -33
train.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
3
  import datasets
4
  from datasets import Dataset
5
  from typing import cast
@@ -18,13 +18,13 @@ def load_model(model_name, device_id=0):
18
  )
19
 
20
  processor = AutoProcessor.from_pretrained(model_name)
 
21
 
22
- model = AutoModelForVision2Seq.from_pretrained(
23
  model_name,
24
  quantization_config=bnb_config,
25
  dtype=torch.bfloat16,
26
  device_map={"": device_id},
27
- torch_dtype=torch.bfloat16,
28
  attn_implementation="flash_attention_2",
29
  )
30
 
@@ -63,7 +63,7 @@ def caption_batch(batch, processor, model):
63
 
64
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
65
 
66
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
67
  generated = model.generate(
68
  **inputs,
69
  max_new_tokens=128,
@@ -91,40 +91,44 @@ def caption_batch(batch, processor, model):
91
 
92
 
93
  def process_shard_worker(
94
- gpu_id, start, end, model_name, batch_size, input_dataset, output_file
95
  ):
96
- torch.cuda.set_device(gpu_id)
97
-
98
- print(f"[GPU {gpu_id}] Loading model...", flush=True)
99
- processor, model = load_model(model_name, gpu_id)
100
-
101
- print(f"[GPU {gpu_id}] Loading data shard [{start}:{end}]...", flush=True)
102
- loaded = datasets.load_dataset(input_dataset, split=f"train[{start}:{end}]")
103
-
104
- if isinstance(loaded, datasets.DatasetDict):
105
- shard = cast(Dataset, loaded["train"])
106
- else:
107
- shard = cast(Dataset, loaded)
108
-
109
- print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
110
- result = shard.map(
111
- lambda batch: caption_batch(batch, processor, model),
112
- batched=True,
113
- batch_size=batch_size,
114
- remove_columns=[col for col in shard.column_names if col != "image"],
115
- writer_batch_size=1000,
116
- )
 
117
 
118
- print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
119
- result.save_to_disk(output_file)
120
 
121
- print(f"[GPU {gpu_id}] Done!", flush=True)
122
- return output_file
 
 
 
123
 
124
 
125
  def main():
126
- input_dataset = "none-yet/wikiart"
127
- output_dataset = "nroggendorff/wikiart"
128
  model_name = "datalab-to/chandra"
129
  batch_size = 16
130
 
@@ -148,6 +152,7 @@ def main():
148
 
149
  processes = []
150
  temp_files = []
 
151
 
152
  for i in range(num_gpus):
153
  start = i * shard_size
@@ -157,13 +162,23 @@ def main():
157
 
158
  p = mp.Process(
159
  target=process_shard_worker,
160
- args=(i, start, end, model_name, batch_size, input_dataset, output_file),
161
  )
162
  p.start()
163
  processes.append(p)
164
 
165
  for p in processes:
166
  p.join()
 
 
 
 
 
 
 
 
 
 
167
 
168
  print("\nAll processes completed. Loading and concatenating results...")
169
 
 
1
  import torch
2
+ from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
3
  import datasets
4
  from datasets import Dataset
5
  from typing import cast
 
18
  )
19
 
20
  processor = AutoProcessor.from_pretrained(model_name)
21
+ processor.tokenizer.padding_side = "left"
22
 
23
+ model = AutoModelForImageTextToText.from_pretrained(
24
  model_name,
25
  quantization_config=bnb_config,
26
  dtype=torch.bfloat16,
27
  device_map={"": device_id},
 
28
  attn_implementation="flash_attention_2",
29
  )
30
 
 
63
 
64
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
65
 
66
+ with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
67
  generated = model.generate(
68
  **inputs,
69
  max_new_tokens=128,
 
91
 
92
 
93
  def process_shard_worker(
94
+ gpu_id, start, end, model_name, batch_size, input_dataset, output_file, error_queue
95
  ):
96
+ try:
97
+ torch.cuda.set_device(gpu_id)
98
+
99
+ print(f"[GPU {gpu_id}] Loading model...", flush=True)
100
+ processor, model = load_model(model_name, gpu_id)
101
+
102
+ print(f"[GPU {gpu_id}] Loading data shard [{start}:{end}]...", flush=True)
103
+ loaded = datasets.load_dataset(input_dataset, split=f"train[{start}:{end}]")
104
+
105
+ if isinstance(loaded, datasets.DatasetDict):
106
+ shard = cast(Dataset, loaded["train"])
107
+ else:
108
+ shard = cast(Dataset, loaded)
109
+
110
+ print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
111
+ result = shard.map(
112
+ lambda batch: caption_batch(batch, processor, model),
113
+ batched=True,
114
+ batch_size=batch_size,
115
+ remove_columns=[col for col in shard.column_names if col != "image"],
116
+ writer_batch_size=1000,
117
+ )
118
 
119
+ print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
120
+ result.save_to_disk(output_file)
121
 
122
+ print(f"[GPU {gpu_id}] Done!", flush=True)
123
+ return output_file
124
+ except Exception as e:
125
+ error_queue.put((gpu_id, e))
126
+ raise
127
 
128
 
129
  def main():
130
+ input_dataset = "none-yet/anime-captions"
131
+ output_dataset = "nroggendorff/anime-captions"
132
  model_name = "datalab-to/chandra"
133
  batch_size = 16
134
 
 
152
 
153
  processes = []
154
  temp_files = []
155
+ error_queue = mp.Queue()
156
 
157
  for i in range(num_gpus):
158
  start = i * shard_size
 
162
 
163
  p = mp.Process(
164
  target=process_shard_worker,
165
+ args=(i, start, end, model_name, batch_size, input_dataset, output_file, error_queue),
166
  )
167
  p.start()
168
  processes.append(p)
169
 
170
  for p in processes:
171
  p.join()
172
+ if not error_queue.empty():
173
+ gpu_id, error = error_queue.get()
174
+ print(f"\n[GPU {gpu_id}] Error occurred: {error}", flush=True)
175
+ print("Terminating all processes...", flush=True)
176
+ for proc in processes:
177
+ if proc.is_alive():
178
+ proc.terminate()
179
+ for proc in processes:
180
+ proc.join()
181
+ raise RuntimeError(f"Process for GPU {gpu_id} failed with error: {error}")
182
 
183
  print("\nAll processes completed. Loading and concatenating results...")
184