nroggendorff commited on
Commit
89b44b4
·
verified ·
1 Parent(s): 6a6733e

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +42 -26
train.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
4
 
5
 
6
- def load_model(model_name="datalab-to/chandra", device_id=0):
7
  bnb_config = BitsAndBytesConfig(
8
  load_in_4bit=True,
9
  bnb_4bit_compute_dtype=torch.bfloat16,
@@ -16,7 +16,7 @@ def load_model(model_name="datalab-to/chandra", device_id=0):
16
  model = AutoModelForVision2Seq.from_pretrained(
17
  model_name,
18
  quantization_config=bnb_config,
19
- dtype=torch.bfloat16,
20
  device_map={"": device_id},
21
  )
22
 
@@ -67,14 +67,16 @@ def caption_batch(batch, processor, model):
67
  generated = model.generate(
68
  input_ids=input_ids,
69
  attention_mask=attention_mask,
 
70
  )
71
 
72
- decoded = processor.batch_decode(generated)
73
 
74
  captions = []
75
  for d in decoded:
76
  if "<|im_start|>assistant" in d:
77
  d = d.split("<|im_start|>assistant")[-1].strip()
 
78
  captions.append(d)
79
 
80
  return {
@@ -86,11 +88,38 @@ def caption_batch(batch, processor, model):
86
  import datasets
87
  from datasets import Dataset
88
  from typing import cast
89
- from concurrent.futures import ThreadPoolExecutor
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  input_dataset = "none-yet/anime-captions"
93
  output_dataset = "nroggendorff/anime-captions"
 
94
 
95
  loaded = datasets.load_dataset(input_dataset, split="train")
96
 
@@ -100,31 +129,18 @@ else:
100
  ds = cast(Dataset, loaded)
101
 
102
  num_gpus = torch.cuda.device_count()
103
- models = [load_model(device_id=i) for i in range(num_gpus)]
104
-
105
  batch_size = 32
106
- shard_size = len(ds) // num_gpus
107
-
108
-
109
- def process_shard(shard_idx, processor, model):
110
- start = shard_idx * shard_size
111
- end = start + shard_size if shard_idx < num_gpus - 1 else len(ds)
112
- shard = ds.select(range(start, end))
113
-
114
- return shard.map(
115
- lambda batch: caption_batch(batch, processor, model),
116
- batched=True,
117
- batch_size=batch_size,
118
- remove_columns=shard.column_names,
119
- )
120
 
 
 
 
 
 
121
 
122
- with ThreadPoolExecutor(max_workers=num_gpus) as executor:
123
- futures = [
124
- executor.submit(process_shard, i, proc, model)
125
- for i, (proc, model) in enumerate(models)
126
- ]
127
- shards = [f.result() for f in futures]
128
 
129
  ds = datasets.concatenate_datasets(shards)
130
 
 
3
  from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
4
 
5
 
6
+ def load_model(model_name, device_id=0):
7
  bnb_config = BitsAndBytesConfig(
8
  load_in_4bit=True,
9
  bnb_4bit_compute_dtype=torch.bfloat16,
 
16
  model = AutoModelForVision2Seq.from_pretrained(
17
  model_name,
18
  quantization_config=bnb_config,
19
+ torch_dtype=torch.bfloat16,
20
  device_map={"": device_id},
21
  )
22
 
 
67
  generated = model.generate(
68
  input_ids=input_ids,
69
  attention_mask=attention_mask,
70
+ max_new_tokens=256,
71
  )
72
 
73
+ decoded = processor.batch_decode(generated, skip_special_tokens=False)
74
 
75
  captions = []
76
  for d in decoded:
77
  if "<|im_start|>assistant" in d:
78
  d = d.split("<|im_start|>assistant")[-1].strip()
79
+ d = d.replace("<|im_end|>", "").strip()
80
  captions.append(d)
81
 
82
  return {
 
88
  import datasets
89
  from datasets import Dataset
90
  from typing import cast
91
+ import multiprocessing as mp
92
 
93
 
94
+ def process_shard_worker(args):
95
+ _, device_id, start, end, model_name, batch_size = args
96
+
97
+ torch.cuda.set_device(device_id)
98
+
99
+ processor, model = load_model(model_name, device_id)
100
+
101
+ input_dataset = "none-yet/anime-captions"
102
+ loaded = datasets.load_dataset(input_dataset, split=f"train[{start}:{end}]")
103
+
104
+ if isinstance(loaded, datasets.DatasetDict):
105
+ shard = cast(Dataset, loaded["train"])
106
+ else:
107
+ shard = cast(Dataset, loaded)
108
+
109
+ result = shard.map(
110
+ lambda batch: caption_batch(batch, processor, model),
111
+ batched=True,
112
+ batch_size=batch_size,
113
+ remove_columns=shard.column_names,
114
+ )
115
+
116
+ return result
117
+
118
+ mp.set_start_method("spawn", force=True)
119
+
120
  input_dataset = "none-yet/anime-captions"
121
  output_dataset = "nroggendorff/anime-captions"
122
+ model_name = "datalab-to/chandra"
123
 
124
  loaded = datasets.load_dataset(input_dataset, split="train")
125
 
 
129
  ds = cast(Dataset, loaded)
130
 
131
  num_gpus = torch.cuda.device_count()
 
 
132
  batch_size = 32
133
+ total_size = len(ds)
134
+ shard_size = total_size // num_gpus
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ worker_args = []
137
+ for i in range(num_gpus):
138
+ start = i * shard_size
139
+ end = start + shard_size if i < num_gpus - 1 else total_size
140
+ worker_args.append((i, i, start, end, model_name, batch_size))
141
 
142
+ with mp.Pool(processes=num_gpus) as pool:
143
+ shards = pool.map(process_shard_worker, worker_args)
 
 
 
 
144
 
145
  ds = datasets.concatenate_datasets(shards)
146