Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| """ | |
| The Aligner class simplifies the process of running alignment. | |
| """ | |
| import logging | |
| import numpy as np | |
| import os | |
| import sys | |
| import time | |
| from itertools import chain | |
| import torch | |
| import torch.distributed as dist | |
| import transformers | |
| from datasets import ( | |
| set_caching_enabled, | |
| Dataset, | |
| DatasetDict, | |
| ) | |
| from transformers import ( | |
| default_data_collator, | |
| pipeline, | |
| set_seed, | |
| ) | |
| from transformers.testing_utils import CaptureLogger | |
| from lmflow.args import DatasetArguments | |
| from lmflow.datasets.dataset import Dataset as LMFlowDataset | |
| from lmflow.pipeline.base_aligner import BaseAligner | |
| from lmflow.pipeline.utils.raft_trainer import RaftTrainer | |
| logger = logging.getLogger(__name__) | |
| class RaftAligner(BaseAligner): | |
| """ | |
| Initializes the `RaftAligner` class with given arguments. | |
| Parameters | |
| ------------ | |
| model_args : ModelArguments object. | |
| Contains the arguments required to load the model. | |
| data_args : DatasetArguments object. | |
| Contains the arguments required to load the dataset. | |
| raft_aligner_args : RaftAlignerArguments object. | |
| Contains the arguments required to perform alignment. | |
| args : Optional. | |
| Positional arguments. | |
| kwargs : Optional. | |
| Keyword arguments. | |
| """ | |
| def __init__(self, model_args, data_args, aligner_args, *args, **kwargs): | |
| self.model_args = model_args | |
| self.data_args = data_args | |
| self.aligner_args = aligner_args | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| ) | |
| logger.setLevel(logging.INFO) | |
| output_reward_path = aligner_args.output_reward_path | |
| if output_reward_path is not None: | |
| os.makedirs(os.path.dirname(output_reward_path), exist_ok=True) | |
| # Deletes a maybe-exist file | |
| try: | |
| os.remove(output_reward_path) | |
| except OSError: | |
| pass | |
| def _initialize_trainer(self, model, tokenizer, training_args): | |
| """ | |
| This function takes the model and tokenizer as the input and initialize the trainer. | |
| """ | |
| trainer = RaftTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=Dataset.from_dict({"text": [ " " ] }), | |
| eval_dataset=Dataset.from_dict({}), | |
| tokenizer=tokenizer, | |
| data_collator=default_data_collator, | |
| compute_metrics=None, | |
| preprocess_logits_for_metrics=None, | |
| ) | |
| return trainer | |
| def _load_dataset( | |
| self, | |
| selected_dataset, | |
| model, | |
| tokenizer, | |
| model_args, | |
| data_args, | |
| training_args, | |
| ): | |
| ''' | |
| This function prepares the dataset for every iteration. | |
| ''' | |
| raw_datasets = selected_dataset | |
| if training_args.do_train: | |
| column_names = list(raw_datasets["train"].features) | |
| else: | |
| column_names = list(raw_datasets["validation"].features) | |
| text_column_name = "text" if "text" in column_names else column_names[0] | |
| # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function | |
| tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") | |
| def tokenize_function(examples): | |
| with CaptureLogger(tok_logger) as cl: | |
| output = tokenizer(examples[text_column_name]) | |
| # clm input could be much much longer than block_size | |
| if "Token indices sequence length is longer than the" in cl.out: | |
| tok_logger.warning( | |
| "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" | |
| " before being passed to the model." | |
| ) | |
| return output | |
| with training_args.main_process_first(desc="dataset map tokenization"): | |
| if not data_args.streaming: | |
| tokenized_datasets = raw_datasets.map( | |
| tokenize_function, | |
| batched=True, | |
| num_proc=data_args.preprocessing_num_workers, | |
| remove_columns=column_names, | |
| load_from_cache_file=not data_args.overwrite_cache, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| else: | |
| tokenized_datasets = raw_datasets.map( | |
| tokenize_function, | |
| batched=True, | |
| remove_columns=column_names, | |
| ) | |
| if data_args.block_size is None: | |
| block_size = tokenizer.model_max_length | |
| if block_size > 1024: | |
| logger.warning( | |
| "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" | |
| " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" | |
| " override this default with `--block_size xxx`." | |
| ) | |
| block_size = 512 | |
| else: | |
| if data_args.block_size > tokenizer.model_max_length: | |
| logger.warning( | |
| f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" | |
| f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." | |
| ) | |
| block_size = min(data_args.block_size, tokenizer.model_max_length) | |
| # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. | |
| def group_texts(examples): | |
| # Concatenate all texts. | |
| concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | |
| total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
| # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can | |
| # customize this part to your needs. | |
| if total_length >= block_size: | |
| total_length = (total_length // block_size) * block_size | |
| # Split by chunks of max_len. | |
| result = { | |
| k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | |
| for k, t in concatenated_examples.items() | |
| } | |
| result["labels"] = result["input_ids"].copy() | |
| return result | |
| # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder | |
| # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower | |
| # to preprocess. | |
| # | |
| # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: | |
| # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map | |
| with training_args.main_process_first(desc="grouping texts together"): | |
| group_batch_size = 1000 | |
| if data_args.disable_group_texts: | |
| group_batch_size = 1 | |
| if not data_args.streaming: | |
| lm_datasets = tokenized_datasets.map( | |
| group_texts, | |
| batched=True, | |
| batch_size=group_batch_size, | |
| num_proc=data_args.preprocessing_num_workers, | |
| load_from_cache_file=not data_args.overwrite_cache, | |
| desc=f"Grouping texts in chunks of {block_size}", | |
| ) | |
| else: | |
| lm_datasets = tokenized_datasets.map( | |
| group_texts, | |
| batched=True, | |
| batch_size=group_batch_size, | |
| ) | |
| if training_args.do_train: | |
| if "train" not in tokenized_datasets: | |
| raise ValueError("--do_train requires a train dataset") | |
| train_dataset = lm_datasets["train"] | |
| if data_args.max_train_samples is not None: | |
| max_train_samples = min(len(train_dataset), data_args.max_train_samples) | |
| train_dataset = train_dataset.select(range(max_train_samples)) | |
| return train_dataset | |
| def _load_input_dataset(self, dataset, tokenizer): | |
| """ | |
| Load input dataset (i.e. prompt/question dataset) for training. | |
| Args: | |
| dataset: A Dataset object. | |
| The dataset to be loaded. | |
| Returns: | |
| dataloader (`torch.utils.data.DataLoader`): | |
| The dataloader for the dataset. | |
| """ | |
| ds = dataset.get_backend_dataset() | |
| def tokenize(sample): | |
| input_size = 16 | |
| review_encode = tokenizer.encode(sample["text"]) | |
| sample["input_ids"] = review_encode[:input_size] | |
| sample['input'] = tokenizer.decode(sample["input_ids"]) | |
| return sample | |
| ds = ds.map(tokenize, batched=False) | |
| ds.set_format(type='torch') | |
| return ds | |
| def _get_batch_dataset_top( | |
| self, | |
| model, | |
| batch_input, | |
| alpha=0.2, | |
| iter_id=0, | |
| local_rank=0, | |
| output_min_length=16, | |
| output_max_length=48, | |
| infer_batch_size=8, | |
| generation_kwargs={}, | |
| tokenizer=None, | |
| training_args=None, | |
| reward_model=None, | |
| output_reward_path=None, | |
| ): | |
| """ | |
| :param batch_input: input prompts | |
| """ | |
| # we will get the batch dataset via Dataset.from_dict | |
| start_time = time.time() | |
| output_data = [] | |
| query_tensors = batch_input['input_ids'] | |
| querys = batch_input['input'] | |
| data_size = len(querys) | |
| cnt = 0 | |
| reward_eva = [] | |
| reward_train = [] | |
| out_put_dataset_eval = {} | |
| data_eval = [] | |
| input_texts = [] | |
| responses = [] | |
| for i, query_tensor in enumerate(query_tensors): | |
| query = querys[i] | |
| input_texts.append(query) | |
| if (i + 1) % infer_batch_size == 0: | |
| gen_len = np.random.randint(output_min_length, output_max_length) | |
| generation_kwargs["max_new_tokens"] = gen_len | |
| inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(training_args.device) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, **generation_kwargs) | |
| generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| generated_texts = [ | |
| generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts) | |
| ] | |
| texts_for_rewards = [q + r for q, r in zip(input_texts, generated_texts)] | |
| texts_for_reward_dataset = LMFlowDataset.create_from_dict({ | |
| "type": "text_only", | |
| "instances": [ | |
| { "text": text } for text in texts_for_rewards | |
| ], | |
| }) | |
| reward_dataset = reward_model.inference(texts_for_reward_dataset) | |
| rewards = [ sample["value"] for sample in reward_dataset.to_dict()["instances"] ] | |
| reward_eva.extend(rewards) | |
| responses.extend(generated_texts) | |
| input_texts = [] | |
| data = [] | |
| idx = np.argsort(reward_eva)[::-1][:int(data_size * alpha)] | |
| for j in range(len(reward_eva)): | |
| sample = {} | |
| sample["input"] = querys[j] | |
| sample["output"] = [responses[j]] | |
| data.append(sample) | |
| output_data = [data[j] for j in idx] | |
| logger.info(f"collected data of {len(output_data)}") | |
| world_size = int(os.getenv("WORLD_SIZE", "1")) | |
| all_process_list =[{}] * world_size | |
| dist.all_gather_object(all_process_list, output_data) | |
| gathered_data = [] | |
| for i in range(world_size): | |
| gathered_data.extend(all_process_list[i]) | |
| reward_train = [reward_eva[j] for j in idx] | |
| reward_to_send = [np.mean(reward_eva), np.mean(reward_train)] | |
| all_process_rewards = [{}] * world_size | |
| dist.all_gather_object(all_process_rewards, reward_to_send) | |
| logger.info(all_process_rewards) | |
| if training_args.local_rank == 0 and output_reward_path is not None: | |
| with open(output_reward_path, mode='a') as fout: | |
| fout.write('mean reward: ' + str(np.mean([all_process_rewards[i][0] for i in range(world_size)])) + 'mean reward in training set: ' + str([all_process_rewards[i][1] for i in range(world_size)])) | |
| fout.write("\n") | |
| prompt_structure = "{definition}{input}{output}" | |
| output_dataset = { | |
| "text": [ prompt_structure.format( | |
| definition="", input=sample["input"], output=sample["output"][0] | |
| ) for sample in gathered_data | |
| ] | |
| } | |
| return DatasetDict({ "train": Dataset.from_dict(output_dataset) }) | |
| def align(self, model, dataset, reward_model): | |
| """ | |
| Perform alignment for a model | |
| Parameters | |
| ------------ | |
| model : BaseModel object. | |
| dataset: Dataset object. | |
| Input dataset for model to generate outputs. The input and output | |
| will then be feed into reward model to get the reward for | |
| alignment. | |
| reward_model: RegressionModel object. | |
| """ | |
| tokenizer = model.get_tokenizer() | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| tokenizer.padding_side = "left" | |
| dataset = self._load_input_dataset(dataset, tokenizer) | |
| set_caching_enabled(False) | |
| wrapped_model = model | |
| model = model.get_backend_model() | |
| generation_kwargs = { | |
| "min_length": -1, | |
| "top_k": 0.0, | |
| "top_p": 1.0, | |
| "do_sample": True, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| "temperature":0.7 | |
| } | |
| aligner_args = self.aligner_args | |
| training_args = aligner_args | |
| model_args = self.model_args | |
| data_args = self.data_args | |
| set_seed(42 + training_args.local_rank) | |
| ITERATION = aligner_args.num_raft_iteration | |
| M = aligner_args.raft_batch_size | |
| alpha = aligner_args.top_reward_percentage | |
| data_size = len(dataset['input']) | |
| reward_seq = [] | |
| lr = training_args.learning_rate | |
| raft_trainer = self._initialize_trainer(model, tokenizer, training_args) | |
| raft_trainer.train(resume_from_checkpoint=False, is_first_time=True) | |
| ############## | |
| for iteration in range(ITERATION): | |
| set_seed(88 + training_args.local_rank + 4 * (iteration+1)) | |
| batch_input = dataset.select(np.random.randint(low=0, high=data_size, size=M)) | |
| selected_dataset = self._get_batch_dataset_top( | |
| raft_trainer.tmp_model, | |
| batch_input, | |
| alpha, | |
| iteration, | |
| training_args.local_rank, | |
| output_min_length=aligner_args.output_min_length, | |
| output_max_length=aligner_args.output_max_length, | |
| infer_batch_size=aligner_args.inference_batch_size_per_device, | |
| generation_kwargs=generation_kwargs, | |
| tokenizer=tokenizer, | |
| training_args=training_args, | |
| reward_model=reward_model, | |
| output_reward_path=aligner_args.output_reward_path, | |
| ) | |
| raft_trainer.train_dataset = self._load_dataset( | |
| selected_dataset, | |
| raft_trainer.tmp_model, | |
| tokenizer, | |
| model_args, | |
| data_args, | |
| training_args, | |
| ) | |
| logger.info(f"iter {iteration}") | |
| start_time = time.time() | |
| train_result = raft_trainer.train(resume_from_checkpoint=False) | |
| end_time = time.time() | |
| logger.info("It takes %.2f s to train one stage", end_time - start_time) | |
| self._get_batch_dataset_top( | |
| raft_trainer.tmp_model, | |
| batch_input, alpha, | |
| iteration, | |
| training_args.local_rank, | |
| output_min_length=aligner_args.output_min_length, | |
| output_max_length=aligner_args.output_max_length, | |
| infer_batch_size=aligner_args.inference_batch_size_per_device, | |
| generation_kwargs=generation_kwargs, | |
| tokenizer=tokenizer, | |
| training_args=training_args, | |
| reward_model=reward_model, | |
| output_reward_path=aligner_args.output_reward_path, | |
| ) | |
| if aligner_args.output_dir is not None: | |
| wrapped_model.save(aligner_args.output_dir) | |
| return wrapped_model | |