| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | PreTrainedTokenizerBase, |
| | GenerationConfig, |
| | Qwen2_5_VLForConditionalGeneration, |
| | AutoProcessor |
| | |
| | ) |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from peft import PeftConfig, LoraConfig |
| |
|
| | import random |
| | from typing import Tuple, Optional, List, Union |
| | import logging |
| | import functools |
| |
|
| | from larm.task.base_model import BaseModel |
| | from larm.common.registry import registry |
| |
|
| | from .weaver import MemGenWeaver |
| | from .trigger import MemGenTrigger, NanoTrigger |
| | from .utils import ( |
| | CONVERSATION_TEMPLATE, |
| | load_state_dict_from_safetensor, |
| | fix_model_parameters, |
| | log_trainable_params, |
| | ) |
| |
|
| | |
| | def log_function_call(func): |
| | """Decorator to log function calls with blue color.""" |
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | func_name = func.__name__ |
| | |
| | return func(*args, **kwargs) |
| | return wrapper |
| |
|
| | |
| | def get_next_token(next_token_logits: torch.Tensor, do_sample: bool, temperature: float) -> torch.Tensor: |
| | """ |
| | Selects the next token from model logits. |
| | |
| | Two modes are supported: |
| | 1. Sampling mode (do_sample=True and temperature>0): |
| | Apply temperature scaling to the logits, compute a probability distribution with softmax, |
| | and randomly sample one token. |
| | 2. Greedy mode (do_sample=False or temperature==0): |
| | Select the token with the highest probability (argmax). |
| | |
| | Args: |
| | next_token_logits (torch.Tensor): |
| | Logits for the next tokens, shape [batch_size, vocab_size]. |
| | do_sample (bool): |
| | Whether to perform stochastic sampling. If False, greedy decoding is used. |
| | temperature (float): |
| | Sampling temperature. Higher values make the distribution flatter (more randomness), |
| | while lower values make it sharper (more deterministic). |
| | When set to 0, greedy decoding is enforced. |
| | |
| | Returns: |
| | torch.Tensor: |
| | The selected next token indices, shape [batch_size, 1]. |
| | """ |
| | if len(next_token_logits.shape) != 2: |
| | raise ValueError("Input logits must be a 2D tensor [batch_size, vocab_size]") |
| | if do_sample and temperature != 0: |
| | |
| | probs = F.softmax(next_token_logits / temperature, dim=-1) |
| | return torch.multinomial(probs, num_samples=1) |
| | else: |
| | |
| | return torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| |
|
| |
|
| | |
| | def generate_position_ids(attention_mask: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Generate position ID tensor based on the given attention mask. |
| | |
| | The position IDs are computed as the cumulative count of non-padding tokens |
| | within each sequence. Padding tokens are always assigned position ID 0. |
| | |
| | Args: |
| | attention_mask (torch.Tensor): |
| | A tensor of shape (batch_size, sequence_length). |
| | Typically, 1 indicates a valid (non-padding) token and 0 indicates a padding token. |
| | |
| | Returns: |
| | torch.Tensor: |
| | A tensor of shape (batch_size, sequence_length) containing position IDs. |
| | - For non-padding tokens: position IDs start at 0 and increase consecutively. |
| | - For padding tokens: position ID is always 0. |
| | """ |
| | position_ids = (attention_mask.cumsum(-1) - 1).clamp(min=0) |
| | position_ids.masked_fill_(attention_mask == 0, 0) |
| | return position_ids |
| |
|
| | @log_function_call |
| | def is_conversation(input_ids: torch.Tensor, tokenizer) -> bool: |
| | """ |
| | Check whether the given input IDs represent a conversation format. |
| | Only the first sample in the batch is inspected. |
| | |
| | The function verifies whether the sequence contains at least one pair |
| | of special tokens: <|im_start|> and <|im_end|>. |
| | |
| | Args: |
| | input_ids (torch.Tensor): |
| | Tensor of shape (batch_size, seq_len) containing token IDs. |
| | tokenizer: |
| | A HuggingFace tokenizer used to obtain the special token IDs. |
| | |
| | Returns: |
| | bool: |
| | True if the sequence contains both <|im_start|> and <|im_end|>, |
| | False otherwise. |
| | """ |
| | if len(input_ids.shape) != 2: |
| | raise ValueError("input_ids must be a 2D tensor of shape (batch_size, seq_len)") |
| | |
| | seq = input_ids[0].tolist() |
| |
|
| | |
| | im_start_ids = tokenizer.encode("<|im_start|>", add_special_tokens=False) |
| | im_end_ids = tokenizer.encode("<|im_end|>", add_special_tokens=False) |
| |
|
| | |
| | has_start = any(seq[i:i+len(im_start_ids)] == im_start_ids for i in range(len(seq) - len(im_start_ids) + 1)) |
| | has_end = any(seq[i:i+len(im_end_ids)] == im_end_ids for i in range(len(seq) - len(im_end_ids) + 1)) |
| | |
| | logging.info(f"has_start: {has_start}, has_end: {has_end}") |
| | |
| | logging.info(f"im_start_ids: {im_start_ids}") |
| | logging.info(f"im_end_ids: {im_end_ids}") |
| |
|
| | return has_start and has_end |
| |
|
| |
|
| | @log_function_call |
| | def postprocess_assistant_labels( |
| | input_ids: torch.Tensor, |
| | labels: torch.Tensor, |
| | tokenizer |
| | ) -> torch.Tensor: |
| | """ |
| | Mask out labels corresponding to the `<|im_start|>assistant` marker. |
| | |
| | This ensures that the special tokens used to indicate the start of the assistant's |
| | response do not contribute to the loss during training. |
| | |
| | Args: |
| | input_ids (torch.Tensor): |
| | Tensor of shape (batch_size, seq_len) containing the conversation token IDs. |
| | labels (torch.Tensor): |
| | Tensor of shape (batch_size, seq_len) containing training labels. |
| | A value of -100 indicates positions that should be ignored in loss computation. |
| | tokenizer: |
| | A HuggingFace tokenizer used to encode the `<|im_start|>assistant\n` marker. |
| | |
| | Returns: |
| | torch.Tensor: |
| | The modified labels tensor with positions corresponding to |
| | `<|im_start|>assistant\n` masked as -100. |
| | """ |
| | if tokenizer.chat_template != CONVERSATION_TEMPLATE: |
| | raise ValueError( |
| | "Invalid tokenizer.chat_template detected.\n" |
| | f"Expected:\n{CONVERSATION_TEMPLATE}\n\n" |
| | f"Got:\n{tokenizer.chat_template}\n\n" |
| | "Please ensure that you are using the correct conversation template." |
| | ) |
| | |
| | |
| | pattern_ids: List[int] = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False) |
| |
|
| | batch_size, seq_len = input_ids.shape |
| | new_labels = labels.clone() |
| |
|
| | for b in range(batch_size): |
| | seq = input_ids[b].tolist() |
| | for i in range(len(seq) - len(pattern_ids) + 1): |
| | |
| | if seq[i : i + len(pattern_ids)] == pattern_ids: |
| | new_labels[b, i : i + len(pattern_ids)] = -100 |
| |
|
| | return new_labels |
| |
|
| | |
| | def check_ends_with_delimiter( |
| | input_ids: torch.Tensor, tokenizer: PreTrainedTokenizerBase, delimiters: List[str] |
| | ) -> torch.Tensor: |
| | """ |
| | Check whether each sequence in the batch ends with any of the specified delimiter strings. |
| | |
| | Args: |
| | input_ids (torch.Tensor): |
| | Tensor of shape (batch_size, seq_len) containing token IDs for each sequence. |
| | tokenizer (PreTrainedTokenizerBase): |
| | HuggingFace tokenizer used to decode input_ids back to text. |
| | delimiters (List[str]): |
| | A list of delimiter strings to check against. |
| | If a sequence ends with any of these delimiters, it is marked as True. |
| | |
| | Returns: |
| | torch.Tensor: |
| | A boolean tensor of shape (batch_size, 1), where each entry indicates |
| | whether the corresponding sequence ends with one of the delimiters. |
| | """ |
| | batch_size = input_ids.size(0) |
| |
|
| | |
| | augmentation_decisions = torch.zeros(batch_size, 1, dtype=torch.bool, device=input_ids.device) |
| |
|
| | |
| | delimiter_id_tensors = getattr(getattr(tokenizer, "memgen_model_ref", None), "_delimiter_id_tensors", None) |
| | if delimiter_id_tensors is None: |
| | try: |
| | delimiter_id_tensors = [ |
| | torch.tensor(tokenizer.encode(d, add_special_tokens=False), dtype=torch.long, device=input_ids.device) |
| | for d in delimiters |
| | ] |
| | except Exception: |
| | delimiter_id_tensors = [] |
| |
|
| | for i in range(batch_size): |
| | seq = input_ids[i] |
| | matched = False |
| | for did in delimiter_id_tensors: |
| | L = did.numel() |
| | if L == 0 or L > seq.numel(): |
| | continue |
| | if torch.equal(seq[-L:], did): |
| | matched = True |
| | break |
| | if not matched: |
| | |
| | decoded = tokenizer.decode(seq.tolist()) |
| | for aug_str in delimiters: |
| | if decoded.endswith(aug_str): |
| | matched = True |
| | break |
| | augmentation_decisions[i] = matched |
| | |
| | return augmentation_decisions |
| |
|
| |
|
| | @registry.register_model("latmem") |
| | class LatentMemoryModel(BaseModel): |
| |
|
| | @log_function_call |
| | def __init__( |
| | self, |
| | reasoner_model_name: str, |
| | weaver_model_name: str, |
| | prompt_latents_len: int, |
| | inference_latents_len: int, |
| | weaver_peft_config: Optional[PeftConfig] = None, |
| | trigger_model_name: str = None, |
| | trigger_peft_config: Optional[PeftConfig] = None, |
| | max_prompt_aug_num: int = 1, |
| | max_inference_aug_num: int = 5, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| | reasoner_model_name, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" |
| | ) |
| | |
| | self.processor = AutoProcessor.from_pretrained(reasoner_model_name) |
| | |
| | if hasattr(self.processor, 'tokenizer'): |
| | self.tokenizer = self.processor.tokenizer |
| | else: |
| | self.tokenizer = self.processor |
| | self.config = self.model.config |
| | |
| | |
| | self.weaver = MemGenWeaver( |
| | weaver_model_name, prompt_latents_len, inference_latents_len, weaver_peft_config |
| | ) |
| | |
| | |
| | self.trigger = NanoTrigger() |
| | if trigger_model_name is not None: |
| | self.trigger = MemGenTrigger( |
| | trigger_model_name, trigger_peft_config |
| | ) |
| | logging.info(f"Use Trigger: {trigger_model_name}") |
| | |
| | |
| | |
| | self.reasoner_to_weaver = nn.Linear( |
| | self.model.config.hidden_size, self.weaver.config.hidden_size, dtype=torch.bfloat16 |
| | ) |
| | |
| | self.weaver_to_reasoner = nn.Linear( |
| | self.weaver.config.hidden_size, self.model.config.hidden_size, dtype=torch.bfloat16 |
| | ) |
| |
|
| | self.delimiters: List[str] = [",", ".", "\n"] |
| | |
| | self.max_prompt_aug_num = max_prompt_aug_num |
| | self.max_inference_aug_num = max_inference_aug_num |
| |
|
| | |
| | if max_prompt_aug_num == 0 and max_inference_aug_num == 0: |
| | logging.warning( |
| | "⚠️ Both max_prompt_aug_num and max_inference_aug_num are set to 0. " |
| | "This means no latent memory will be used. " |
| | "If you are training, this will cause errors because no trainable parameters will be involved. " |
| | "For training, at least one augmentation type must be enabled (> 0)." |
| | ) |
| |
|
| | |
| | try: |
| | self._delimiter_id_tensors = [ |
| | torch.tensor(self.tokenizer.encode(d, add_special_tokens=False), dtype=torch.long) |
| | for d in self.delimiters |
| | ] |
| | except Exception: |
| | |
| | self._delimiter_id_tensors = [] |
| |
|
| | |
| | self._postprocess_models() |
| | |
| | self.warnings_issued = {} |
| | self.model_tags = None |
| | log_trainable_params(self) |
| |
|
| | @log_function_call |
| | def add_model_tags(self, tags: Union[list[str], str]) -> None: |
| | r""" |
| | Add custom tags into the model that gets pushed to the Hugging Face Hub. Will |
| | not overwrite existing tags in the model. |
| | |
| | Args: |
| | tags (`Union[list[str], str]`): |
| | The desired tags to inject in the model |
| | |
| | Examples: |
| | |
| | ```python |
| | from transformers import AutoModel |
| | |
| | model = AutoModel.from_pretrained("google-bert/bert-base-cased") |
| | |
| | model.add_model_tags(["custom", "custom-bert"]) |
| | |
| | # Push the model to your namespace with the name "my-custom-bert". |
| | model.push_to_hub("my-custom-bert") |
| | ``` |
| | """ |
| | if isinstance(tags, str): |
| | tags = [tags] |
| |
|
| | if self.model_tags is None: |
| | self.model_tags = [] |
| |
|
| | for tag in tags: |
| | if tag not in self.model_tags: |
| | self.model_tags.append(tag) |
| | |
| | @log_function_call |
| | def _postprocess_models(self): |
| | """ |
| | Postprocess the components of the latent memory model: reasoner, weaver, trigger, and tokenizer. |
| | |
| | Steps: |
| | 1. Freeze all parameters of the reasoner (no gradient updates). |
| | 2. Cast all models to bfloat16 for memory and compute efficiency. |
| | 3. Ensure the tokenizer has a valid pad token: |
| | - If pad token is missing, use the EOS token as the pad token. |
| | - Set `padding_side` to "left" for compatibility with generation tasks. |
| | 4. Standardize the tokenizer's chat template to `CONVERSATION_TEMPLATE`. |
| | """ |
| | |
| | fix_model_parameters(self.model) |
| |
|
| | |
| | self.model = self.model.bfloat16() |
| | self.weaver = self.weaver.bfloat16() |
| | self.trigger = self.trigger.bfloat16() |
| |
|
| | |
| | if self.tokenizer.pad_token is None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
| | self.tokenizer.padding_side = "left" |
| | logging.info( |
| | f"Tokenizer has no pad token. Using EOS token ({self.tokenizer.eos_token}) as pad token." |
| | ) |
| | |
| | |
| | |
| | if hasattr(self.model, 'config'): |
| | self.model.config.pad_token_id = self.tokenizer.pad_token_id |
| | self.model.config.bos_token_id = self.tokenizer.bos_token_id |
| | self.model.config.eos_token_id = self.tokenizer.eos_token_id |
| | |
| | |
| | if hasattr(self.model, 'generation_config'): |
| | if self.model.generation_config is not None: |
| | self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id |
| | self.model.generation_config.bos_token_id = self.tokenizer.bos_token_id |
| | self.model.generation_config.eos_token_id = self.tokenizer.eos_token_id |
| | |
| | |
| | if hasattr(self, 'processor') and hasattr(self.processor, 'tokenizer'): |
| | self.processor.tokenizer.pad_token_id = self.tokenizer.pad_token_id |
| | self.processor.tokenizer.bos_token_id = self.tokenizer.bos_token_id |
| | self.processor.tokenizer.eos_token_id = self.tokenizer.eos_token_id |
| | |
| | is_vlm = isinstance(self.model, Qwen2_5_VLForConditionalGeneration) or hasattr(self.processor, 'image_processor') |
| | if not is_vlm: |
| | self.processor.tokenizer.chat_template = CONVERSATION_TEMPLATE |
| |
|
| | |
| | is_vlm = isinstance(self.model, Qwen2_5_VLForConditionalGeneration) or hasattr(getattr(self, 'processor', None), 'image_processor') |
| | if not is_vlm: |
| | self.tokenizer.chat_template = CONVERSATION_TEMPLATE |
| | |
| | logging.info( |
| | f"Synchronized special tokens - pad_token_id: {self.tokenizer.pad_token_id}, " |
| | f"bos_token_id: {self.tokenizer.bos_token_id}, eos_token_id: {self.tokenizer.eos_token_id}" |
| | ) |
| |
|
| | @property |
| | def device(self): |
| | assert self.model.device == self.weaver.device == self.trigger.device |
| | return self.model.device |
| | |
| | |
| | @log_function_call |
| | def _forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | labels: torch.Tensor, |
| | pixel_values: Optional[torch.Tensor] = None, |
| | image_grid_thw: Optional[Tuple[int, int, int]] = None, |
| | **kwargs |
| | ) -> torch.Tensor: |
| | |
| | assert input_ids.shape == attention_mask.shape == labels.shape |
| | |
| | tokenizer = self.tokenizer |
| | reasoner = self.model |
| | weaver = self.weaver |
| | delimiters = self.delimiters |
| | max_augment_num = self.max_inference_aug_num |
| | device = self.device |
| | embeds_dtype = reasoner.get_input_embeddings().weight.dtype |
| | B, _ = input_ids.shape |
| | hidden_size = reasoner.config.hidden_size |
| |
|
| | |
| | augmentation_indices = self._select_augment_points_after_delimiter( |
| | input_ids, labels, delimiters, tokenizer, max_augment_num |
| | ) |
| | logging.info(f"augmentation_indices: {augmentation_indices}") |
| | |
| | |
| | if pixel_values is not None: |
| | with torch.no_grad(): |
| | seed_outputs = reasoner( |
| | input_ids=input_ids, |
| | pixel_values=pixel_values, |
| | image_grid_thw=image_grid_thw, |
| | attention_mask=attention_mask, |
| | return_dict=True, |
| | output_hidden_states=True, |
| | ) |
| | inputs_embeds = seed_outputs.hidden_states[0] |
| | else: |
| | inputs_embeds = reasoner.get_input_embeddings()(input_ids) |
| | |
| | |
| | current_start_idx = 0 |
| | current_inputs_embeds = torch.empty((B, 0, hidden_size), device=device, dtype=embeds_dtype) |
| | current_attention_mask = torch.empty((B, 0), device=device, dtype=attention_mask.dtype) |
| | current_latents_mask = torch.empty((B, 0), device=device, dtype=torch.bool) |
| |
|
| | |
| | current_weaver_inputs_embeds = torch.empty( |
| | (B, 0, self.weaver.config.hidden_size), device=device, dtype=self.weaver_to_reasoner.weight.dtype |
| | ) |
| |
|
| | |
| | for aug_point_idx in augmentation_indices: |
| | |
| | segment_inputs_embeds = inputs_embeds[:, current_start_idx:aug_point_idx] |
| | segment_attention_mask = attention_mask[:, current_start_idx:aug_point_idx] |
| | segment_latents_mask = torch.zeros((B, segment_inputs_embeds.size(1)), device=device, dtype=torch.bool) |
| |
|
| | |
| | current_inputs_embeds = torch.cat([current_inputs_embeds, segment_inputs_embeds], dim=1) |
| | current_attention_mask = torch.cat([current_attention_mask, segment_attention_mask], dim=1) |
| | current_position_ids = generate_position_ids(current_attention_mask) |
| | current_latents_mask = torch.cat([current_latents_mask, segment_latents_mask], dim=1) |
| |
|
| | |
| | segment_weaver_inputs = self.reasoner_to_weaver(segment_inputs_embeds) |
| | current_weaver_inputs_embeds = torch.cat([current_weaver_inputs_embeds, segment_weaver_inputs], dim=1) |
| |
|
| | |
| | is_prompt_end_aug = (labels[:, aug_point_idx] != -100).all() and (labels[:, aug_point_idx-1] == -100).all().item() |
| | |
| | if is_prompt_end_aug: |
| | logging.info(f"[augment] idx={aug_point_idx} TYPE=PROMPT") |
| | weaver_hidden_states, attn_mask, pos_ids = weaver.augment_prompt( |
| | current_weaver_inputs_embeds, current_attention_mask, current_position_ids |
| | ) |
| | else: |
| | logging.info(f"[augment] idx={aug_point_idx} TYPE=INFERENCE") |
| | weaver_hidden_states, attn_mask, pos_ids = weaver.augment_inference( |
| | current_weaver_inputs_embeds, current_attention_mask, current_position_ids |
| | ) |
| |
|
| | |
| | latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states) |
| |
|
| | |
| | current_inputs_embeds = torch.cat([current_inputs_embeds, latent_inputs_embeds], dim=1) |
| | current_attention_mask = torch.cat([current_attention_mask, attn_mask], dim=1) |
| | current_start_idx = aug_point_idx |
| | |
| | |
| | latent_mask = torch.ones((B, latent_inputs_embeds.size(1)), device=device, dtype=torch.bool) |
| | current_latents_mask = torch.cat([current_latents_mask, latent_mask], dim=1) |
| |
|
| | |
| | current_weaver_inputs_embeds = torch.cat([current_weaver_inputs_embeds, weaver_hidden_states], dim=1) |
| | |
| | |
| | remaining_inputs_embeds = inputs_embeds[:, current_start_idx:] |
| | remaining_attention_mask = attention_mask[:, current_start_idx:] |
| | latent_mask = torch.zeros((B, remaining_attention_mask.size(1)), device=device, dtype=torch.bool) |
| | |
| | current_inputs_embeds = torch.cat([current_inputs_embeds, remaining_inputs_embeds], dim=1) |
| | current_attention_mask = torch.cat([current_attention_mask, remaining_attention_mask], dim=1) |
| | current_position_ids = generate_position_ids(current_attention_mask) |
| | current_latents_mask = torch.cat([current_latents_mask, latent_mask], dim=1) |
| |
|
| | reasoner_outputs = reasoner( |
| | inputs_embeds=current_inputs_embeds, |
| | attention_mask=current_attention_mask, |
| | position_ids=current_position_ids |
| | ) |
| | logits = reasoner_outputs.logits |
| | |
| | |
| | shifted = torch.zeros_like(current_latents_mask) |
| | shifted[:, :-1] = current_latents_mask[:, 1:] |
| | valid_mask = ~shifted |
| | |
| | valid_logits = logits[valid_mask].view(logits.size(0), -1, logits.size(2)) |
| | |
| | |
| | return valid_logits |
| | |
| | |
| | def _instructional_forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | labels: torch.Tensor, |
| | pixel_values: Optional[torch.Tensor] = None, |
| | image_grid_thw: Optional[Tuple[int, int, int]] = None, |
| | **kwargs |
| | ) -> torch.Tensor: |
| | """ |
| | Forward pass for single-turn instructional data (no multi-turn conversation required). |
| | |
| | This method is used for instruction-following tasks (SFT), where the input |
| | consists of a single instruction and the corresponding labels. It directly |
| | delegates to the single-turn forward method `_forward`. |
| | |
| | Args: |
| | input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing input token IDs. |
| | attention_mask (torch.Tensor): Tensor indicating padding positions. |
| | labels (torch.Tensor): Tensor containing the target labels for supervised fine-tuning. |
| | **kwargs: Additional keyword arguments passed to `_forward`. |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: |
| | - logits: The output logits from the model for each input token. |
| | - labels: The same as input labels, used for loss computation. |
| | """ |
| | logits = self._forward( |
| | input_ids, |
| | attention_mask, |
| | labels, |
| | pixel_values=pixel_values, |
| | image_grid_thw=image_grid_thw, |
| | **kwargs, |
| | ) |
| | |
| | return logits, labels |
| |
|
| | @log_function_call |
| | def _conversational_forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | labels: torch.Tensor, |
| | pixel_values: Optional[torch.Tensor] = None, |
| | image_grid_thw: Optional[Tuple[int, int, int]] = None, |
| | **kwargs |
| | ) -> torch.Tensor: |
| | """ |
| | Forward pass for conversational (multi-turn) data. |
| | |
| | Multi-turn forward is constructed by sequentially calling the single-turn forward |
| | for each conversation turn. Latents inserted in turn i-1 are not visible to turn i. |
| | |
| | Args: |
| | input_ids (torch.Tensor): Input token IDs, shape (1, seq_len). Batch size must be 1. |
| | attention_mask (torch.Tensor): Attention mask for input tokens. |
| | labels (torch.Tensor): Target labels for supervised fine-tuning (-100 for ignore positions). |
| | **kwargs: Additional arguments passed to `_forward`. |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: |
| | - all_logits: Logits for the entire sequence, with zeros for unsupervised positions. |
| | - all_labels: Labels for the entire sequence, with -100 for unsupervised positions. |
| | """ |
| | assert input_ids.shape[0] == 1, "Conversational SFT currently only supports batch_size = 1" |
| | seq_len = input_ids.shape[1] |
| | vocab_size = self.config.vocab_size |
| | device = input_ids.device |
| |
|
| | |
| | label_row = labels[0] |
| | should_supervise = label_row != -100 |
| | if not should_supervise.any(): |
| | raise ValueError("At least one completion segment is required") |
| |
|
| | |
| | valid_mask = should_supervise.int() |
| | diff = torch.diff(torch.cat([torch.tensor([0], device=device), valid_mask])) |
| | valid_starts = (diff == 1).nonzero(as_tuple=True)[0].tolist() |
| | ends = (diff == -1).nonzero(as_tuple=True)[0].tolist() |
| | if len(ends) < len(valid_starts): |
| | ends.append(seq_len) |
| | assert len(valid_starts) == len(ends) |
| | |
| | |
| | triplets = [] |
| | start = 0 |
| | for s, e in zip(valid_starts, ends): |
| | triplets.append((start, s, e)) |
| | start = e |
| | |
| | |
| | |
| | if self.max_prompt_aug_num == 0: |
| | |
| | select_turns = [1] * len(triplets) |
| | elif len(triplets) <= self.max_prompt_aug_num: |
| | select_turns = [1] * len(triplets) |
| | else: |
| | triplets_num = len(triplets) |
| | selected_indices = set(random.sample(range(triplets_num), self.max_prompt_aug_num)) |
| | select_turns = [1 if i in selected_indices else 0 for i in range(triplets_num)] |
| |
|
| | |
| | all_logits = torch.zeros(1, seq_len, vocab_size, device=device) |
| | all_labels = torch.full((1, seq_len), -100, device=device) |
| |
|
| | |
| | for triplet, should_supervise in zip(triplets, select_turns): |
| | start, valid_start, end = triplet |
| | if should_supervise: |
| | cur_input_ids = input_ids[0, :end].unsqueeze(0) |
| | cur_attention = attention_mask[0, :end].unsqueeze(0) |
| | |
| | cur_labels = labels[0, :end].clone().unsqueeze(0) |
| | cur_labels[0, :valid_start] = -100 |
| |
|
| | |
| | logits = self._forward( |
| | cur_input_ids, |
| | cur_attention, |
| | cur_labels, |
| | pixel_values=pixel_values, |
| | image_grid_thw=image_grid_thw, |
| | **kwargs, |
| | ) |
| | |
| | |
| | all_logits[0, start:end, :] = logits[0, start:end, :] |
| | all_labels[0, start:end] = labels[0, start:end] |
| |
|
| | |
| | |
| | |
| | return all_logits, all_labels |
| |
|
| | @log_function_call |
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | labels: torch.Tensor, |
| | pixel_values: Optional[torch.Tensor] = None, |
| | image_grid_thw: Optional[Tuple[int, int, int]] = None, |
| | **kwargs |
| | ): |
| | tokenizer = self.tokenizer |
| |
|
| | |
| | assert labels is not None, "Reasoning Processor requires input labels for training" |
| | |
| | |
| | if self.training and self.max_prompt_aug_num == 0 and self.max_inference_aug_num == 0: |
| | raise ValueError( |
| | "Cannot train MemGen model with both max_prompt_aug_num=0 and max_inference_aug_num=0. " |
| | "At least one augmentation type must be enabled during training. " |
| | "If you want to use the base model without augmentation, please use the reasoner model directly. " |
| | "For inference/evaluation without augmentation, set the model to eval mode: model.eval()" |
| | ) |
| | |
| | |
| | forward_func = self._instructional_forward |
| | if is_conversation(input_ids, tokenizer): |
| | |
| | logging.info("is_conversation: True") |
| | labels = postprocess_assistant_labels(input_ids, labels, tokenizer) |
| | forward_func = self._conversational_forward |
| | |
| | batch_size = 1 |
| | iter_num = input_ids.size(0) // batch_size |
| |
|
| | |
| | logits, supervised_labels = [], [] |
| | for i in range(iter_num): |
| | batch_input_ids = input_ids[i * batch_size: (i + 1) * batch_size] |
| | batch_attention_mask = attention_mask[i * batch_size: (i + 1) * batch_size] |
| | batch_labels = labels[i * batch_size: (i + 1) * batch_size] |
| |
|
| | |
| | batch_logits, batch_supervised_labels = forward_func( |
| | input_ids=batch_input_ids, |
| | attention_mask=batch_attention_mask, |
| | labels=batch_labels, |
| | pixel_values=pixel_values, |
| | image_grid_thw=image_grid_thw, |
| | **kwargs |
| | ) |
| | logits.append(batch_logits) |
| | supervised_labels.append(batch_supervised_labels) |
| | |
| | |
| | all_logits = torch.concat(logits, dim=0) |
| | all_labels = torch.concat(supervised_labels, dim=0) |
| |
|
| | |
| | shift_logits = all_logits[..., :-1, :].contiguous() |
| | shift_labels = all_labels[..., 1:].contiguous() |
| | |
| | loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| |
|
| | |
| | outputs = CausalLMOutputWithPast(loss=loss, logits=all_logits) |
| | outputs.supervised_labels = all_labels |
| | return outputs |
| |
|
| | |
| | @log_function_call |
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | generation_config: GenerationConfig = None, |
| | return_augmentation_mask: bool = False, |
| | pixel_values: Optional[torch.Tensor] = None, |
| | image_grid_thw: Optional[Tuple[int, int, int]] = None, |
| | **kwargs |
| | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| | |
| | tokenizer = self.tokenizer |
| | reasoner = self.model |
| | weaver = self.weaver |
| | trigger = self.trigger |
| | delimiters = self.delimiters |
| | max_augment_num = self.max_inference_aug_num |
| | invalid_token_id = -100 |
| |
|
| | |
| | input_ids = input_ids.to(self.device) |
| | attention_mask = attention_mask.to(self.device) |
| | max_new_tokens = generation_config.max_new_tokens |
| | do_sample = generation_config.do_sample |
| | temperature = generation_config.temperature |
| | pad_token_id = tokenizer.pad_token_id |
| | |
| | try: |
| | im_end_ids = tokenizer.encode("<|im_end|>", add_special_tokens=False) |
| | if isinstance(im_end_ids, list) and len(im_end_ids) == 1: |
| | eos_token_id = im_end_ids[0] |
| | else: |
| | eos_token_id = tokenizer.eos_token_id |
| | except Exception: |
| | eos_token_id = tokenizer.eos_token_id |
| | prompt_len = input_ids.size(1) |
| | generation_config = GenerationConfig( |
| | do_sample=do_sample, |
| | temperature=temperature, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | |
| | ) |
| | |
| | with torch.no_grad(): |
| | outputs = reasoner( |
| | input_ids=input_ids, |
| | pixel_values=pixel_values, |
| | image_grid_thw=image_grid_thw, |
| | attention_mask=attention_mask, |
| | return_dict=True, |
| | output_hidden_states=True, |
| | ) |
| | |
| | |
| | fused_embeds = outputs.hidden_states[0] |
| | |
| |
|
| | inputs_embeds = fused_embeds |
| |
|
| | B, _, hidden_size = inputs_embeds.shape |
| | device = inputs_embeds.device |
| |
|
| | try: |
| | current_inputs_embeds = inputs_embeds |
| | current_attention_mask = attention_mask |
| | current_position_ids = generate_position_ids(current_attention_mask) |
| | current_input_ids = input_ids |
| |
|
| | |
| | if self.max_prompt_aug_num > 0: |
| | logging.info("[Generation] Applying prompt augmentation") |
| | weaver_inputs_embeds = self.reasoner_to_weaver(current_inputs_embeds) |
| | weaver_hidden_states, attn_mask, pos_ids = weaver.augment_prompt( |
| | weaver_inputs_embeds, current_attention_mask, current_position_ids |
| | ) |
| | latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states) |
| |
|
| | |
| | current_inputs_embeds = torch.cat([current_inputs_embeds, latent_inputs_embeds], dim=1) |
| | current_attention_mask = torch.cat([current_attention_mask, attn_mask], dim=1) |
| | current_position_ids = torch.cat([current_position_ids, pos_ids], dim=1) |
| | else: |
| | logging.info("[Generation] Skipping prompt augmentation (max_prompt_aug_num=0)") |
| |
|
| | |
| | sentence_augment_count = torch.zeros(B, dtype=torch.int, device=device) |
| | augmentation_pos = torch.full((B, max_new_tokens), fill_value=invalid_token_id, device=device) |
| | inserted_embeds: List[List[torch.Tensor]] = [[] for _ in range(B)] |
| | for i in range(max_new_tokens): |
| | |
| | |
| | if (current_input_ids[:, -1] == eos_token_id).all(): |
| | break |
| | |
| | |
| | if (sentence_augment_count >= max_augment_num).all(): |
| | |
| | generation_config.max_new_tokens = max_new_tokens - i |
| |
|
| | |
| | generated = reasoner.generate( |
| | inputs_embeds=current_inputs_embeds, |
| | attention_mask=current_attention_mask, |
| | generation_config=generation_config, |
| | ) |
| | current_input_ids = torch.cat([current_input_ids, generated], dim=1) |
| | break |
| |
|
| | outputs = reasoner( |
| | inputs_embeds=current_inputs_embeds, |
| | attention_mask=current_attention_mask, |
| | position_ids=current_position_ids, |
| | output_hidden_states=False, |
| | ) |
| | current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids = self._append_one_step( |
| | outputs, current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids, do_sample, temperature |
| | ) |
| | |
| | if i == max_new_tokens - 1: |
| | break |
| |
|
| | |
| | augment_decision = self._should_augment( |
| | current_input_ids, current_attention_mask, sentence_augment_count=sentence_augment_count, |
| | do_sample=do_sample, temperature=temperature |
| | ) |
| | augmentation_pos[:, i + 1] = augment_decision |
| | augment_indices = torch.where(augment_decision == 1)[0] |
| |
|
| | |
| | if len(augment_indices) > 0: |
| | |
| | sentence_augment_count[augment_indices] += 1 |
| |
|
| | |
| | candidate_inputs_embeds = current_inputs_embeds[augment_indices] |
| | candidate_attention_mask = current_attention_mask[augment_indices] |
| | candidate_position_ids = current_position_ids[augment_indices] |
| | |
| | |
| | weaver_inputs_embeds = self.reasoner_to_weaver(candidate_inputs_embeds) |
| | weaver_hidden_states, attn_mask, _ = weaver.augment_inference( |
| | weaver_inputs_embeds, candidate_attention_mask, candidate_position_ids |
| | ) |
| | latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states) |
| | |
| | candidate_inputs_embeds = torch.cat([candidate_inputs_embeds, latent_inputs_embeds], dim=1) |
| | candidate_attention_mask = torch.cat([candidate_attention_mask, attn_mask], dim=1) |
| | |
| | |
| | new_len = candidate_inputs_embeds.size(1) |
| | merged_inputs_embeds = torch.zeros((B, new_len, hidden_size), device=device, dtype=current_inputs_embeds.dtype) |
| | merged_attention_mask = torch.zeros((B, new_len), device=device, dtype=current_attention_mask.dtype) |
| | |
| | |
| | merged_inputs_embeds[augment_indices] = candidate_inputs_embeds |
| | merged_attention_mask[augment_indices] = candidate_attention_mask |
| | |
| | |
| | non_augment_indices = torch.where(augment_decision != 1)[0] |
| | if len(non_augment_indices) > 0: |
| | non_aug_inputs_embeds = current_inputs_embeds[non_augment_indices] |
| | non_aug_attention_mask = current_attention_mask[non_augment_indices] |
| | non_aug_inputs_embeds, non_aug_attention_mask, _ = self._left_pad( |
| | non_aug_inputs_embeds, non_aug_attention_mask, None, weaver.inference_latents_num |
| | ) |
| | merged_inputs_embeds[non_augment_indices] = non_aug_inputs_embeds |
| | merged_attention_mask[non_augment_indices] = non_aug_attention_mask |
| | |
| | current_inputs_embeds = merged_inputs_embeds |
| | current_attention_mask = merged_attention_mask |
| | current_position_ids = generate_position_ids(current_attention_mask) |
| | |
| | |
| | for idx, embed in zip(augment_indices, latent_inputs_embeds): |
| | inserted_embeds[idx].append(embed.clone().detach().cpu()) |
| | except Exception: |
| | logging.exception( |
| | f"[generate] Exception. do_sample={do_sample}, temperature={temperature}, pad_id={pad_token_id}, eos_id={eos_token_id}, " |
| | f"pixel_present={pixel_values is not None}, max_new_tokens={max_new_tokens}" |
| | ) |
| | raise |
| | |
| | |
| | new_generated_len = current_input_ids.size(1) - prompt_len |
| | augmentation_pos = augmentation_pos[:, :new_generated_len] |
| | |
| | if not return_augmentation_mask: |
| | return current_input_ids |
| | else: |
| | return current_input_ids, augmentation_pos |
| |
|
| | |
| | @classmethod |
| | @log_function_call |
| | def from_config(cls, config): |
| | |
| | reasoner_model_name = config.get("reasoner_model_name", None) |
| | max_prompt_aug_num = config.get("max_prompt_aug_num", None) |
| | max_inference_aug_num = config.get("max_inference_aug_num", None) |
| |
|
| | |
| | weaver_configs = config.get("weaver") |
| | weaver_model_name = weaver_configs.get("weaver_model_name", None) |
| | prompt_latents_len = weaver_configs.get("prompt_latents_len", 8) |
| | inference_latents_len = weaver_configs.get("inference_latents_len", 2) |
| | weaver_use_peft = weaver_configs.get("use_peft", True) |
| | weaver_peft_config = weaver_configs.get("peft_config", None) if weaver_use_peft else None |
| |
|
| | if weaver_peft_config is not None: |
| | weaver_peft_config = LoraConfig(**weaver_peft_config) |
| | |
| | |
| | trigger_configs = config.get("trigger") |
| | trigger_model_name = trigger_configs.get("trigger_model_name", None) |
| | trigger_use_peft = trigger_configs.get("use_peft", True) |
| | trigger_peft_config = trigger_configs.get("peft_config", None) if trigger_use_peft else None |
| |
|
| | if trigger_peft_config is not None: |
| | trigger_peft_config = LoraConfig(**trigger_peft_config) |
| |
|
| | model = cls( |
| | reasoner_model_name, |
| | |
| | weaver_model_name=weaver_model_name, |
| | prompt_latents_len=prompt_latents_len, |
| | inference_latents_len=inference_latents_len, |
| | weaver_peft_config=weaver_peft_config, |
| | |
| | trigger_model_name=trigger_model_name, |
| | trigger_peft_config=trigger_peft_config, |
| | |
| | max_prompt_aug_num=max_prompt_aug_num, |
| | max_inference_aug_num=max_inference_aug_num |
| | ) |
| |
|
| | |
| | load_model_path = config.get("load_model_path", None) |
| | if load_model_path is not None: |
| | model_state_dict = load_state_dict_from_safetensor(load_model_path) |
| | model.load_state_dict(model_state_dict, strict=False) |
| | logging.info(f"Load model state dict from: {load_model_path}") |
| |
|
| | return model |
| | |
| |
|
| | |
| | @torch.no_grad() |
| | def _append_one_step( |
| | self, |
| | reasoner_outputs, |
| | current_inputs_embeds: torch.Tensor, |
| | current_attention_mask: torch.Tensor, |
| | current_position_ids: torch.Tensor, |
| | current_input_ids: torch.Tensor, |
| | do_sample: bool, |
| | temperature: float |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | reasoner = self.model |
| | B = current_inputs_embeds.size(0) |
| | |
| | |
| | next_token_logits = reasoner_outputs.logits[:, -1] |
| | next_token_ids = get_next_token(next_token_logits, do_sample, temperature) |
| | current_input_ids = torch.cat([current_input_ids, next_token_ids], dim=1) |
| | |
| | |
| | next_token_embeds = reasoner.get_input_embeddings()(next_token_ids) |
| | current_inputs_embeds = torch.cat([current_inputs_embeds, next_token_embeds], dim=1) |
| | |
| | |
| | attn_mask = torch.ones((B, 1), dtype=current_attention_mask.dtype, device=current_attention_mask.device) |
| | current_attention_mask = torch.cat([current_attention_mask, attn_mask], dim=1) |
| | |
| | |
| | next_position_id = current_position_ids[:, -1:] + 1 |
| | current_position_ids = torch.cat([current_position_ids, next_position_id], dim=1) |
| |
|
| | return current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids |
| | |
| | |
| | def _select_augment_points_after_delimiter( |
| | self, |
| | input_ids: torch.Tensor, |
| | labels: torch.Tensor, |
| | delimiters: List[str], |
| | tokenizer: PreTrainedTokenizerBase, |
| | max_num: int = 10, |
| | ) -> List[int]: |
| | """ |
| | Select positions in a sequence suitable for data augmentation based on labels and delimiters. |
| | |
| | This function identifies two types of augmentation points: |
| | 1. **Prompt augmentation points**: positions where the label transitions from -100 to a valid token. |
| | - Typically corresponds to the start of the assistant's response. |
| | - Can be disabled if max_prompt_aug_num=0. |
| | 2. **Inference augmentation points**: positions inside valid label regions that follow a delimiter. |
| | - Only the first `max_num` points are kept. |
| | |
| | Args: |
| | input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing token IDs. |
| | labels (torch.Tensor): Tensor of shape (batch_size, seq_len), where -100 indicates |
| | positions that should not contribute to loss. |
| | delimiters (List[str]): List of string delimiters used to determine inference augmentation points. |
| | tokenizer (PreTrainedTokenizerBase): Tokenizer used to decode input_ids for delimiter matching. |
| | max_num (int, optional): Maximum number of inference augmentation points to select. Default is 10. |
| | |
| | Returns: |
| | List[int]: Sorted list of selected augmentation positions in the sequence. |
| | |
| | Raises: |
| | ValueError: If the prompt augmentation point count is not exactly one when it's expected. |
| | RuntimeError: If no valid augmentation points are found when both types are enabled. |
| | """ |
| | assert input_ids.shape == labels.shape |
| | B, seq_len = input_ids.size(0), input_ids.size(1) |
| |
|
| | prompt_augment_idx = [] |
| | inference_augment_idx = [] |
| |
|
| | for i in range(1, seq_len): |
| | |
| | if (labels[:, i] != -100).all() and (labels[:, i - 1] == -100).all(): |
| | prompt_augment_idx.append(i) |
| |
|
| | |
| | elif (labels[:, i] != -100).all() and (labels[:, i - 1] != -100).all(): |
| | batch_tokens_before_i = input_ids[:, :i] |
| | |
| | if any(check_ends_with_delimiter(batch_tokens_before_i, tokenizer, delimiters)): |
| | inference_augment_idx.append(i) |
| | |
| | |
| | if len(prompt_augment_idx) != 1: |
| | raise ValueError("Single-turn forward must have exactly one prompt augment index in data format") |
| |
|
| | final_points = [] |
| | |
| | |
| | if self.max_prompt_aug_num > 0: |
| | final_points = prompt_augment_idx[:1] |
| | logging.info(f"[Augmentation] Prompt augmentation enabled: adding point at {prompt_augment_idx[0]}") |
| | else: |
| | logging.info(f"[Augmentation] Prompt augmentation disabled (max_prompt_aug_num=0)") |
| |
|
| | |
| | if len(inference_augment_idx) > max_num: |
| | inference_augment_idx = inference_augment_idx[:max_num] |
| |
|
| | final_points.extend(inference_augment_idx) |
| | |
| | |
| | if len(final_points) == 0: |
| | if self.max_prompt_aug_num == 0 and max_num == 0: |
| | logging.warning("[Augmentation] No augmentation points selected (both prompt and inference disabled)") |
| | return [] |
| | else: |
| | raise RuntimeError("No valid augmentation points found despite augmentation being enabled") |
| | |
| | final_points.sort() |
| | logging.info(f"[Augmentation] Final augmentation points: {final_points}") |
| | return final_points |
| |
|
| | |
| | @torch.no_grad() |
| | def _should_augment( |
| | self, |
| | input_ids, |
| | attention_mask, |
| | sentence_augment_count: torch.Tensor, |
| | do_sample: bool, |
| | temperature: float = 0.0 |
| | ) -> torch.Tensor: |
| | """ |
| | Determine whether each sentence in the batch should be augmented based on |
| | the model's strategy and trigger predictions. |
| | |
| | aug_mask values: |
| | - -100: No sampling or augmentation decision not applicable |
| | - 0 : Sampled but trigger indicates no augmentation |
| | - 1 : Sampled and trigger indicates augmentation |
| | |
| | Args: |
| | input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len) |
| | attention_mask (torch.Tensor): Attention mask for input_ids |
| | sentence_augment_count (torch.Tensor): Tracks how many times each sentence has been augmented |
| | do_sample (bool): Whether to sample from the trigger output |
| | temperature (float): Sampling temperature |
| | |
| | Returns: |
| | torch.Tensor: Augmentation mask for each sentence in the batch |
| | """ |
| | tokenizer = self.tokenizer |
| | delimiters = self.delimiters |
| | trigger = self.trigger |
| | max_augment_num = self.max_inference_aug_num |
| |
|
| | batch_size = input_ids.size(0) |
| |
|
| | |
| | aug_mask = torch.full((batch_size,), -100, dtype=torch.long, device=input_ids.device) |
| |
|
| | |
| | ends_with_delimiters = check_ends_with_delimiter(input_ids, tokenizer, delimiters).squeeze(1) |
| | aug_mask[ends_with_delimiters] = 0 |
| |
|
| | |
| | over_limit = (sentence_augment_count >= max_augment_num) |
| | aug_mask[over_limit] = -100 |
| |
|
| | |
| | trigger_indices = (aug_mask != -100).nonzero(as_tuple=True)[0] |
| | if trigger_indices.numel() > 0: |
| | trigger_logits = trigger( |
| | input_ids=input_ids[trigger_indices], |
| | attention_mask=attention_mask[trigger_indices], |
| | ) |
| | last_token_logits = trigger_logits[:, -1] |
| |
|
| | |
| | next_tokens = get_next_token(last_token_logits, do_sample, temperature).view(-1) |
| |
|
| | |
| | aug_mask[trigger_indices] = next_tokens |
| |
|
| | return aug_mask |
| |
|
| | |
| | |
| | @torch.no_grad() |
| | def _left_pad( |
| | self, |
| | input_embeds: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | position_ids: torch.Tensor, |
| | pad_num: int |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | |
| | if input_embeds is not None: |
| | B, L, D = input_embeds.shape |
| | pad_embeds = torch.zeros((B, pad_num, D), dtype=input_embeds.dtype, device=input_embeds.device) |
| | input_embeds = torch.cat([pad_embeds, input_embeds], dim=1) |
| | |
| | if attention_mask is not None: |
| | B = attention_mask.size(0) |
| | pad_mask = torch.zeros((B, pad_num), dtype=attention_mask.dtype, device=attention_mask.device) |
| | attention_mask = torch.cat([pad_mask, attention_mask], dim=1) |
| | |
| | if position_ids is not None: |
| | B = position_ids.size(0) |
| | pad_pos = torch.zeros((B, pad_num), dtype=position_ids.dtype, device=position_ids.device) |
| | position_ids = torch.cat([pad_pos, position_ids], dim=1) |
| |
|
| | return input_embeds, attention_mask, position_ids |
| | |
| | @log_function_call |
| | @torch.no_grad() |
| | def _left_clip_pad_tokens( |
| | self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Left-trim padding tokens based on the attention mask. |
| | |
| | This function identifies the leftmost padding (attention_mask=0) in each |
| | sequence. If all sequences in the batch have at least some padding on the |
| | left, it trims the batch by the minimal left-padding length to remove |
| | unnecessary computation. |
| | |
| | Args: |
| | inputs_embeds (torch.Tensor): Input embeddings of shape (B, L, D) |
| | attention_mask (torch.Tensor): Attention mask of shape (B, L) |
| | position_ids (torch.Tensor): Position IDs of shape (B, L) |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | Trimmed inputs_embeds, attention_mask, and position_ids |
| | """ |
| | B, L, D = inputs_embeds.shape |
| | device = inputs_embeds.device |
| |
|
| | |
| | first_nonpad_idx = [] |
| | for b in range(B): |
| | nonzero = (attention_mask[b] != 0).nonzero(as_tuple=True)[0] |
| | if len(nonzero) == 0: |
| | |
| | first_nonpad_idx.append(L) |
| | else: |
| | first_nonpad_idx.append(nonzero[0].item()) |
| | |
| | |
| | min_pad = min(first_nonpad_idx) |
| |
|
| | |
| | if min_pad == 0: |
| | return inputs_embeds, attention_mask, position_ids |
| |
|
| | |
| | inputs_embeds = inputs_embeds[:, min_pad:, :] |
| | attention_mask = attention_mask[:, min_pad:] |
| | position_ids = position_ids[:, min_pad:] |
| |
|
| | return inputs_embeds, attention_mask, position_ids |
| | |
| | |
| | def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
| | """ |
| | Enable gradient checkpointing for all sub-modules that support it. |
| | """ |
| | kwargs = gradient_checkpointing_kwargs or {} |
| | for module in [self.model, getattr(self, "weaver", None), |
| | getattr(self, "trigger", None)]: |
| | if hasattr(module, "gradient_checkpointing_enable"): |
| | module.gradient_checkpointing_enable(**kwargs) |
| |
|
| | def gradient_checkpointing_disable(self): |
| | """ |
| | Disable gradient checkpointing for all sub-modules that support it. |
| | """ |
| | for module in [self.model, getattr(self, "weaver", None), |
| | getattr(self, "trigger", None)]: |
| | if hasattr(module, "gradient_checkpointing_disable"): |
| | module.gradient_checkpointing_disable() |
| |
|
| | @property |
| | def is_gradient_checkpointing(self): |
| | """ |
| | Check if gradient checkpointing is enabled for any sub-module. |
| | """ |
| | for module in [self.model, getattr(self, "weaver", None), |
| | getattr(self, "trigger", None)]: |
| | if hasattr(module, "is_gradient_checkpointing"): |
| | if module.is_gradient_checkpointing: |
| | return True |
| | return False |
| |
|
| |
|