| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from transformers import AutoTokenizer |
| |
|
| |
|
| | class TokenizerMetaMath: |
| | PROMPT_NO_INPUT = ( |
| | "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" |
| | "### Instruction:\n{query}\n\n### Response: " |
| | ) |
| | PROMPT = ( |
| | "Below is an instruction that describes a task, paired with an input that provides further context. " |
| | "Write a response that appropriately completes the request.\n\n" |
| | "### Instruction:\n{query}\n\n### Input:\n{input}\n\n### Response: " |
| | ) |
| |
|
| | def format_prompt(self, query): |
| | query = query.split("\n", 1) |
| | if len(query) == 1 or query[1].strip("\n") == "": |
| | return self.PROMPT_NO_INPUT.format(query=query[0]) |
| | else: |
| | return self.PROMPT.format(query=query[0], input=query[1]) |
| |
|
| | def __init__(self, tokenizer_path): |
| | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| |
|
| | def __call__(self, examples): |
| | prompts = [self.format_prompt(text) for text in examples["query"]] |
| | completions = examples["response"] |
| | return self._tokenize_fn(prompts, completions) |
| |
|
| | def _tokenize_fn(self, prompts, completions): |
| | prompt_tokens = self.tokenizer(prompts, add_special_tokens=False)["input_ids"] |
| | input_tokens = self.tokenizer([x + y for x, y in zip(prompts, completions)], add_special_tokens=False)[ |
| | "input_ids" |
| | ] |
| | input_tokens = [[self.tokenizer.bos_token_id] + x + [self.tokenizer.eos_token_id] for x in input_tokens] |
| | prompt_length = [len(x) + 1 for x in prompt_tokens] |
| | input_length = [len(x) for x in input_tokens] |
| | return {"input_ids": input_tokens, "prompt_length": prompt_length, "input_length": input_length} |
| |
|
| |
|
| | class DataCollator: |
| | def __init__(self, eos_token_id, max_length=None): |
| | self.eos_token_id = eos_token_id |
| | self.max_length = max_length |
| |
|
| | def __call__(self, batch): |
| | batch = {k: [item[k] for item in batch] for k in batch[0]} |
| | input_lengths = torch.stack(batch["input_length"]) |
| | prompt_lengths = torch.stack(batch["prompt_length"]) |
| | input_ids = torch.nn.utils.rnn.pad_sequence( |
| | batch["input_ids"], batch_first=True, padding_value=self.eos_token_id |
| | ) |
| | col_indices = torch.arange(input_ids.size(1)).unsqueeze(0) |
| | attention_mask = col_indices < input_lengths.unsqueeze(1) |
| | label_mask = torch.logical_or(col_indices < prompt_lengths.unsqueeze(1), ~attention_mask) |
| | labels = input_ids.masked_fill(label_mask, -100) |
| | if self.max_length is not None: |
| | input_ids = input_ids[:, : self.max_length] |
| | attention_mask = attention_mask[:, : self.max_length] |
| | labels = labels[:, : self.max_length] |
| | return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} |
| |
|