Spaces:
No application file
No application file
| import torch | |
| import torch.nn.functional as F | |
| from transformers import Trainer | |
| from transformers import DefaultDataCollator | |
| import random | |
| from tqdm import tqdm | |
| import pickle | |
| import torch.distributed as dist | |
| def one_infill_process(input_ids, pad_token, prefix_end, suffix_start, instruction_tokens, prefix_delimiters, suffix_delimiters): | |
| """ | |
| Performs a single fill-in-the-middle (FIM) transformation on a sequence. | |
| This function takes a sequence of token IDs and reformats it for a text infilling task. | |
| It rearranges the sequence into the FIM format: | |
| `[instruction_tokens] [prefix_delimiters[0]] [prefix] [prefix_delimiters[1]] [suffix_delimiters[0]] [suffix] [suffix_delimiters[1]] [middle]` | |
| where: | |
| - `instruction_tokens`: Special instruction tokens at the beginning | |
| - `prefix`: The part of the sequence before the selected span | |
| - `middle`: The selected span of text to be "filled in" | |
| - `suffix`: The part of the sequence after the selected span | |
| - `prefix_delimiters`: Pair of tokens that wrap the prefix | |
| - `suffix_delimiters`: Pair of tokens that wrap the suffix | |
| Args: | |
| input_ids (torch.Tensor): A sequence of input token IDs. | |
| pad_token (int): The ID of the padding token. | |
| prefix_end (int): The end index for the prefix. | |
| suffix_start (int): The start index for the suffix. | |
| instruction_tokens (list[int]): A list of instruction token IDs at the beginning. | |
| prefix_delimiters (list[int]): A list of two token IDs to wrap the prefix. | |
| suffix_delimiters (list[int]): A list of two token IDs to wrap the suffix. | |
| Returns: | |
| tuple[torch.Tensor, torch.Tensor]: | |
| - new_sample (torch.Tensor): The transformed sequence with the FIM format. | |
| - prompt_indices (torch.Tensor): A boolean mask indicating the prompt tokens in `new_sample`. | |
| """ | |
| device = input_ids.device | |
| instruction_len = len(instruction_tokens) | |
| prefix_open_len = len(prefix_delimiters[0]) | |
| prefix_close_len = len(prefix_delimiters[1]) | |
| suffix_open_len = len(suffix_delimiters[0]) | |
| suffix_close_len = len(suffix_delimiters[1]) | |
| input_len = (input_ids != pad_token).sum() | |
| instruction_tokens = torch.tensor(instruction_tokens, dtype=input_ids.dtype, device=device) | |
| prefix_open_delim = torch.tensor(prefix_delimiters[0], dtype=input_ids.dtype, device=device) | |
| prefix_close_delim = torch.tensor(prefix_delimiters[1], dtype=input_ids.dtype, device=device) | |
| suffix_open_delim = torch.tensor(suffix_delimiters[0], dtype=input_ids.dtype, device=device) | |
| suffix_close_delim = torch.tensor(suffix_delimiters[1], dtype=input_ids.dtype, device=device) | |
| new_sample = torch.full((input_ids.shape[0],), pad_token, dtype=input_ids.dtype, device=device) | |
| new_sample[:instruction_len] = instruction_tokens | |
| new_sample[instruction_len:instruction_len + prefix_open_len] = prefix_open_delim | |
| new_sample[instruction_len + prefix_open_len:instruction_len + prefix_open_len + prefix_end] = input_ids[:prefix_end] | |
| new_sample[instruction_len + prefix_open_len + prefix_end:instruction_len + prefix_open_len + prefix_end + prefix_close_len] = prefix_close_delim | |
| suffix_offset = instruction_len + prefix_open_len + prefix_end + prefix_close_len | |
| new_sample[suffix_offset:suffix_offset + suffix_open_len] = suffix_open_delim | |
| new_sample[suffix_offset + suffix_open_len:suffix_offset + suffix_open_len + (input_len - suffix_start)] = input_ids[suffix_start:input_len] | |
| new_sample[suffix_offset + suffix_open_len + (input_len - suffix_start):suffix_offset + suffix_open_len + (input_len - suffix_start) + suffix_close_len] = suffix_close_delim | |
| middle_start = suffix_offset + suffix_open_len + (input_len - suffix_start) + suffix_close_len | |
| new_sample[middle_start:middle_start + (suffix_start - prefix_end)] = input_ids[prefix_end:suffix_start] | |
| prompt_indices = torch.ones_like(new_sample, dtype=torch.bool) | |
| prompt_indices[:middle_start] = True | |
| prompt_indices[middle_start:] = False | |
| return new_sample, prompt_indices | |
| def vectorized_infill_process(input_ids, pad_token, prefix_cutoff, instruction_tokens, prefix_delimiters, suffix_delimiters): | |
| batch_size, _ = input_ids.shape | |
| device = input_ids.device | |
| input_lengths = (input_ids != pad_token).sum(dim=1) | |
| prefix_range = input_lengths - prefix_cutoff | |
| rand_prefix_ends = torch.rand(batch_size, device=device) * prefix_range | |
| prefix_ends = (prefix_cutoff + rand_prefix_ends).long() | |
| # Generate suffix_starts indices | |
| low_suffix_start = prefix_ends + 1 | |
| high_suffix_start = input_lengths | |
| suffix_range = high_suffix_start - low_suffix_start | |
| rand_suffix_starts = torch.rand(batch_size, device=device) * suffix_range | |
| suffix_starts = (low_suffix_start + rand_suffix_starts).long() | |
| new_samples, prompt_indices = [], [] | |
| for i in range(batch_size): | |
| new_sample, prompt_index = one_infill_process( | |
| input_ids[i], | |
| pad_token, | |
| prefix_ends[i], | |
| suffix_starts[i], | |
| instruction_tokens, | |
| prefix_delimiters, | |
| suffix_delimiters | |
| ) | |
| new_samples.append(new_sample) | |
| prompt_indices.append(prompt_index) | |
| new_samples = torch.stack(new_samples) | |
| prompt_indices = torch.stack(prompt_indices) | |
| return new_samples, prompt_indices | |
| class dLLMDataCollator(DefaultDataCollator): | |
| """ | |
| Adds the forward noising process to the batch. | |
| Modify forward_process to change the noise schedule | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| self.mask_token_id = kwargs["tokenizer"].mask_token_id | |
| self.tokenizer = kwargs["tokenizer"] | |
| if "max_length" in kwargs: | |
| self.max_length = kwargs["max_length"] | |
| if kwargs["tokenizer"].mask_token_id is None: | |
| assert ( | |
| "mask_token_id" in kwargs | |
| ), "For dLLM models, pass a mask_token_id or set it equal to tokenizer.mask_token_id" | |
| self.mask_token_id = kwargs["mask_token_id"] | |
| # Optional infill parameters | |
| self.is_infill_task = kwargs.get("is_infill_task", False) | |
| if self.is_infill_task: | |
| self.instruction_tokens = kwargs.get("instruction_tokens", []) | |
| self.prefix_delimiters = kwargs.get("prefix_delimiters", []) | |
| self.suffix_delimiters = kwargs.get("suffix_delimiters", []) | |
| # Flatten delimiter lists if they are nested | |
| if self.prefix_delimiters and isinstance(self.prefix_delimiters[0], list): | |
| self.prefix_delimiters = [token for sublist in self.prefix_delimiters for token in sublist] | |
| if self.suffix_delimiters and isinstance(self.suffix_delimiters[0], list): | |
| self.suffix_delimiters = [token for sublist in self.suffix_delimiters for token in sublist] | |
| def forward_process(self, batch, eps=1e-3): | |
| input_ids = batch["input_ids"] | |
| B, N = input_ids.shape | |
| if "t" not in batch: | |
| t = torch.rand((B,), device=input_ids.device) | |
| else: | |
| t = batch["t"] | |
| t = (1 - eps) * t + eps | |
| t = t[:, None].repeat(1, N) | |
| mask_indices = torch.rand((B, N), device=input_ids.device) < t | |
| noisy_batch = torch.where(mask_indices, self.mask_token_id, input_ids) | |
| return noisy_batch, t, mask_indices | |
| def __call__(self, batch): | |
| batch = super().__call__(batch) | |
| # Pad input_ids to max_length before processing | |
| if hasattr(self, 'max_length'): | |
| batch = self.tokenizer.pad(batch, | |
| padding = "max_length", | |
| max_length = self.max_length, | |
| return_tensors = "pt" | |
| ) | |
| print(batch["input_ids"].shape) | |
| batch["labels"] = batch["input_ids"].clone() | |
| # Apply infill transformation if enabled | |
| if self.is_infill_task: | |
| batch["input_ids"], infill_prompt_indices = vectorized_infill_process( | |
| batch["input_ids"], | |
| self.tokenizer.pad_token_id, | |
| batch["prefix_cutoff"], | |
| self.instruction_tokens, | |
| self.prefix_delimiters, | |
| self.suffix_delimiters | |
| ) | |
| batch["labels"] = batch["input_ids"].clone() | |
| noisy_batch, batch["t"], mask_indices = self.forward_process(batch) | |
| batch["labels"][~mask_indices] = -100 | |
| batch["num_prompt_tokens"] = 0 | |
| if "prompt_lengths" in batch: | |
| prompt_lengths = batch.pop("prompt_lengths") | |
| prompt_lengths = prompt_lengths.unsqueeze(1) # (B, 1) | |
| prompt_length_indices = torch.arange(noisy_batch.shape[1]).unsqueeze(0) # (1, L) | |
| # mask the prompt tokens | |
| prompt_mask = prompt_length_indices < prompt_lengths # (B, L) | |
| noisy_batch[prompt_mask] = batch["input_ids"][prompt_mask].clone() | |
| batch["labels"][prompt_mask] = -100 | |
| batch["num_prompt_tokens"] = prompt_mask.sum() | |
| elif self.is_infill_task: | |
| # Use infill prompt indices to mask prompt tokens | |
| noisy_batch[infill_prompt_indices] = batch["input_ids"][infill_prompt_indices].clone() | |
| batch["labels"][infill_prompt_indices] = -100 | |
| batch["num_prompt_tokens"] = infill_prompt_indices.sum() | |
| batch["input_ids"] = noisy_batch.long() | |
| return batch | |
| class dLLMTrainer(Trainer): | |
| def compute_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False): | |
| """ | |
| Absorbing state diffusion loss computation | |
| NOTE: time step t here is different from ours | |
| """ | |
| normalize_constant = 4096 | |
| batch_size = inputs["input_ids"].size(0) | |
| labels, t, num_prompt_tokens = inputs.pop("labels"), inputs.pop("t"), inputs.pop("num_prompt_tokens") | |
| if "prefix_cutoff" in inputs: | |
| inputs.pop("prefix_cutoff") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| unscaled_loss = F.cross_entropy( | |
| logits.view(-1, logits.shape[-1]), labels.view(-1), reduction="none" | |
| ).view(logits.shape[0], -1) | |
| if (self.state.global_step + 1) % self.args.logging_steps == 0: | |
| self.log({"unscaled_loss": (unscaled_loss.sum() / (labels != -100).sum()).item()}) | |
| loss = unscaled_loss / t | |
| loss = loss.sum() / (batch_size * normalize_constant) | |
| # double-check debug | |
| if return_outputs: | |
| print("Retuning outputs") | |
| return loss, {"dummy": None} | |
| return loss | |
| class dLLMSFTDataset(torch.utils.data.Dataset): | |
| """ | |
| Similar to AR datasets, except in inference, we keep the timsteps fixed | |
| """ | |
| def __init__(self, data, tokenizer, max_length, eval=False): | |
| super().__init__() | |
| self.data = data | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.eval = eval | |
| if self.eval: | |
| self.t = torch.linspace(0, 1, len(self.data)) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| out = self.data[idx] | |
| if self.eval: | |
| out["t"] = self.t[idx] | |
| return out | |
| class dLLMDataCollator(DefaultDataCollator): | |
| """ | |
| Adds the forward noising process to the batch. | |
| Modify forward_process to change the noise schedule | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| self.mask_token_id = kwargs["tokenizer"].mask_token_id | |
| self.tokenizer = kwargs["tokenizer"] | |
| if "max_length" in kwargs: | |
| self.max_length = kwargs["max_length"] | |
| if kwargs["tokenizer"].mask_token_id is None: | |
| assert ( | |
| "mask_token_id" in kwargs | |
| ), "For dLLM models, pass a mask_token_id or set it equal to tokenizer.mask_token_id" | |
| self.mask_token_id = kwargs["mask_token_id"] | |
| # Optional infill parameters | |
| self.is_infill_task = kwargs.get("is_infill_task", False) | |
| if self.is_infill_task: | |
| self.instruction_tokens = kwargs.get("instruction_tokens", []) | |
| self.prefix_delimiters = kwargs.get("prefix_delimiters", []) | |
| self.suffix_delimiters = kwargs.get("suffix_delimiters", []) | |
| def forward_process(self, batch, eps=1e-3): | |
| input_ids = batch["input_ids"] | |
| B, N = input_ids.shape | |
| if "t" not in batch: | |
| t = torch.rand((B,), device=input_ids.device) | |
| else: | |
| t = batch["t"] | |
| t = (1 - eps) * t + eps | |
| t = t[:, None].repeat(1, N) | |
| mask_indices = torch.rand((B, N), device=input_ids.device) < t | |
| noisy_batch = torch.where(mask_indices, self.mask_token_id, input_ids) | |
| return noisy_batch, t, mask_indices | |
| def __call__(self, batch): | |
| batch = super().__call__(batch) | |
| # Pad input_ids to max_length before processing | |
| if hasattr(self, 'max_length'): | |
| batch = self.tokenizer.pad(batch, | |
| padding = "max_length", | |
| max_length = self.max_length, | |
| return_tensors = "pt" | |
| ) | |
| batch["labels"] = batch["input_ids"].clone() | |
| # Apply infill transformation if enabled | |
| if self.is_infill_task: | |
| batch["input_ids"], infill_prompt_indices = vectorized_infill_process( | |
| batch["input_ids"], | |
| self.tokenizer.pad_token_id, | |
| batch["prefix_cutoff"], | |
| self.instruction_tokens, | |
| self.prefix_delimiters, | |
| self.suffix_delimiters | |
| ) | |
| batch["labels"] = batch["input_ids"].clone() | |
| noisy_batch, batch["t"], mask_indices = self.forward_process(batch) | |
| batch["labels"][~mask_indices] = -100 | |
| batch["num_prompt_tokens"] = 0 | |
| if "prompt_lengths" in batch: | |
| prompt_lengths = batch.pop("prompt_lengths") | |
| prompt_lengths = prompt_lengths.unsqueeze(1) # (B, 1) | |
| prompt_length_indices = torch.arange(noisy_batch.shape[1]).unsqueeze(0) # (1, L) | |
| # mask the prompt tokens | |
| prompt_mask = prompt_length_indices < prompt_lengths # (B, L) | |
| noisy_batch[prompt_mask] = batch["input_ids"][prompt_mask].clone() | |
| batch["labels"][prompt_mask] = -100 | |
| batch["num_prompt_tokens"] = prompt_mask.sum() | |
| elif self.is_infill_task: | |
| # Use infill prompt indices to mask prompt tokens | |
| noisy_batch[infill_prompt_indices] = batch["input_ids"][infill_prompt_indices].clone() | |
| batch["labels"][infill_prompt_indices] = -100 | |
| batch["num_prompt_tokens"] = infill_prompt_indices.sum() | |
| batch["input_ids"] = noisy_batch.long() | |
| return batch | |
| SYSTEM_PROMPT = """ | |
| Respond in the following format: | |
| <reasoning> | |
| Your reasoning here | |
| </reasoning> | |
| <answer> | |
| ... | |
| </answer> | |
| """ | |
| def preprocess_dataset(data, tokenizer, max_length, test_split=0.01): | |
| preprocessed_data = [] | |
| # TODO: check if the pad_token = mask_token | |
| for i in tqdm(range(len(data)), desc="Preprocessing dataset"): | |
| question = SYSTEM_PROMPT + "\n\n" + data[i]["question"] | |
| trajectory = f"<reasoning>{data[i]['thinking_trajectories'][0]}</reasoning>\n<answer>{data[i]['attempt']}</answer>" | |
| prompt = [{"role": "user", "content": question}] | |
| response = [{"role": "assistant", "content": trajectory}] | |
| inputs = tokenizer.apply_chat_template(prompt + response, tokenize=False) | |
| prompt = tokenizer.apply_chat_template(prompt, tokenize=False) + "\n" | |
| tokenized_input = tokenizer( | |
| inputs, return_tensors="pt", truncation=True, max_length=max_length, padding="max_length" | |
| ).input_ids.squeeze(0) | |
| num_tokens = tokenized_input.shape[0] | |
| tokenized_prompt = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length) | |
| preprocessed_data.append( | |
| { | |
| "input_ids": tokenized_input, | |
| "prompt_lengths": tokenized_prompt.attention_mask.sum(-1), | |
| } | |
| ) | |
| random.shuffle(preprocessed_data) | |
| test_data = preprocessed_data[: int(len(preprocessed_data) * test_split)] | |
| train_data = preprocessed_data[int(len(preprocessed_data) * test_split) :] | |
| return train_data, test_data | |