| | """ |
| | Fine-Tune SantaCoder on code/text dataset |
| | """ |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import argparse |
| | import os |
| | import random |
| | import sys |
| |
|
| | import numpy as np |
| | import torch |
| | from datasets import load_dataset |
| | from torch.utils.data import IterableDataset |
| | from torch.utils.data.dataloader import DataLoader |
| | from tqdm import tqdm |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | Trainer, |
| | TrainingArguments, |
| | logging, |
| | set_seed, |
| | ) |
| |
|
| | |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--resume_from_checkpoint", type=str, default=None) |
| | parser.add_argument("--model_path", type=str, default="bigcode/santacoder") |
| | parser.add_argument("--dataset_name", type=str, default="bigcode/the-stack-dedup") |
| | parser.add_argument("--subset", type=str, default=None) |
| | parser.add_argument("--split", type=str, default="train") |
| | parser.add_argument("--size_valid_set", type=int, default=4000) |
| | parser.add_argument("--streaming", action="store_true") |
| | parser.add_argument("--shuffle_buffer", type=int, default=5000) |
| | parser.add_argument("--data_column", type=str, default="content") |
| |
|
| | parser.add_argument("--seq_length", type=int, default=1024) |
| | parser.add_argument("--max_steps", type=int, default=10000) |
| | parser.add_argument("--batch_size", type=int, default=2) |
| | parser.add_argument("--gradient_accumulation_steps", type=int, default=8) |
| | parser.add_argument("--eos_token_id", type=int, default=49152) |
| |
|
| | parser.add_argument("--learning_rate", type=float, default=5e-5) |
| | parser.add_argument("--lr_scheduler_type", type=str, default="cosine") |
| | parser.add_argument("--num_warmup_steps", type=int, default=100) |
| | parser.add_argument("--weight_decay", type=float, default=0.05) |
| |
|
| | parser.add_argument("--local_rank", type=int, default=0) |
| | parser.add_argument("--no_fp16", action="store_false") |
| | parser.add_argument("--bf16", action="store_true") |
| | parser.add_argument("--no_gradient_checkpointing", action="store_false") |
| | parser.add_argument("--seed", type=int, default=0) |
| | parser.add_argument("--num_workers", type=int, default=None) |
| | parser.add_argument("--output_dir", type=str, default="./checkpoints") |
| | parser.add_argument("--log_freq", default=1, type=int) |
| | parser.add_argument("--eval_freq", default=1000, type=int) |
| | parser.add_argument("--save_freq", default=1000, type=int) |
| |
|
| | |
| | |
| | return parser.parse_args() |
| |
|
| |
|
| | def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400): |
| | """ |
| | Estimate the average number of characters per token in the dataset. |
| | """ |
| | total_characters, total_tokens = 0, 0 |
| | for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): |
| | total_characters += len(example[data_column]) |
| | total_tokens += len(tokenizer(example[data_column]).tokens()) |
| |
|
| | return total_characters / total_tokens |
| |
|
| |
|
| | class ConstantLengthDataset(IterableDataset): |
| | """ |
| | Iterable dataset that returns constant length chunks of tokens from stream of text files. |
| | Args: |
| | tokenizer (Tokenizer): The processor used for proccessing the data. |
| | dataset (dataset.Dataset): Dataset with text files. |
| | infinite (bool): If True the iterator is reset after dataset reaches end else stops. |
| | seq_length (int): Length of token sequences to return. |
| | num_of_sequences (int): Number of token sequences to keep in buffer. |
| | chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. |
| | # fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM. |
| | # fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM. |
| | seed (int): Seed for random number generator. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | tokenizer, |
| | dataset, |
| | infinite=False, |
| | seq_length=1024, |
| | num_of_sequences=1024, |
| | chars_per_token=3.6, |
| | content_field="content", |
| | |
| | |
| | seed=0, |
| | ): |
| | self.tokenizer = tokenizer |
| | self.concat_token_id = ( |
| | tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id |
| | ) |
| | self.dataset = dataset |
| | self.seq_length = seq_length |
| | self.infinite = infinite |
| | self.current_size = 0 |
| | self.max_buffer_size = seq_length * chars_per_token * num_of_sequences |
| | self.content_field = content_field |
| | |
| | |
| | self.seed = seed |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def __iter__(self): |
| | iterator = iter(self.dataset) |
| | more_examples = True |
| | while more_examples: |
| | buffer, buffer_len = [], 0 |
| | while True: |
| | if buffer_len >= self.max_buffer_size: |
| | break |
| | try: |
| | buffer.append(next(iterator)[self.content_field]) |
| | buffer_len += len(buffer[-1]) |
| | except StopIteration: |
| | if self.infinite: |
| | iterator = iter(self.dataset) |
| | else: |
| | more_examples = False |
| | break |
| | tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] |
| | all_token_ids = [] |
| |
|
| | np_rng = np.random.RandomState(seed=self.seed) |
| | for tokenized_input in tokenized_inputs: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | all_token_ids.extend(tokenized_input + [self.concat_token_id]) |
| | examples = [] |
| | for i in range(0, len(all_token_ids), self.seq_length): |
| | input_ids = all_token_ids[i : i + self.seq_length] |
| | if len(input_ids) == self.seq_length: |
| | examples.append(input_ids) |
| | random.shuffle(examples) |
| | for example in examples: |
| | self.current_size += 1 |
| | yield { |
| | "input_ids": torch.LongTensor(example), |
| | "labels": torch.LongTensor(example), |
| | } |
| |
|
| | def create_datasets(tokenizer, args): |
| | dataset = load_dataset( |
| | args.dataset_name, |
| | data_dir=args.subset, |
| | split=args.split, |
| | use_auth_token=True, |
| | num_proc=args.num_workers if not args.streaming else None, |
| | streaming=args.streaming, |
| | ) |
| | if args.streaming: |
| | print("Loading the dataset in streaming mode") |
| | valid_data = dataset.take(args.size_valid_set) |
| | train_data = dataset.skip(args.size_valid_set) |
| | train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) |
| | else: |
| | dataset = dataset.train_test_split(test_size=0.005, seed=args.seed) |
| | train_data = dataset["train"] |
| | valid_data = dataset["test"] |
| | print( |
| | f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}" |
| | ) |
| | chars_per_token = chars_token_ratio(train_data, tokenizer, args.data_column) |
| | print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") |
| | train_dataset = ConstantLengthDataset( |
| | tokenizer, |
| | train_data, |
| | infinite=True, |
| | seq_length=args.seq_length, |
| | chars_per_token=chars_per_token, |
| | content_field=args.data_column, |
| | |
| | |
| | seed=args.seed, |
| | ) |
| | valid_dataset = ConstantLengthDataset( |
| | tokenizer, |
| | valid_data, |
| | infinite=False, |
| | seq_length=args.seq_length, |
| | chars_per_token=chars_per_token, |
| | content_field=args.data_column, |
| | |
| | |
| | seed=args.seed, |
| | ) |
| |
|
| | return train_dataset, valid_dataset |
| |
|
| |
|
| | def run_training(args, train_data, val_data): |
| | print("Loading the model") |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | args.model_path, |
| | trust_remote_code=True, |
| | use_cache=not args.no_gradient_checkpointing, |
| | ) |
| | train_data.start_iteration = 0 |
| |
|
| | print(f"Starting main loop") |
| |
|
| | training_args = TrainingArguments( |
| | output_dir=args.output_dir, |
| | dataloader_drop_last=True, |
| | evaluation_strategy="steps", |
| | max_steps=args.max_steps, |
| | eval_steps=args.eval_freq, |
| | save_steps=args.save_freq, |
| | logging_steps=args.log_freq, |
| | per_device_train_batch_size=args.batch_size, |
| | per_device_eval_batch_size=args.batch_size, |
| | learning_rate=args.learning_rate, |
| | lr_scheduler_type=args.lr_scheduler_type, |
| | warmup_steps=args.num_warmup_steps, |
| | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| | gradient_checkpointing=args.no_gradient_checkpointing, |
| | fp16=args.no_fp16, |
| | bf16=args.bf16, |
| | weight_decay=args.weight_decay, |
| | run_name=f"santacoder-{args.subset}", |
| | |
| | ) |
| |
|
| | trainer = Trainer( |
| | model=model, args=training_args, train_dataset=train_data, eval_dataset=val_data |
| | ) |
| |
|
| | print("Training...") |
| | trainer.train(args.resume_from_checkpoint) |
| |
|
| | print("Saving last checkpoint of the model") |
| | model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/")) |
| |
|
| |
|
| | def main(args): |
| | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_auth_token=True) |
| |
|
| | train_dataset, eval_dataset = create_datasets(tokenizer, args) |
| |
|
| | run_training(args, train_dataset, eval_dataset) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print(sys.argv) |
| | args = get_args() |
| | print(args) |
| | set_seed(args.seed) |
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | logging.set_verbosity_info() |
| |
|
| | main(args) |
| |
|