| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import gc |
| | import random |
| | import warnings |
| | from contextlib import contextmanager |
| | from typing import Dict, List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.nn.utils.rnn import pad_sequence |
| |
|
| | |
| |
|
| | from .import_utils import is_npu_available, is_xpu_available |
| |
|
| |
|
| | try: |
| | from collections.abc import Mapping |
| | except ImportError: |
| | from collections import Mapping |
| |
|
| |
|
| | WANDB_PADDING = -1 |
| |
|
| |
|
| | def top_k_top_p_filtering( |
| | logits: torch.FloatTensor, |
| | top_k: int = 0, |
| | top_p: float = 1.0, |
| | filter_value: float = -float("Inf"), |
| | min_tokens_to_keep: int = 1, |
| | ) -> torch.FloatTensor: |
| | """ |
| | Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. |
| | |
| | Args: |
| | logits: logits distribution shape (batch size, vocabulary size) |
| | top_k (`int`, *optional*, defaults to 0): |
| | If > 0, only keep the top k tokens with highest probability (top-k filtering) |
| | top_p (`float`, *optional*, defaults to 1.0): |
| | If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus |
| | filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
| | min_tokens_to_keep (`int`, *optional*, defaults to 1): |
| | Minimumber of tokens we keep per batch example in the output. |
| | |
| | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
| | """ |
| |
|
| | if top_k > 0: |
| | logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(None, logits) |
| |
|
| | if 0 <= top_p <= 1.0: |
| | logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(None, logits) |
| |
|
| | return logits |
| |
|
| |
|
| | def flatten_dict(nested: Dict, sep: str = "/") -> Dict: |
| | """Flatten dictionary and concatenate nested keys with separator.""" |
| |
|
| | def recurse(nest: Dict, prefix: str, into: Dict) -> None: |
| | for k, v in nest.items(): |
| | if sep in k: |
| | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") |
| | if isinstance(v, Mapping): |
| | recurse(v, prefix + k + sep, into) |
| | else: |
| | into[prefix + k] = v |
| |
|
| | flat = {} |
| | recurse(nested, "", flat) |
| | return flat |
| |
|
| |
|
| | def convert_to_scalar(stats: Dict) -> Dict: |
| | """ |
| | Converts the stats from a flattened dict to single scalar dicts |
| | """ |
| | tensorboard_stats = {} |
| | for k, v in stats.items(): |
| | |
| | |
| | if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and (len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1)): |
| | v = v.item() |
| | tensorboard_stats[k] = v |
| | return tensorboard_stats |
| |
|
| |
|
| | def stack_dicts(stats_dicts: List[Dict]) -> Dict: |
| | """Stack the values of a dict.""" |
| | results = dict() |
| | for k in stats_dicts[0]: |
| | stats_list = [torch.flatten(d[k]) for d in stats_dicts] |
| | results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING) |
| | return results |
| |
|
| |
|
| | def add_suffix(input_dict: Dict, suffix: str) -> Dict: |
| | """Add suffix to dict keys.""" |
| | return dict((k + suffix, v) for k, v in input_dict.items()) |
| |
|
| |
|
| | def pad_to_size(tensor: torch.Tensor, size: int, dim: int = 1, padding: int = 50256) -> torch.Tensor: |
| | """Pad tensor to size.""" |
| | t_size = tensor.size()[dim] |
| | if t_size == size: |
| | return tensor |
| | else: |
| | return torch.nn.functional.pad(tensor, (0, size - t_size), "constant", padding) |
| |
|
| |
|
| | def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor: |
| | """ |
| | See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 |
| | """ |
| | logp = F.log_softmax(logits, dim=2) |
| |
|
| | if not gather: |
| | return logp |
| | logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1) |
| | return logpy |
| |
|
| |
|
| | def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: |
| | """Whiten values.""" |
| | mean, var = torch.mean(values), torch.var(values) |
| | whitened = (values - mean) * torch.rsqrt(var + 1e-8) |
| | if not shift_mean: |
| | whitened += mean |
| | return whitened |
| |
|
| |
|
| | def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool = None) -> torch.Tensor: |
| | """Compute mean of tensor with a masked values.""" |
| | if axis is not None: |
| | return (values * mask).sum(axis=axis) / mask.sum(axis=axis) |
| | else: |
| | return (values * mask).sum() / mask.sum() |
| |
|
| |
|
| | def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: |
| | """Compute variance of tensor with masked values.""" |
| | mean = masked_mean(values, mask) |
| | centered_values = values - mean |
| | variance = masked_mean(centered_values**2, mask) |
| | if unbiased: |
| | mask_sum = mask.sum() |
| | if mask_sum == 0: |
| | raise ValueError("The sum of the mask is zero, which can happen when `mini_batch_size=1`;" "try increase the `mini_batch_size` or `gradient_accumulation_steps`") |
| | |
| | |
| | bessel_correction = mask_sum / (mask_sum - 1) |
| | variance = variance * bessel_correction |
| | return variance |
| |
|
| |
|
| | def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: |
| | """Whiten values with masked values.""" |
| | mean, var = masked_mean(values, mask), masked_var(values, mask) |
| | whitened = (values - mean) * torch.rsqrt(var + 1e-8) |
| | if not shift_mean: |
| | whitened += mean |
| | return whitened |
| |
|
| |
|
| | def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor: |
| | """ |
| | Tensor extension to torch.clamp |
| | https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713 |
| | """ |
| | clipped = torch.max(torch.min(x, tensor_max), tensor_min) |
| | return clipped |
| |
|
| |
|
| | def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor: |
| | """Calculate entropy from logits.""" |
| | pd = torch.nn.functional.softmax(logits, dim=-1) |
| | entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1) |
| | return entropy |
| |
|
| |
|
| | def average_torch_dicts(list_of_dicts: List[Dict]) -> Dict: |
| | """Average values of a list of dicts with torch tensors.""" |
| | average_dict = dict() |
| | for key in list_of_dicts[0].keys(): |
| | average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0) |
| | return average_dict |
| |
|
| |
|
| | def stats_to_np(stats_dict: Dict) -> Dict: |
| | """Cast all torch.tensors in dict to numpy arrays.""" |
| | new_dict = dict() |
| | for k, v in stats_dict.items(): |
| | if isinstance(v, torch.Tensor): |
| | new_dict[k] = v.detach().cpu() |
| | if new_dict[k].dtype == torch.bfloat16: |
| | new_dict[k] = new_dict[k].float() |
| | new_dict[k] = new_dict[k].numpy() |
| | else: |
| | new_dict[k] = v |
| | if np.isscalar(new_dict[k]): |
| | new_dict[k] = float(new_dict[k]) |
| | return new_dict |
| |
|
| |
|
| | def respond_to_batch(model: nn.Module, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0) -> torch.LongTensor: |
| | """Sample text from language model.""" |
| | input_ids = queries |
| | for i in range(txt_len): |
| | |
| | outputs = model(input_ids) |
| | next_token_logits = outputs[0][:, -1, :] |
| | next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) |
| | |
| | probs = F.softmax(next_token_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1).squeeze(1) |
| | input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) |
| | return input_ids[:, -txt_len:] |
| |
|
| |
|
| | def set_seed(seed: int) -> None: |
| | """ |
| | Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`. |
| | |
| | Args: |
| | seed (`int`): The seed to set. |
| | """ |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | if is_xpu_available(): |
| | torch.xpu.manual_seed_all(seed) |
| | elif is_npu_available(): |
| | torch.npu.manual_seed_all(seed) |
| | else: |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| |
|
| | class LengthSampler: |
| | """ |
| | Samples a length |
| | """ |
| |
|
| | def __init__(self, min_value: int, max_value: int): |
| | self.values = list(range(min_value, max_value)) |
| |
|
| | def __call__(self) -> int: |
| | return np.random.choice(self.values) |
| |
|
| |
|
| | class PPODecorators(object): |
| | optimize_device_cache = False |
| |
|
| | @classmethod |
| | @contextmanager |
| | def empty_device_cache(cls): |
| | yield |
| | if cls.optimize_device_cache: |
| | if is_xpu_available(): |
| | gc.collect() |
| | torch.xpu.empty_cache() |
| | gc.collect() |
| | elif is_npu_available(): |
| | gc.collect() |
| | torch.npu.empty_cache() |
| | gc.collect() |
| | elif torch.cuda.is_available(): |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| |
|
| | def randn_tensor( |
| | shape: Union[Tuple, List], |
| | generator: Optional[Union[List[torch.Generator], torch.Generator]] = None, |
| | device: Optional[torch.device] = None, |
| | dtype: Optional[torch.dtype] = None, |
| | layout: Optional[torch.layout] = None, |
| | ) -> torch.Tensor: |
| | """A helper function to create random tensors on the desired `device` with the desired `dtype`. When |
| | passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor |
| | is always created on the CPU. |
| | """ |
| | |
| | rand_device = device |
| | batch_size = shape[0] |
| |
|
| | layout = layout or torch.strided |
| | device = device or torch.device("cpu") |
| |
|
| | if generator is not None: |
| | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type |
| | if gen_device_type != device.type and gen_device_type == "cpu": |
| | rand_device = "cpu" |
| | if device != "mps": |
| | warnings.warn( |
| | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." |
| | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" |
| | f" slighly speed up this function by passing a generator that was created on the {device} device." |
| | ) |
| | elif gen_device_type != device.type and gen_device_type == "cuda": |
| | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") |
| |
|
| | |
| | if isinstance(generator, list) and len(generator) == 1: |
| | generator = generator[0] |
| |
|
| | if isinstance(generator, list): |
| | shape = (1,) + shape[1:] |
| | latents = [torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size)] |
| | latents = torch.cat(latents, dim=0).to(device) |
| | else: |
| | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) |
| |
|
| | return latents |
| |
|