File size: 2,799 Bytes
30c14cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | #
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
from typing import List
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
def get_strategy(args):
from openrlhf.utils.deepspeed import DeepspeedStrategy
strategy = DeepspeedStrategy(
seed=getattr(args, "seed", 42),
full_determinism=getattr(args, "full_determinism", False),
max_norm=getattr(args, "max_norm", 1.0),
micro_train_batch_size=getattr(args, "micro_train_batch_size", 1),
train_batch_size=getattr(args, "train_batch_size", 128),
zero_stage=args.zero_stage,
bf16=getattr(args, "bf16", True),
args=args,
)
return strategy
def get_tokenizer(pretrain, model, padding_side="left", strategy=None, use_fast=True):
tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast)
tokenizer.padding_side = padding_side
# NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM.
# https://github.com/facebookresearch/llama-recipes/pull/196
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
if model is not None:
model.config.pad_token_id = tokenizer.pad_token_id
return tokenizer
def convert_token_to_id(token, tokenizer):
if isinstance(token, str):
token = tokenizer.encode(token, add_special_tokens=False)
assert len(token) == 1
return token[0]
else:
raise ValueError("token should be int or str")
def zero_pad_sequences(
sequences: List[torch.Tensor], side: str = "left", value: int = 0, stack: bool = False
) -> torch.Tensor:
assert side in ("left", "right")
max_len = max(seq.size(-1) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(-1)
padding = (pad_len, 0) if side == "left" else (0, pad_len)
padded_sequences.append(F.pad(seq, padding, value=value))
if stack:
return torch.stack(padded_sequences, dim=0)
else:
return torch.cat(padded_sequences, dim=0)
def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor):
"""Remove the pad token. Return tensors and not lists.
Args:
input_ids shape: [bs, seq_length]
attention_mask shape: [bs, seq_length]
Returns:
no_padding_batch(List[Tensor[int]]): contains the rmpad token ids per query.
"""
no_padding_batch = []
for ids, mask in zip(input_ids, attention_mask):
# Fix for both left and right padding
no_padding_batch.append((ids[mask.bool()]))
return no_padding_batch
|