leandro von werra commited on
Commit
e3d9b66
·
1 Parent(s): 12f24f2

update training script and requirements

Browse files
Files changed (2) hide show
  1. codeparrot_training.py +40 -45
  2. requirements.txt +1 -2
codeparrot_training.py CHANGED
@@ -12,45 +12,41 @@ from argparse import Namespace
12
  import torch
13
  import logging
14
  import wandb
 
15
 
16
 
17
  class ConstantLengthDataset(IterableDataset):
18
-
19
- def __init__(self, tokenizer, dataset, seq_length=1024, batch_size=3,
20
  num_of_sequences=1024, chars_per_token=3.6):
21
  self.tokenizer = tokenizer
22
- self.concatenation_token = tokenizer.bos_token
23
  self.dataset = dataset
24
  self.seq_length = seq_length
25
- self.batch_size = batch_size
26
  self.input_characters = seq_length * chars_per_token * num_of_sequences
27
-
28
  def __iter__(self):
29
  iterator = iter(self.dataset)
30
  more_examples = True
31
- batch = []
32
  while more_examples:
33
- buffer = ''
 
34
  while True:
35
- if len(buffer) >= self.input_characters:
36
  break
37
  try:
38
- next_example = next(iterator)['content']
39
- buffer = buffer + self.concatenation_token + next_example
40
  except StopIteration:
41
  more_examples = False
42
  break
43
-
44
- tokenized_input = tokenizer(buffer, truncation=True,
45
- max_length=self.seq_length,
46
- return_overflowing_tokens=True)
47
-
48
- for input_ids in tokenized_input['input_ids']:
49
  if len(input_ids) == self.seq_length:
50
- batch.append(input_ids)
51
- if len(batch) == self.batch_size:
52
- yield torch.tensor(batch)
53
- batch = []
54
 
55
  def setup_logging(project_name):
56
  logger = logging.getLogger(__name__)
@@ -59,31 +55,31 @@ def setup_logging(project_name):
59
  datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,)
60
  if accelerator.is_main_process: # we only want to setup logging once
61
  wandb.init(project=project_name, config=args)
 
62
  tb_writer = SummaryWriter()
63
  tb_writer.add_hparams(vars(args), {'0': 0})
64
  logger.setLevel(logging.INFO)
65
  datasets.utils.logging.set_verbosity_warning()
66
  transformers.utils.logging.set_verbosity_info()
67
  else:
 
 
68
  logger.setLevel(logging.ERROR)
69
  datasets.utils.logging.set_verbosity_error()
70
  transformers.utils.logging.set_verbosity_error()
71
- return logger, tb_writer
72
 
73
  def create_dataloaders(dataset_name):
74
  train_data = load_dataset(dataset_name+'-train', split="train",
75
  streaming=True)
76
- train_data = train_data.shuffle(buffer_size=args.shuffle_buffer)
 
77
  valid_data = load_dataset(dataset_name+'-valid', split="train",
78
  streaming=True)
79
-
80
  train_dataset = ConstantLengthDataset(tokenizer, train_data,
81
- seq_length=args.seq_length,
82
- batch_size=args.train_batch_size)
83
  valid_dataset = ConstantLengthDataset(tokenizer, valid_data,
84
- seq_length=args.seq_length,
85
- batch_size=args.valid_batch_size)
86
-
87
  train_dataloader=DataLoader(train_dataset, batch_size=args.train_batch_size)
88
  eval_dataloader=DataLoader(valid_dataset, batch_size=args.valid_batch_size)
89
  return train_dataloader, eval_dataloader
@@ -107,7 +103,7 @@ def evaluate():
107
  losses = []
108
  for step, batch in enumerate(eval_dataloader):
109
  with torch.no_grad():
110
- outputs = model(batch[0], labels=batch[0])
111
  loss = outputs.loss.repeat(args.valid_batch_size)
112
  losses.append(accelerator.gather(loss))
113
  if args.max_eval_steps > 0 and step >= args.max_eval_steps: break
@@ -119,19 +115,19 @@ def evaluate():
119
  # Hyperparameters
120
  project_name = 'transformersbook/codeparrot-small'
121
  dataset_name = 'transformersbook/codeparrot'
122
- config = {"train_batch_size": 4,
123
- "valid_batch_size": 4,
124
  "weight_decay": 0.1,
125
  "shuffle_buffer": 1000,
126
  "learning_rate": 5e-4,
127
  "lr_scheduler_type": "cosine",
128
- "num_warmup_steps": 1000,
129
- "gradient_accumulation_steps": 8,
130
- "max_train_steps": 4096,
131
- "max_eval_steps": 1024,
132
  "seq_length": 1024,
133
  "seed": 1,
134
- "save_checkpoint_steps":4096,}
135
  args = Namespace(**config)
136
  set_seed(args.seed)
137
 
@@ -140,12 +136,12 @@ accelerator = Accelerator()
140
  samples_per_step = accelerator.state.num_processes * args.train_batch_size
141
 
142
  # Logging
143
- logger, tb_writer = setup_logging(project_name.split("/")[1])
144
  logger.info(accelerator.state)
145
- run_name = wandb.run.name
146
 
147
  # Load model and tokenizer
148
- hf_repo = Repository("./", clone_from=project_name, revision=run_name)
 
149
  model = GPT2LMHeadModel.from_pretrained("./")
150
  tokenizer = AutoTokenizer.from_pretrained("./")
151
 
@@ -167,12 +163,12 @@ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
167
  model.train()
168
  completed_steps = 0
169
  for step, batch in enumerate(train_dataloader, start=1):
170
- loss = model(batch[0], labels=batch[0]).loss
171
  log_metrics(step, {'lr': get_lr(), 'samples': step*samples_per_step,
172
  'steps': completed_steps, 'loss/train': loss.item()})
173
  loss = loss / args.gradient_accumulation_steps
174
  accelerator.backward(loss)
175
- if step % args.gradient_accumulation_steps == 0:
176
  optimizer.step()
177
  lr_scheduler.step()
178
  optimizer.zero_grad()
@@ -183,8 +179,8 @@ for step, batch in enumerate(train_dataloader, start=1):
183
  log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
184
  accelerator.wait_for_everyone()
185
  unwrapped_model = accelerator.unwrap_model(model)
186
- unwrapped_model.save_pretrained("./")
187
  if accelerator.is_main_process:
 
188
  hf_repo.push_to_hub(commit_message=f'step {step}')
189
  model.train()
190
  if completed_steps >= args.max_train_steps:
@@ -196,7 +192,6 @@ eval_loss, perplexity = evaluate()
196
  log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
197
  accelerator.wait_for_everyone()
198
  unwrapped_model = accelerator.unwrap_model(model)
199
- unwrapped_model.save_pretrained("./")
200
  if accelerator.is_main_process:
201
- try: hf_repo.push_to_hub(commit_message=f'final model')
202
- except: logger.info('No changes to previously saved model.')
 
12
  import torch
13
  import logging
14
  import wandb
15
+ import time
16
 
17
 
18
  class ConstantLengthDataset(IterableDataset):
19
+ def __init__(self, tokenizer, dataset, seq_length=1024,
 
20
  num_of_sequences=1024, chars_per_token=3.6):
21
  self.tokenizer = tokenizer
22
+ self.concat_token_id = tokenizer.bos_token_id
23
  self.dataset = dataset
24
  self.seq_length = seq_length
 
25
  self.input_characters = seq_length * chars_per_token * num_of_sequences
26
+ self.produced_samples = 0
27
  def __iter__(self):
28
  iterator = iter(self.dataset)
29
  more_examples = True
 
30
  while more_examples:
31
+ buffer = []
32
+ buffer_len = 0
33
  while True:
34
+ if buffer_len >= self.input_characters:
35
  break
36
  try:
37
+ buffer.append(next(iterator)['content'])
38
+ buffer_len += len(buffer[-1])
39
  except StopIteration:
40
  more_examples = False
41
  break
42
+ tokenized_inputs = tokenizer(buffer, truncation=False)['input_ids']
43
+ all_token_ids = []
44
+ for tokenized_input in tokenized_inputs:
45
+ all_token_ids.extend(tokenized_input + [self.concat_token_id])
46
+ for i in range(0, len(all_token_ids), self.seq_length):
47
+ input_ids = all_token_ids[i : i + self.seq_length]
48
  if len(input_ids) == self.seq_length:
49
+ yield torch.tensor(input_ids)
 
 
 
50
 
51
  def setup_logging(project_name):
52
  logger = logging.getLogger(__name__)
 
55
  datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,)
56
  if accelerator.is_main_process: # we only want to setup logging once
57
  wandb.init(project=project_name, config=args)
58
+ run_name = wandb.run.name
59
  tb_writer = SummaryWriter()
60
  tb_writer.add_hparams(vars(args), {'0': 0})
61
  logger.setLevel(logging.INFO)
62
  datasets.utils.logging.set_verbosity_warning()
63
  transformers.utils.logging.set_verbosity_info()
64
  else:
65
+ tb_writer = None
66
+ run_name = ''
67
  logger.setLevel(logging.ERROR)
68
  datasets.utils.logging.set_verbosity_error()
69
  transformers.utils.logging.set_verbosity_error()
70
+ return logger, tb_writer, run_name
71
 
72
  def create_dataloaders(dataset_name):
73
  train_data = load_dataset(dataset_name+'-train', split="train",
74
  streaming=True)
75
+ train_data = train_data.shuffle(buffer_size=args.shuffle_buffer,
76
+ seed=args.seed)
77
  valid_data = load_dataset(dataset_name+'-valid', split="train",
78
  streaming=True)
 
79
  train_dataset = ConstantLengthDataset(tokenizer, train_data,
80
+ seq_length=args.seq_length)
 
81
  valid_dataset = ConstantLengthDataset(tokenizer, valid_data,
82
+ seq_length=args.seq_length)
 
 
83
  train_dataloader=DataLoader(train_dataset, batch_size=args.train_batch_size)
84
  eval_dataloader=DataLoader(valid_dataset, batch_size=args.valid_batch_size)
85
  return train_dataloader, eval_dataloader
 
103
  losses = []
104
  for step, batch in enumerate(eval_dataloader):
105
  with torch.no_grad():
106
+ outputs = model(batch, labels=batch)
107
  loss = outputs.loss.repeat(args.valid_batch_size)
108
  losses.append(accelerator.gather(loss))
109
  if args.max_eval_steps > 0 and step >= args.max_eval_steps: break
 
115
  # Hyperparameters
116
  project_name = 'transformersbook/codeparrot-small'
117
  dataset_name = 'transformersbook/codeparrot'
118
+ config = {"train_batch_size": 12,
119
+ "valid_batch_size": 12,
120
  "weight_decay": 0.1,
121
  "shuffle_buffer": 1000,
122
  "learning_rate": 5e-4,
123
  "lr_scheduler_type": "cosine",
124
+ "num_warmup_steps": 2000,
125
+ "gradient_accumulation_steps": 1,
126
+ "max_train_steps": 8192,
127
+ "max_eval_steps": 512,
128
  "seq_length": 1024,
129
  "seed": 1,
130
+ "save_checkpoint_steps":512,}
131
  args = Namespace(**config)
132
  set_seed(args.seed)
133
 
 
136
  samples_per_step = accelerator.state.num_processes * args.train_batch_size
137
 
138
  # Logging
139
+ logger, tb_writer, run_name = setup_logging(project_name.split("/")[1])
140
  logger.info(accelerator.state)
 
141
 
142
  # Load model and tokenizer
143
+ if accelerator.is_main_process: # we only want to setup logging once
144
+ hf_repo = Repository("./", clone_from=project_name, revision=run_name)
145
  model = GPT2LMHeadModel.from_pretrained("./")
146
  tokenizer = AutoTokenizer.from_pretrained("./")
147
 
 
163
  model.train()
164
  completed_steps = 0
165
  for step, batch in enumerate(train_dataloader, start=1):
166
+ loss = model(batch, labels=batch).loss
167
  log_metrics(step, {'lr': get_lr(), 'samples': step*samples_per_step,
168
  'steps': completed_steps, 'loss/train': loss.item()})
169
  loss = loss / args.gradient_accumulation_steps
170
  accelerator.backward(loss)
171
+ if step % args.gradient_accumulation_steps == 0:
172
  optimizer.step()
173
  lr_scheduler.step()
174
  optimizer.zero_grad()
 
179
  log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
180
  accelerator.wait_for_everyone()
181
  unwrapped_model = accelerator.unwrap_model(model)
 
182
  if accelerator.is_main_process:
183
+ unwrapped_model.save_pretrained("./")
184
  hf_repo.push_to_hub(commit_message=f'step {step}')
185
  model.train()
186
  if completed_steps >= args.max_train_steps:
 
192
  log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
193
  accelerator.wait_for_everyone()
194
  unwrapped_model = accelerator.unwrap_model(model)
 
195
  if accelerator.is_main_process:
196
+ unwrapped_model.save_pretrained("./")
197
+ hf_repo.push_to_hub(commit_message=f'final model')
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
- torch==1.9.0
2
  wandb
3
  tensorboard
4
- git+https://github.com/huggingface/huggingface_hub.git@push-branching
5
  git+https://github.com/huggingface/transformers.git
6
  git+https://github.com/huggingface/datasets.git@load_dataset-no-dataset-script
7
  git+https://github.com/huggingface/accelerate.git
 
 
1
  wandb
2
  tensorboard
3
+ git+https://github.com/huggingface/huggingface_hub.git
4
  git+https://github.com/huggingface/transformers.git
5
  git+https://github.com/huggingface/datasets.git@load_dataset-no-dataset-script
6
  git+https://github.com/huggingface/accelerate.git