# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random import warnings from collections import deque from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import IterableDataset from transformers import DataCollatorForLanguageModeling, PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback class AdaptiveKLController: """ Adaptive KL controller described in the paper: https://arxiv.org/pdf/1909.08593.pdf """ def __init__(self, init_kl_coef, target, horizon): self.value = init_kl_coef self.target = target self.horizon = horizon def update(self, current, n_steps): target = self.target proportional_error = np.clip(current / target - 1, -0.2, 0.2) mult = 1 + proportional_error * n_steps / self.horizon self.value *= mult class FixedKLController: """Fixed KL controller.""" def __init__(self, kl_coef): self.value = kl_coef def update(self, current, n_steps): pass class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): """ Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' when they do not come from the assistant. This ensure that the loss is only calculated on the completion made by the assistant. Args: instruction_template (`Optional[str]`): the template form that indicates the start of the human instruction, typically something like '### Human:\n'. Useful for assistant-style conversation datasets response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like '### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response differently if it does not have proper context. mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present for flexibility and backwards-compatibility. ignore_index (`int`, *optional*, defaults to `-100`): The index to use to ignore the initial tokens with """ def __init__( self, response_template: Union[str, List[int]], instruction_template: Union[str, List[int]] = None, *args, mlm: bool = False, ignore_index: int = -100, **kwargs, ): super().__init__(*args, mlm=mlm, **kwargs) self.instruction_template = instruction_template if isinstance(instruction_template, str): # The user provides a string, must tokenize self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) else: # The user already provides the token ids self.instruction_token_ids = instruction_template self.response_template = response_template if isinstance(response_template, str): # The user provides a string, must tokenize self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) else: # The user already provides the token ids self.response_token_ids = response_template if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: warnings.warn( "The pad_token_id and eos_token_id values of this tokenizer are identical. " "If you are planning for multi-turn training, " "it can result in the model continuously generating questions and answers without eos token. " "To avoid this, set the pad_token_id to a different value." ) self.ignore_index = ignore_index def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: batch = super().torch_call(examples) if self.instruction_template is None: for i in range(len(examples)): response_token_ids_start_idx = None for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match if ( self.response_token_ids == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist() ): response_token_ids_start_idx = idx if response_token_ids_start_idx is None: warnings.warn( f"Could not find response key `{self.response_template}` in the " f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' f"This instance will be ignored in loss calculation. " f"Note, if this happens often, consider increasing the `max_seq_length`." ) batch["labels"][i, :] = self.ignore_index else: response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids) # Make pytorch loss function ignore all tokens up through the end of the response key batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index else: for i in range(len(examples)): response_token_ids_idxs = [] human_token_ids_idxs = [] for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: # find the indexes of the start of a response. if ( self.response_token_ids == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist() ): response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids)) if len(response_token_ids_idxs) == 0: warnings.warn( f"Could not find response key `{self.response_template}` in the " f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' f"This instance will be ignored in loss calculation. " f"Note, if this happens often, consider increasing the `max_seq_length`." ) batch["labels"][i, :] = self.ignore_index human_token_ids = self.instruction_token_ids for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]: # find the indexes of the start of a human answer. if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist(): human_token_ids_idxs.append(human_idx) if len(human_token_ids_idxs) == 0: warnings.warn( f"Could not find instruction key `{self.instruction_template}` in the " f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' f"This instance will be ignored in loss calculation. " f"Note, if this happens often, consider increasing the `max_seq_length`." ) batch["labels"][i, :] = self.ignore_index for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): # Make pytorch loss function ignore all non response tokens if idx != 0: batch["labels"][i, start:end] = self.ignore_index else: batch["labels"][i, :end] = self.ignore_index if len(response_token_ids_idxs) < len(human_token_ids_idxs): batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index return batch @dataclass class RewardDataCollatorWithPadding: r""" Reward DataCollator class that pads the inputs to the maximum length of the batch. Args: tokenizer (`PreTrainedTokenizerBase`): The tokenizer used for encoding the data. padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): padding_strategy to pass to the tokenizer. max_length (`Optional[int]`, `optional`, defaults to `None`): The maximum length of the sequence to be processed. pad_to_multiple_of (`Optional[int]`, `optional`, defaults to `None`): If set will pad the sequence to a multiple of the provided value. return_tensors (`str`, `optional`, defaults to `"pt"`): The tensor type to use. """ tokenizer: PreTrainedTokenizerBase padding: Union[bool, str] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: features_chosen = [] features_rejected = [] margin = [] # check if we have a margin. If we do, we need to batch it as well has_margin = "margin" in features[0] for feature in features: # check if the keys are named as expected if ( "input_ids_chosen" not in feature or "input_ids_rejected" not in feature or "attention_mask_chosen" not in feature or "attention_mask_rejected" not in feature ): raise ValueError( "The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`" ) features_chosen.append( { "input_ids": feature["input_ids_chosen"], "attention_mask": feature["attention_mask_chosen"], } ) features_rejected.append( { "input_ids": feature["input_ids_rejected"], "attention_mask": feature["attention_mask_rejected"], } ) if has_margin: margin.append(feature["margin"]) batch_chosen = self.tokenizer.pad( features_chosen, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, ) batch_rejected = self.tokenizer.pad( features_rejected, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, ) batch = { "input_ids_chosen": batch_chosen["input_ids"], "attention_mask_chosen": batch_chosen["attention_mask"], "input_ids_rejected": batch_rejected["input_ids"], "attention_mask_rejected": batch_rejected["attention_mask"], "return_loss": True, } if has_margin: margin = torch.tensor(margin, dtype=torch.float) batch["margin"] = margin return batch @dataclass class DPODataCollatorWithPadding: r""" DPO DataCollator class that pads the inputs to the maximum length of the batch. Args: tokenizer (`PreTrainedTokenizerBase`): The tokenizer used for encoding the data. model (Optional[`PreTrainedModel`]): The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to prepare the *decoder_input_ids*. padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): padding_strategy to pass to the tokenizer. max_length (`Optional[int]`, `optional`, defaults to `None`): The maximum length of the sequence to be processed. max_prompt_length (`Optional[int]`, `optional`, defaults to `None`): The maximum length of the prompt to be processed. label_pad_token_id (`int`, defaults to -100): The label used for masking. padding_value (`int`, defaults to 0): The value used for padding. is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): Whether or not you model has an encoder_decoder architecture. max_target_length (`Optional[int]`, `optional`, defaults to `None`): The maximum length of the target to be processed. Only useful for encoder-decoder architectures. truncation_mode: (`str`, defaults to "keep_end"): The truncation mode to use when truncating the prompt. """ tokenizer: PreTrainedTokenizerBase model: Optional[PreTrainedModel] = None padding: Union[bool, str] = True max_length: Optional[int] = None max_prompt_length: Optional[int] = None label_pad_token_id: int = -100 padding_value: int = 0 truncation_mode: str = "keep_end" is_encoder_decoder: Optional[bool] = False max_target_length: Optional[int] = None def tokenize_batch_element( self, prompt: str, chosen: str, rejected: str, ) -> Dict: """Tokenize a single batch element. At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, we truncate the chosen/rejected. We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens. """ batch = {} if not self.is_encoder_decoder: chosen_tokens = self.tokenizer(chosen, add_special_tokens=False) rejected_tokens = self.tokenizer(rejected, add_special_tokens=False) prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) eos_token_id = self.tokenizer.eos_token_id # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0) eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id] # attention mask these indices to eos_token_id new_attention_mask = [ 0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"]) ] prompt_tokens["attention_mask"] = new_attention_mask # do the same for chosen and rejected eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id] new_attention_mask_c = [ 0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"]) ] chosen_tokens["attention_mask"] = new_attention_mask_c eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id] new_attention_mask_r = [ 0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"]) ] rejected_tokens["attention_mask"] = new_attention_mask_r # add EOS token to end of prompt chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) chosen_tokens["attention_mask"].append(1) rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) rejected_tokens["attention_mask"].append(1) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) # if combined sequence is too long, truncate the prompt if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: if self.truncation_mode == "keep_start": prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} elif self.truncation_mode == "keep_end": prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} else: raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") # if that's still too long, truncate the response if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()} rejected_tokens = { k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items() } # Create labels chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens} rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens} chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( prompt_tokens["input_ids"] ) rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( prompt_tokens["input_ids"] ) for k, toks in { "chosen": chosen_sequence_tokens, "rejected": rejected_sequence_tokens, "prompt": prompt_tokens, }.items(): for type_key, tokens in toks.items(): if type_key == "token_type_ids": continue batch[f"{k}_{type_key}"] = tokens else: chosen_tokens = self.tokenizer( chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True ) rejected_tokens = self.tokenizer( rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True ) prompt_tokens = self.tokenizer( prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True ) batch["chosen_labels"] = chosen_tokens["input_ids"] batch["rejected_labels"] = rejected_tokens["input_ids"] batch["prompt_input_ids"] = prompt_tokens["input_ids"] batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"): batch["rejected_decoder_input_ids"] = self.model.prepare_decoder_input_ids_from_labels( labels=batch["rejected_labels"] ) batch["chosen_decoder_input_ids"] = self.model.prepare_decoder_input_ids_from_labels( labels=batch["chosen_labels"] ) batch["prompt"] = prompt batch["chosen"] = prompt + chosen batch["rejected"] = prompt + rejected batch["chosen_response_only"] = chosen batch["rejected_response_only"] = rejected return batch def collate(self, batch): # first, pad everything to the same length padded_batch = {} for k in batch[0].keys(): if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): if self.is_encoder_decoder: to_pad = [torch.LongTensor(ex[k]) for ex in batch] if (k.startswith("prompt")) and (k.endswith("input_ids")): padding_value = self.tokenizer.pad_token_id elif k.endswith("_attention_mask"): padding_value = 0 elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): padding_value = self.label_pad_token_id else: raise ValueError(f"Unexpected key in batch '{k}'") padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) else: # adapted from https://stackoverflow.com/questions/73256206 if "prompt" in k: to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] else: to_pad = [torch.LongTensor(ex[k]) for ex in batch] if k.endswith("_input_ids"): padding_value = self.tokenizer.pad_token_id elif k.endswith("_labels"): padding_value = self.label_pad_token_id elif k.endswith("_attention_mask"): padding_value = self.padding_value else: raise ValueError(f"Unexpected key in batch '{k}'") padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) # for the prompt, flip back so padding is on left side if "prompt" in k: padded_batch[k] = padded_batch[k].flip(dims=[1]) else: padded_batch[k] = [ex[k] for ex in batch] return padded_batch def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: tokenized_batch = [] for feature in features: prompt = feature["prompt"] chosen = feature["chosen"] rejected = feature["rejected"] batch_element = self.tokenize_batch_element(prompt, chosen, rejected) tokenized_batch.append(batch_element) # return collated batch return self.collate(tokenized_batch) class ConstantLengthDataset(IterableDataset): """ Iterable dataset that returns constant length chunks of tokens from stream of text files. The dataset also formats the text before tokenization with a specific format that is provided by the user. Args: tokenizer (`transformers.PreTrainedTokenizer`): The processor used for processing the data. dataset (`dataset.Dataset`): Dataset with text files. dataset_text_field (`str`, **optional**): Name of the field in the dataset that contains the text. Used only if `formatting_func` is `None`. formatting_func (`Callable`, **optional**): Function that formats the text before tokenization. Usually it is recommended to have follows a certain pattern such as `"### Question: {question}\n ### Answer: {answer}\n"` infinite (`bool`, *optional*, defaults to `False`): If True the iterator is reset after dataset reaches end else stops. seq_length (`int`, *optional*, defaults to `1024`): Length of token sequences to return. num_of_sequences (`int`, *optional*, defaults to `1024`): Number of token sequences to keep in buffer. chars_per_token (`int`, *optional*, defaults to `3.6`): Number of characters per token used to estimate number of tokens in text buffer. eos_token_id (`int`, *optional*, defaults to `0`): Id of the end of sequence token if the passed tokenizer does not have an EOS token. shuffle ('bool', *optional*, defaults to True) Shuffle the examples before they are returned """ def __init__( self, tokenizer, dataset, dataset_text_field=None, formatting_func=None, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6, eos_token_id=0, shuffle=True, ): self.tokenizer = tokenizer if tokenizer.eos_token_id is None: warnings.warn( "The passed tokenizer does not have an EOS token. We will use the passed eos_token_id instead which corresponds" f" to {eos_token_id}. If this is not the correct EOS token, make sure to pass the correct eos_token_id." ) self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id self.dataset = dataset self.seq_length = seq_length self.infinite = infinite self.current_size = 0 self.max_buffer_size = seq_length * chars_per_token * num_of_sequences self.shuffle = shuffle if formatting_func is None: self.formatting_func = lambda x: x[dataset_text_field] else: self.formatting_func = formatting_func if formatting_func is not None: if formatting_func.__code__.co_argcount > 1: warnings.warn( "The passed formatting_func has more than one argument. Usually that function should have a single argument `example`" " which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing." ) def __len__(self): return len(self.dataset) def __iter__(self): iterator = iter(self.dataset) more_examples = True while more_examples: buffer, buffer_len = [], 0 while True: if buffer_len >= self.max_buffer_size: break try: buffer.append(self.formatting_func(next(iterator))) buffer_len += len(buffer[-1]) except StopIteration: if self.infinite: iterator = iter(self.dataset) warnings.warn("The dataset reached end and the iterator is reset to the start.") else: more_examples = False break tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] all_token_ids = [] for tokenized_input in tokenized_inputs: all_token_ids.extend(tokenized_input + [self.concat_token_id]) examples = [] for i in range(0, len(all_token_ids), self.seq_length): input_ids = all_token_ids[i : i + self.seq_length] if len(input_ids) == self.seq_length: examples.append(input_ids) if self.shuffle: random.shuffle(examples) for example in examples: self.current_size += 1 yield { "input_ids": torch.LongTensor(example), "labels": torch.LongTensor(example), } class PeftSavingCallback(TrainerCallback): def on_save(self, args, state, control, **kwargs): if args.should_save: checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") kwargs["model"].save_pretrained(checkpoint_path) if "pytorch_model.bin" in os.listdir(checkpoint_path): os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) class RunningMoments: def __init__(self, accelerator): """ Calculates the running mean and standard deviation of a data stream. Reference: https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75 """ self.mean = 0 self.std = 1 self.var = 1 self.count = 1e-24 self.accelerator = accelerator @torch.no_grad() def update(self, xs: torch.Tensor) -> Tuple[float, float]: """ Updates running moments from batch's moments computed across ranks """ if self.accelerator.use_distributed: xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs) else: xs_count = xs.numel() xs_var, xs_mean = torch.var_mean(xs, unbiased=False) xs_mean, xs_var = xs_mean.float(), xs_var.float() delta = xs_mean - self.mean tot_count = self.count + xs_count new_sum = xs_var * xs_count # correct old_sum deviation accounting for the new mean old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count tot_sum = old_sum + new_sum self.mean += delta * xs_count / tot_count self.var = tot_sum / tot_count self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt() self.count = tot_count return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item() @torch.no_grad() def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]: """ Computes element-wise mean and variance of the tensor across processes. Reference: https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75 """ xs = xs.to(accelerator.device) sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device) sum_and_count = accelerator.reduce(sum_and_count) global_sum, count = sum_and_count global_mean = global_sum / count sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask)) sum_var = accelerator.reduce(sum_var) global_var = sum_var / count return global_mean.to(device), global_var.to(device), count.to(device) def compute_accuracy(eval_pred) -> Dict[str, float]: predictions, labels = eval_pred # Here, predictions is rewards_chosen and rewards_rejected. # We want to see how much of the time rewards_chosen > rewards_rejected. if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0: warnings.warn( f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading." ) predictions = np.argmax(predictions, axis=1) accuracy = np.array(predictions == labels, dtype=float).mean().item() return {"accuracy": accuracy} def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: if tensor.size(dim) >= length: return tensor else: pad_size = list(tensor.shape) pad_size[dim] = length - tensor.size(dim) return torch.cat( [ tensor, pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), ], dim=dim, ) def disable_dropout_in_model(model: torch.nn.Module) -> None: for module in model.modules(): if isinstance(module, torch.nn.Dropout): module.p = 0 def exact_div(a, b, a_str, b_str, custom_error_message=""): q = a // b if a != q * b: raise ValueError(f"{custom_error_message}, {a_str}={a}, {b_str}={b}, inexact division: {a} / {b} = {a / b}") return q # copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5 class PerPromptStatTracker: r""" Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm Args: buffer_size (`int`): Size of the buffer to keep for each prompt. min_count (`int`): Minimum number of samples to keep in the buffer before calculating the mean and std. """ def __init__(self, buffer_size, min_count): self.buffer_size = buffer_size self.min_count = min_count self.stats = {} def update(self, prompts, rewards): prompts = np.array(prompts) rewards = np.array(rewards) unique = np.unique(prompts) advantages = np.empty_like(rewards) for prompt in unique: prompt_rewards = rewards[prompts == prompt] if prompt not in self.stats: self.stats[prompt] = deque(maxlen=self.buffer_size) self.stats[prompt].extend(prompt_rewards) if len(self.stats[prompt]) < self.min_count: mean = np.mean(rewards) std = np.std(rewards) + 1e-6 else: mean = np.mean(self.stats[prompt]) std = np.std(self.stats[prompt]) + 1e-6 advantages[prompts == prompt] = (prompt_rewards - mean) / std return advantages def get_stats(self): return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()} def neftune_post_forward_hook(module, input, output): """ Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding layers. This method is slightly adapted from the original source code that can be found here: https://github.com/neelsjain/NEFTune Simply add it to your model as follows: ```python model = ... model.embed_tokens.neftune_noise_alpha = 0.1 model.embed_tokens.register_forward_hook(neftune_post_forward_hook) ``` Args: module (`torch.nn.Module`): The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to the desired noise alpha value. input (`torch.Tensor`): The input tensor to the model. output (`torch.Tensor`): The output tensor of the model (i.e. the embeddings). """ if module.training: dims = torch.tensor(output.size(1) * output.size(2)) mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) return output