import torch import torch.nn as nn import torch.nn.functional as F from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase, GenerationConfig, Qwen2_5_VLForConditionalGeneration, AutoProcessor ) from transformers.modeling_outputs import CausalLMOutputWithPast from peft import PeftConfig, LoraConfig import random from typing import Tuple, Optional, List, Union import logging import functools from larm.task.base_model import BaseModel from larm.common.registry import registry from .weaver import MemGenWeaver from .trigger import MemGenTrigger, NanoTrigger from .utils import ( CONVERSATION_TEMPLATE, load_state_dict_from_safetensor, fix_model_parameters, log_trainable_params, ) # Decorator to log function calls in blue def log_function_call(func): """Decorator to log function calls with blue color.""" @functools.wraps(func) def wrapper(*args, **kwargs): func_name = func.__name__ # logging.info(f"\033[94m[CALL] {func_name}\033[0m") # blue color return func(*args, **kwargs) return wrapper # @log_function_call def get_next_token(next_token_logits: torch.Tensor, do_sample: bool, temperature: float) -> torch.Tensor: """ Selects the next token from model logits. Two modes are supported: 1. Sampling mode (do_sample=True and temperature>0): Apply temperature scaling to the logits, compute a probability distribution with softmax, and randomly sample one token. 2. Greedy mode (do_sample=False or temperature==0): Select the token with the highest probability (argmax). Args: next_token_logits (torch.Tensor): Logits for the next tokens, shape [batch_size, vocab_size]. do_sample (bool): Whether to perform stochastic sampling. If False, greedy decoding is used. temperature (float): Sampling temperature. Higher values make the distribution flatter (more randomness), while lower values make it sharper (more deterministic). When set to 0, greedy decoding is enforced. Returns: torch.Tensor: The selected next token indices, shape [batch_size, 1]. """ if len(next_token_logits.shape) != 2: raise ValueError("Input logits must be a 2D tensor [batch_size, vocab_size]") if do_sample and temperature != 0: # Apply temperature scaling and sample from the resulting probability distribution probs = F.softmax(next_token_logits / temperature, dim=-1) return torch.multinomial(probs, num_samples=1) else: # Greedy decoding: pick the token with the highest probability return torch.argmax(next_token_logits, dim=-1, keepdim=True) # @log_function_call def generate_position_ids(attention_mask: torch.Tensor) -> torch.Tensor: """ Generate position ID tensor based on the given attention mask. The position IDs are computed as the cumulative count of non-padding tokens within each sequence. Padding tokens are always assigned position ID 0. Args: attention_mask (torch.Tensor): A tensor of shape (batch_size, sequence_length). Typically, 1 indicates a valid (non-padding) token and 0 indicates a padding token. Returns: torch.Tensor: A tensor of shape (batch_size, sequence_length) containing position IDs. - For non-padding tokens: position IDs start at 0 and increase consecutively. - For padding tokens: position ID is always 0. """ position_ids = (attention_mask.cumsum(-1) - 1).clamp(min=0) position_ids.masked_fill_(attention_mask == 0, 0) return position_ids @log_function_call def is_conversation(input_ids: torch.Tensor, tokenizer) -> bool: """ Check whether the given input IDs represent a conversation format. Only the first sample in the batch is inspected. The function verifies whether the sequence contains at least one pair of special tokens: <|im_start|> and <|im_end|>. Args: input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing token IDs. tokenizer: A HuggingFace tokenizer used to obtain the special token IDs. Returns: bool: True if the sequence contains both <|im_start|> and <|im_end|>, False otherwise. """ if len(input_ids.shape) != 2: raise ValueError("input_ids must be a 2D tensor of shape (batch_size, seq_len)") seq = input_ids[0].tolist() # Encode the special tokens to obtain their ID sequences im_start_ids = tokenizer.encode("<|im_start|>", add_special_tokens=False) im_end_ids = tokenizer.encode("<|im_end|>", add_special_tokens=False) # Check if the sequence contains at least one <|im_start|> and one <|im_end|> has_start = any(seq[i:i+len(im_start_ids)] == im_start_ids for i in range(len(seq) - len(im_start_ids) + 1)) has_end = any(seq[i:i+len(im_end_ids)] == im_end_ids for i in range(len(seq) - len(im_end_ids) + 1)) logging.info(f"has_start: {has_start}, has_end: {has_end}") # logging.info(f"seq: {seq}") logging.info(f"im_start_ids: {im_start_ids}") logging.info(f"im_end_ids: {im_end_ids}") return has_start and has_end @log_function_call def postprocess_assistant_labels( input_ids: torch.Tensor, labels: torch.Tensor, tokenizer ) -> torch.Tensor: """ Mask out labels corresponding to the `<|im_start|>assistant` marker. This ensures that the special tokens used to indicate the start of the assistant's response do not contribute to the loss during training. Args: input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing the conversation token IDs. labels (torch.Tensor): Tensor of shape (batch_size, seq_len) containing training labels. A value of -100 indicates positions that should be ignored in loss computation. tokenizer: A HuggingFace tokenizer used to encode the `<|im_start|>assistant\n` marker. Returns: torch.Tensor: The modified labels tensor with positions corresponding to `<|im_start|>assistant\n` masked as -100. """ if tokenizer.chat_template != CONVERSATION_TEMPLATE: raise ValueError( "Invalid tokenizer.chat_template detected.\n" f"Expected:\n{CONVERSATION_TEMPLATE}\n\n" f"Got:\n{tokenizer.chat_template}\n\n" "Please ensure that you are using the correct conversation template." ) # Encode the token sequence for "<|im_start|>assistant\n" pattern_ids: List[int] = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False) batch_size, seq_len = input_ids.shape new_labels = labels.clone() for b in range(batch_size): seq = input_ids[b].tolist() for i in range(len(seq) - len(pattern_ids) + 1): # Mask positions matching the pattern if seq[i : i + len(pattern_ids)] == pattern_ids: new_labels[b, i : i + len(pattern_ids)] = -100 return new_labels # @log_function_call def check_ends_with_delimiter( input_ids: torch.Tensor, tokenizer: PreTrainedTokenizerBase, delimiters: List[str] ) -> torch.Tensor: """ Check whether each sequence in the batch ends with any of the specified delimiter strings. Args: input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing token IDs for each sequence. tokenizer (PreTrainedTokenizerBase): HuggingFace tokenizer used to decode input_ids back to text. delimiters (List[str]): A list of delimiter strings to check against. If a sequence ends with any of these delimiters, it is marked as True. Returns: torch.Tensor: A boolean tensor of shape (batch_size, 1), where each entry indicates whether the corresponding sequence ends with one of the delimiters. """ batch_size = input_ids.size(0) # Initialize result tensor: False by default augmentation_decisions = torch.zeros(batch_size, 1, dtype=torch.bool, device=input_ids.device) # Fast path: compare suffix in token ID space if delimiter id tensors precomputed delimiter_id_tensors = getattr(getattr(tokenizer, "memgen_model_ref", None), "_delimiter_id_tensors", None) if delimiter_id_tensors is None: try: delimiter_id_tensors = [ torch.tensor(tokenizer.encode(d, add_special_tokens=False), dtype=torch.long, device=input_ids.device) for d in delimiters ] except Exception: delimiter_id_tensors = [] for i in range(batch_size): seq = input_ids[i] matched = False for did in delimiter_id_tensors: L = did.numel() if L == 0 or L > seq.numel(): continue if torch.equal(seq[-L:], did): matched = True break if not matched: # Fallback to decode once when necessary decoded = tokenizer.decode(seq.tolist()) for aug_str in delimiters: if decoded.endswith(aug_str): matched = True break augmentation_decisions[i] = matched return augmentation_decisions @registry.register_model("latmem") class LatentMemoryModel(BaseModel): @log_function_call def __init__( self, reasoner_model_name: str, weaver_model_name: str, prompt_latents_len: int, inference_latents_len: int, weaver_peft_config: Optional[PeftConfig] = None, trigger_model_name: str = None, trigger_peft_config: Optional[PeftConfig] = None, max_prompt_aug_num: int = 1, max_inference_aug_num: int = 5, ): super().__init__() # # build reasoner LLM # self.model = AutoModelForCausalLM.from_pretrained( # reasoner_model_name, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" # ) # todo: add vlm: qwen2.5-vl self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( reasoner_model_name, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) self.processor = AutoProcessor.from_pretrained(reasoner_model_name) # For text-only models, processor IS the tokenizer; for VLMs, processor has a tokenizer attribute if hasattr(self.processor, 'tokenizer'): self.tokenizer = self.processor.tokenizer else: self.tokenizer = self.processor self.config = self.model.config # build weaver LLM self.weaver = MemGenWeaver( weaver_model_name, prompt_latents_len, inference_latents_len, weaver_peft_config ) # build trigger LLM self.trigger = NanoTrigger() # always return true if trigger_model_name is not None: self.trigger = MemGenTrigger( trigger_model_name, trigger_peft_config ) logging.info(f"Use Trigger: {trigger_model_name}") # projection layers for mapping embeddings between reasoner and weaver # map reasoner input embeddings to weaver input embeddings self.reasoner_to_weaver = nn.Linear( self.model.config.hidden_size, self.weaver.config.hidden_size, dtype=torch.bfloat16 ) # NOTE: 2048 -> 1536 # Map weaver hidden states to reasoner input embeddings self.weaver_to_reasoner = nn.Linear( self.weaver.config.hidden_size, self.model.config.hidden_size, dtype=torch.bfloat16 ) # NOTE: 1536 -> 2048 self.delimiters: List[str] = [",", ".", "\n"] # delimiters for detecting augmentation points # self.delimiters: List[str] = [".", "\n"] self.max_prompt_aug_num = max_prompt_aug_num # insert latents after input prompt self.max_inference_aug_num = max_inference_aug_num # insert latents after specified delimiters # Validate augmentation configuration if max_prompt_aug_num == 0 and max_inference_aug_num == 0: logging.warning( "⚠️ Both max_prompt_aug_num and max_inference_aug_num are set to 0. " "This means no latent memory will be used. " "If you are training, this will cause errors because no trainable parameters will be involved. " "For training, at least one augmentation type must be enabled (> 0)." ) # Precompute delimiter token ID sequences for fast suffix checks (avoid repeated decode) try: self._delimiter_id_tensors = [ torch.tensor(self.tokenizer.encode(d, add_special_tokens=False), dtype=torch.long) for d in self.delimiters ] except Exception: # Fallback: empty list; code path will behave as if no delimiter found self._delimiter_id_tensors = [] # postprocess self._postprocess_models() self.warnings_issued = {} self.model_tags = None log_trainable_params(self) @log_function_call def add_model_tags(self, tags: Union[list[str], str]) -> None: r""" Add custom tags into the model that gets pushed to the Hugging Face Hub. Will not overwrite existing tags in the model. Args: tags (`Union[list[str], str]`): The desired tags to inject in the model Examples: ```python from transformers import AutoModel model = AutoModel.from_pretrained("google-bert/bert-base-cased") model.add_model_tags(["custom", "custom-bert"]) # Push the model to your namespace with the name "my-custom-bert". model.push_to_hub("my-custom-bert") ``` """ if isinstance(tags, str): tags = [tags] if self.model_tags is None: self.model_tags = [] for tag in tags: if tag not in self.model_tags: self.model_tags.append(tag) @log_function_call def _postprocess_models(self): """ Postprocess the components of the latent memory model: reasoner, weaver, trigger, and tokenizer. Steps: 1. Freeze all parameters of the reasoner (no gradient updates). 2. Cast all models to bfloat16 for memory and compute efficiency. 3. Ensure the tokenizer has a valid pad token: - If pad token is missing, use the EOS token as the pad token. - Set `padding_side` to "left" for compatibility with generation tasks. 4. Standardize the tokenizer's chat template to `CONVERSATION_TEMPLATE`. """ # Freeze all parameters of the reasoner by default fix_model_parameters(self.model) # Convert all sub-models to bfloat16 self.model = self.model.bfloat16() self.weaver = self.weaver.bfloat16() self.trigger = self.trigger.bfloat16() # Ensure tokenizer has a pad token if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.padding_side = "left" logging.info( f"Tokenizer has no pad token. Using EOS token ({self.tokenizer.eos_token}) as pad token." ) # Synchronize model config with tokenizer's special tokens to avoid warnings # Update model config if hasattr(self.model, 'config'): self.model.config.pad_token_id = self.tokenizer.pad_token_id self.model.config.bos_token_id = self.tokenizer.bos_token_id self.model.config.eos_token_id = self.tokenizer.eos_token_id # Update generation config if hasattr(self.model, 'generation_config'): if self.model.generation_config is not None: self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id self.model.generation_config.bos_token_id = self.tokenizer.bos_token_id self.model.generation_config.eos_token_id = self.tokenizer.eos_token_id # Also update processor's tokenizer if it exists if hasattr(self, 'processor') and hasattr(self.processor, 'tokenizer'): self.processor.tokenizer.pad_token_id = self.tokenizer.pad_token_id self.processor.tokenizer.bos_token_id = self.tokenizer.bos_token_id self.processor.tokenizer.eos_token_id = self.tokenizer.eos_token_id # IMPORTANT: Do NOT override VLM chat templates; they encode image placeholders. is_vlm = isinstance(self.model, Qwen2_5_VLForConditionalGeneration) or hasattr(self.processor, 'image_processor') if not is_vlm: self.processor.tokenizer.chat_template = CONVERSATION_TEMPLATE # Normalize the tokenizer's chat template (skip for VLMs) is_vlm = isinstance(self.model, Qwen2_5_VLForConditionalGeneration) or hasattr(getattr(self, 'processor', None), 'image_processor') if not is_vlm: self.tokenizer.chat_template = CONVERSATION_TEMPLATE logging.info( f"Synchronized special tokens - pad_token_id: {self.tokenizer.pad_token_id}, " f"bos_token_id: {self.tokenizer.bos_token_id}, eos_token_id: {self.tokenizer.eos_token_id}" ) @property def device(self): assert self.model.device == self.weaver.device == self.trigger.device return self.model.device @log_function_call def _forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[Tuple[int, int, int]] = None, **kwargs ) -> torch.Tensor: # preprocess inputs assert input_ids.shape == attention_mask.shape == labels.shape tokenizer = self.tokenizer reasoner = self.model weaver = self.weaver delimiters = self.delimiters max_augment_num = self.max_inference_aug_num # Limit the number of inference augmentation points to avoid excessive augmentation device = self.device embeds_dtype = reasoner.get_input_embeddings().weight.dtype B, _ = input_ids.shape hidden_size = reasoner.config.hidden_size # select augment idx augmentation_indices = self._select_augment_points_after_delimiter( input_ids, labels, delimiters, tokenizer, max_augment_num ) logging.info(f"augmentation_indices: {augmentation_indices}") # origin inputs embeds (use fused embeddings when image inputs are provided) if pixel_values is not None: with torch.no_grad(): seed_outputs = reasoner( input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw, attention_mask=attention_mask, return_dict=True, output_hidden_states=True, ) inputs_embeds = seed_outputs.hidden_states[0] else: inputs_embeds = reasoner.get_input_embeddings()(input_ids) # Initialize the start index and empty tensors for accumulating processed segments current_start_idx = 0 current_inputs_embeds = torch.empty((B, 0, hidden_size), device=device, dtype=embeds_dtype) current_attention_mask = torch.empty((B, 0), device=device, dtype=attention_mask.dtype) current_latents_mask = torch.empty((B, 0), device=device, dtype=torch.bool) # Prepare incremental weaver-side embeds to avoid re-projecting the whole history current_weaver_inputs_embeds = torch.empty( (B, 0, self.weaver.config.hidden_size), device=device, dtype=self.weaver_to_reasoner.weight.dtype ) # Iterate over the selected augmentation points for aug_point_idx in augmentation_indices: # Slice the current segment of original embeddings and attention mask segment_inputs_embeds = inputs_embeds[:, current_start_idx:aug_point_idx] segment_attention_mask = attention_mask[:, current_start_idx:aug_point_idx] segment_latents_mask = torch.zeros((B, segment_inputs_embeds.size(1)), device=device, dtype=torch.bool) # Concatenate the current segment to the accumulated embeddings and masks current_inputs_embeds = torch.cat([current_inputs_embeds, segment_inputs_embeds], dim=1) current_attention_mask = torch.cat([current_attention_mask, segment_attention_mask], dim=1) current_position_ids = generate_position_ids(current_attention_mask) current_latents_mask = torch.cat([current_latents_mask, segment_latents_mask], dim=1) # Project only the newly added segment to weaver space and append (incremental) segment_weaver_inputs = self.reasoner_to_weaver(segment_inputs_embeds) current_weaver_inputs_embeds = torch.cat([current_weaver_inputs_embeds, segment_weaver_inputs], dim=1) # Determine whether this point is the end of the prompt (prompt augmentation) is_prompt_end_aug = (labels[:, aug_point_idx] != -100).all() and (labels[:, aug_point_idx-1] == -100).all().item() # Depending on type, use weaver to augment prompt or inference if is_prompt_end_aug: logging.info(f"[augment] idx={aug_point_idx} TYPE=PROMPT") weaver_hidden_states, attn_mask, pos_ids = weaver.augment_prompt( current_weaver_inputs_embeds, current_attention_mask, current_position_ids ) else: logging.info(f"[augment] idx={aug_point_idx} TYPE=INFERENCE") weaver_hidden_states, attn_mask, pos_ids = weaver.augment_inference( current_weaver_inputs_embeds, current_attention_mask, current_position_ids ) # Map weaver hidden states back to reasoner embeddings latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states) # NOTE # Update accumulated embeddings and masks with the newly augmented segment current_inputs_embeds = torch.cat([current_inputs_embeds, latent_inputs_embeds], dim=1) current_attention_mask = torch.cat([current_attention_mask, attn_mask], dim=1) current_start_idx = aug_point_idx # Update latent mask for the newly added latent embeddings latent_mask = torch.ones((B, latent_inputs_embeds.size(1)), device=device, dtype=torch.bool) current_latents_mask = torch.cat([current_latents_mask, latent_mask], dim=1) # Keep weaver-side embeds in sync by appending the latent hidden states directly current_weaver_inputs_embeds = torch.cat([current_weaver_inputs_embeds, weaver_hidden_states], dim=1) # Process the remaining segment after the last augmentation point remaining_inputs_embeds = inputs_embeds[:, current_start_idx:] remaining_attention_mask = attention_mask[:, current_start_idx:] latent_mask = torch.zeros((B, remaining_attention_mask.size(1)), device=device, dtype=torch.bool) current_inputs_embeds = torch.cat([current_inputs_embeds, remaining_inputs_embeds], dim=1) current_attention_mask = torch.cat([current_attention_mask, remaining_attention_mask], dim=1) current_position_ids = generate_position_ids(current_attention_mask) current_latents_mask = torch.cat([current_latents_mask, latent_mask], dim=1) reasoner_outputs = reasoner( inputs_embeds=current_inputs_embeds, attention_mask=current_attention_mask, position_ids=current_position_ids ) logits = reasoner_outputs.logits # Identify valid positions in logits (positions that should contribute to loss) shifted = torch.zeros_like(current_latents_mask) shifted[:, :-1] = current_latents_mask[:, 1:] valid_mask = ~shifted valid_logits = logits[valid_mask].view(logits.size(0), -1, logits.size(2)) # assert shifted.sum() == current_latents_mask.sum() # assert valid_logits.shape[:2] == input_ids.shape return valid_logits # @log_function_call def _instructional_forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[Tuple[int, int, int]] = None, **kwargs ) -> torch.Tensor: """ Forward pass for single-turn instructional data (no multi-turn conversation required). This method is used for instruction-following tasks (SFT), where the input consists of a single instruction and the corresponding labels. It directly delegates to the single-turn forward method `_forward`. Args: input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing input token IDs. attention_mask (torch.Tensor): Tensor indicating padding positions. labels (torch.Tensor): Tensor containing the target labels for supervised fine-tuning. **kwargs: Additional keyword arguments passed to `_forward`. Returns: Tuple[torch.Tensor, torch.Tensor]: - logits: The output logits from the model for each input token. - labels: The same as input labels, used for loss computation. """ logits = self._forward( input_ids, attention_mask, labels, pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs, ) # For Instruction SFT, labels remain the same as input return logits, labels @log_function_call def _conversational_forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[Tuple[int, int, int]] = None, **kwargs ) -> torch.Tensor: """ Forward pass for conversational (multi-turn) data. Multi-turn forward is constructed by sequentially calling the single-turn forward for each conversation turn. Latents inserted in turn i-1 are not visible to turn i. Args: input_ids (torch.Tensor): Input token IDs, shape (1, seq_len). Batch size must be 1. attention_mask (torch.Tensor): Attention mask for input tokens. labels (torch.Tensor): Target labels for supervised fine-tuning (-100 for ignore positions). **kwargs: Additional arguments passed to `_forward`. Returns: Tuple[torch.Tensor, torch.Tensor]: - all_logits: Logits for the entire sequence, with zeros for unsupervised positions. - all_labels: Labels for the entire sequence, with -100 for unsupervised positions. """ assert input_ids.shape[0] == 1, "Conversational SFT currently only supports batch_size = 1" seq_len = input_ids.shape[1] vocab_size = self.config.vocab_size device = input_ids.device # Identify single-turn segments within the conversation based on labels label_row = labels[0] should_supervise = label_row != -100 if not should_supervise.any(): raise ValueError("At least one completion segment is required") # Compute the start and end indices of valid supervised segments valid_mask = should_supervise.int() diff = torch.diff(torch.cat([torch.tensor([0], device=device), valid_mask])) valid_starts = (diff == 1).nonzero(as_tuple=True)[0].tolist() # Transition 0 -> 1 ends = (diff == -1).nonzero(as_tuple=True)[0].tolist() # Transition 1 -> 0 if len(ends) < len(valid_starts): ends.append(seq_len) assert len(valid_starts) == len(ends) # Build triplets (start of previous segment, start of supervised segment, end of supervised segment) triplets = [] start = 0 for s, e in zip(valid_starts, ends): triplets.append((start, s, e)) start = e # If there are more segments than allowed, randomly select self.max_prompt_aug_num segments # Note: max_prompt_aug_num=0 means train all turns but without prompt augmentation (only inference aug) if self.max_prompt_aug_num == 0: # Train all turns, but without prompt-level augmentation (inference aug may still be used) select_turns = [1] * len(triplets) elif len(triplets) <= self.max_prompt_aug_num: select_turns = [1] * len(triplets) else: triplets_num = len(triplets) selected_indices = set(random.sample(range(triplets_num), self.max_prompt_aug_num)) select_turns = [1 if i in selected_indices else 0 for i in range(triplets_num)] # Initialize tensors to store logits and labels for the entire sequence all_logits = torch.zeros(1, seq_len, vocab_size, device=device) all_labels = torch.full((1, seq_len), -100, device=device) # Loop over each conversation turn and perform single-turn forward if supervised for triplet, should_supervise in zip(triplets, select_turns): start, valid_start, end = triplet if should_supervise: cur_input_ids = input_ids[0, :end].unsqueeze(0) cur_attention = attention_mask[0, :end].unsqueeze(0) # cur_labels only used for _forward, does not represent the true supervision range cur_labels = labels[0, :end].clone().unsqueeze(0) cur_labels[0, :valid_start] = -100 # Mask tokens before supervision start # Single-turn forward for the current conversation segment logits = self._forward( cur_input_ids, cur_attention, cur_labels, pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs, ) # Update overall logits and labels with the results of this segment all_logits[0, start:end, :] = logits[0, start:end, :] all_labels[0, start:end] = labels[0, start:end] # Return logits and labels: # - supervised positions retain computed logits and original labels # - unsupervised positions have logits = 0 and labels = -100 return all_logits, all_labels @log_function_call def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[Tuple[int, int, int]] = None, **kwargs ): tokenizer = self.tokenizer # Ensure labels are provided, required for training the reasoning processor assert labels is not None, "Reasoning Processor requires input labels for training" # Check augmentation configuration in training mode if self.training and self.max_prompt_aug_num == 0 and self.max_inference_aug_num == 0: raise ValueError( "Cannot train MemGen model with both max_prompt_aug_num=0 and max_inference_aug_num=0. " "At least one augmentation type must be enabled during training. " "If you want to use the base model without augmentation, please use the reasoner model directly. " "For inference/evaluation without augmentation, set the model to eval mode: model.eval()" ) # Determine whether the input is single-turn (instruction) or multi-turn (conversation) forward_func = self._instructional_forward if is_conversation(input_ids, tokenizer): # For conversational data, mask assistant tokens in labels logging.info("is_conversation: True") labels = postprocess_assistant_labels(input_ids, labels, tokenizer) forward_func = self._conversational_forward batch_size = 1 # Currently process one sequence per batch iter_num = input_ids.size(0) // batch_size # Forward pass per batch logits, supervised_labels = [], [] for i in range(iter_num): batch_input_ids = input_ids[i * batch_size: (i + 1) * batch_size] batch_attention_mask = attention_mask[i * batch_size: (i + 1) * batch_size] batch_labels = labels[i * batch_size: (i + 1) * batch_size] # Call the appropriate forward function (instruction or conversation) batch_logits, batch_supervised_labels = forward_func( # forward_func --> _forward input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels, pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs ) logits.append(batch_logits) supervised_labels.append(batch_supervised_labels) # Concatenate results from all batches all_logits = torch.concat(logits, dim=0) all_labels = torch.concat(supervised_labels, dim=0) # Compute causal language modeling loss (shifted by one) shift_logits = all_logits[..., :-1, :].contiguous() shift_labels = all_labels[..., 1:].contiguous() # assert shift_logits.shape[:-1] == shift_labels.shape loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # Return model outputs outputs = CausalLMOutputWithPast(loss=loss, logits=all_logits) outputs.supervised_labels = all_labels # Positions in input_ids that are supervised return outputs @log_function_call @torch.no_grad() def generate( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, generation_config: GenerationConfig = None, return_augmentation_mask: bool = False, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[Tuple[int, int, int]] = None, **kwargs ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: tokenizer = self.tokenizer reasoner = self.model weaver = self.weaver trigger = self.trigger delimiters = self.delimiters max_augment_num = self.max_inference_aug_num invalid_token_id = -100 # preproecess inputs input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) max_new_tokens = generation_config.max_new_tokens do_sample = generation_config.do_sample temperature = generation_config.temperature # control reasoner generate and trigger generate pad_token_id = tokenizer.pad_token_id # Prefer chat end token (<|im_end|>) if available; fall back to tokenizer.eos try: im_end_ids = tokenizer.encode("<|im_end|>", add_special_tokens=False) if isinstance(im_end_ids, list) and len(im_end_ids) == 1: eos_token_id = im_end_ids[0] else: eos_token_id = tokenizer.eos_token_id except Exception: eos_token_id = tokenizer.eos_token_id prompt_len = input_ids.size(1) generation_config = GenerationConfig( do_sample=do_sample, temperature=temperature, pad_token_id=pad_token_id, eos_token_id=eos_token_id, # use_cache=True ) with torch.no_grad(): outputs = reasoner( input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw, attention_mask=attention_mask, return_dict=True, output_hidden_states=True, ) # Use embedding layer output (hidden_states[0]) as inputs_embeds to ensure compatibility with model's embedding space. fused_embeds = outputs.hidden_states[0] # Embedding output after token & vision embedding projection # logging.info(f"Seed embeds shape (from hidden_states[0]): {fused_embeds.shape}") inputs_embeds = fused_embeds B, _, hidden_size = inputs_embeds.shape device = inputs_embeds.device try: current_inputs_embeds = inputs_embeds current_attention_mask = attention_mask current_position_ids = generate_position_ids(current_attention_mask) current_input_ids = input_ids # Apply prompt augmentation only if max_prompt_aug_num > 0 if self.max_prompt_aug_num > 0: logging.info("[Generation] Applying prompt augmentation") weaver_inputs_embeds = self.reasoner_to_weaver(current_inputs_embeds) weaver_hidden_states, attn_mask, pos_ids = weaver.augment_prompt( weaver_inputs_embeds, current_attention_mask, current_position_ids ) latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states) # Concatenate initial augmented prompt current_inputs_embeds = torch.cat([current_inputs_embeds, latent_inputs_embeds], dim=1) current_attention_mask = torch.cat([current_attention_mask, attn_mask], dim=1) current_position_ids = torch.cat([current_position_ids, pos_ids], dim=1) else: logging.info("[Generation] Skipping prompt augmentation (max_prompt_aug_num=0)") # Generation Loop Initialization sentence_augment_count = torch.zeros(B, dtype=torch.int, device=device) augmentation_pos = torch.full((B, max_new_tokens), fill_value=invalid_token_id, device=device) inserted_embeds: List[List[torch.Tensor]] = [[] for _ in range(B)] for i in range(max_new_tokens): # If all sequences in the batch have already generated an EOS token, stop early if (current_input_ids[:, -1] == eos_token_id).all(): break # Check if all sequences have reached the maximum number of augmentations if (sentence_augment_count >= max_augment_num).all(): # Adjust the remaining generation length generation_config.max_new_tokens = max_new_tokens - i # Perform generation for the remaining tokens using the reasoner generated = reasoner.generate( inputs_embeds=current_inputs_embeds, attention_mask=current_attention_mask, generation_config=generation_config, ) current_input_ids = torch.cat([current_input_ids, generated], dim=1) break outputs = reasoner( inputs_embeds=current_inputs_embeds, attention_mask=current_attention_mask, position_ids=current_position_ids, output_hidden_states=False, ) current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids = self._append_one_step( outputs, current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids, do_sample, temperature ) if i == max_new_tokens - 1: break # Determine which sentences in the batch should be augmented augment_decision = self._should_augment( current_input_ids, current_attention_mask, sentence_augment_count=sentence_augment_count, do_sample=do_sample, temperature=temperature ) augmentation_pos[:, i + 1] = augment_decision augment_indices = torch.where(augment_decision == 1)[0] # If there are sentences to augment, apply augmentation; others remain with left padding if len(augment_indices) > 0: # Increment the augmentation count for sentences that are being augmented sentence_augment_count[augment_indices] += 1 # Select embeddings, attention masks, and position IDs for sentences to be augmented candidate_inputs_embeds = current_inputs_embeds[augment_indices] candidate_attention_mask = current_attention_mask[augment_indices] candidate_position_ids = current_position_ids[augment_indices] # Perform inference augmentation using the weaver weaver_inputs_embeds = self.reasoner_to_weaver(candidate_inputs_embeds) weaver_hidden_states, attn_mask, _ = weaver.augment_inference( weaver_inputs_embeds, candidate_attention_mask, candidate_position_ids ) latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states) candidate_inputs_embeds = torch.cat([candidate_inputs_embeds, latent_inputs_embeds], dim=1) candidate_attention_mask = torch.cat([candidate_attention_mask, attn_mask], dim=1) # Create a single merged tensor for all sequences new_len = candidate_inputs_embeds.size(1) merged_inputs_embeds = torch.zeros((B, new_len, hidden_size), device=device, dtype=current_inputs_embeds.dtype) merged_attention_mask = torch.zeros((B, new_len), device=device, dtype=current_attention_mask.dtype) # Directly place augmented and non-augmented sequences merged_inputs_embeds[augment_indices] = candidate_inputs_embeds merged_attention_mask[augment_indices] = candidate_attention_mask # Non-augmented sequences now include both -100 and 0 non_augment_indices = torch.where(augment_decision != 1)[0] if len(non_augment_indices) > 0: non_aug_inputs_embeds = current_inputs_embeds[non_augment_indices] non_aug_attention_mask = current_attention_mask[non_augment_indices] non_aug_inputs_embeds, non_aug_attention_mask, _ = self._left_pad( non_aug_inputs_embeds, non_aug_attention_mask, None, weaver.inference_latents_num ) merged_inputs_embeds[non_augment_indices] = non_aug_inputs_embeds merged_attention_mask[non_augment_indices] = non_aug_attention_mask current_inputs_embeds = merged_inputs_embeds current_attention_mask = merged_attention_mask current_position_ids = generate_position_ids(current_attention_mask) # Record inserted embeds for post-processing for idx, embed in zip(augment_indices, latent_inputs_embeds): inserted_embeds[idx].append(embed.clone().detach().cpu()) except Exception: logging.exception( f"[generate] Exception. do_sample={do_sample}, temperature={temperature}, pad_id={pad_token_id}, eos_id={eos_token_id}, " f"pixel_present={pixel_values is not None}, max_new_tokens={max_new_tokens}" ) raise # postprocess new_generated_len = current_input_ids.size(1) - prompt_len augmentation_pos = augmentation_pos[:, :new_generated_len] if not return_augmentation_mask: return current_input_ids else: return current_input_ids, augmentation_pos @classmethod @log_function_call def from_config(cls, config): # reasoner configs reasoner_model_name = config.get("reasoner_model_name", None) max_prompt_aug_num = config.get("max_prompt_aug_num", None) max_inference_aug_num = config.get("max_inference_aug_num", None) # processor configs weaver_configs = config.get("weaver") weaver_model_name = weaver_configs.get("weaver_model_name", None) prompt_latents_len = weaver_configs.get("prompt_latents_len", 8) inference_latents_len = weaver_configs.get("inference_latents_len", 2) weaver_use_peft = weaver_configs.get("use_peft", True) weaver_peft_config = weaver_configs.get("peft_config", None) if weaver_use_peft else None if weaver_peft_config is not None: weaver_peft_config = LoraConfig(**weaver_peft_config) # trigger configs trigger_configs = config.get("trigger") trigger_model_name = trigger_configs.get("trigger_model_name", None) trigger_use_peft = trigger_configs.get("use_peft", True) trigger_peft_config = trigger_configs.get("peft_config", None) if trigger_use_peft else None if trigger_peft_config is not None: trigger_peft_config = LoraConfig(**trigger_peft_config) model = cls( reasoner_model_name, # weaver configs weaver_model_name=weaver_model_name, prompt_latents_len=prompt_latents_len, inference_latents_len=inference_latents_len, weaver_peft_config=weaver_peft_config, # trigger configs trigger_model_name=trigger_model_name, trigger_peft_config=trigger_peft_config, # augmentations max_prompt_aug_num=max_prompt_aug_num, max_inference_aug_num=max_inference_aug_num ) # load model state dict load_model_path = config.get("load_model_path", None) if load_model_path is not None: model_state_dict = load_state_dict_from_safetensor(load_model_path) model.load_state_dict(model_state_dict, strict=False) logging.info(f"Load model state dict from: {load_model_path}") return model # @log_function_call @torch.no_grad() def _append_one_step( self, reasoner_outputs, current_inputs_embeds: torch.Tensor, current_attention_mask: torch.Tensor, current_position_ids: torch.Tensor, current_input_ids: torch.Tensor, do_sample: bool, temperature: float ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: reasoner = self.model B = current_inputs_embeds.size(0) # Append next token next_token_logits = reasoner_outputs.logits[:, -1] next_token_ids = get_next_token(next_token_logits, do_sample, temperature) current_input_ids = torch.cat([current_input_ids, next_token_ids], dim=1) # Append next token embeds next_token_embeds = reasoner.get_input_embeddings()(next_token_ids) current_inputs_embeds = torch.cat([current_inputs_embeds, next_token_embeds], dim=1) # Append attention mask attn_mask = torch.ones((B, 1), dtype=current_attention_mask.dtype, device=current_attention_mask.device) current_attention_mask = torch.cat([current_attention_mask, attn_mask], dim=1) # Append position ids next_position_id = current_position_ids[:, -1:] + 1 current_position_ids = torch.cat([current_position_ids, next_position_id], dim=1) return current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids # @log_function_call def _select_augment_points_after_delimiter( self, input_ids: torch.Tensor, labels: torch.Tensor, delimiters: List[str], tokenizer: PreTrainedTokenizerBase, max_num: int = 10, ) -> List[int]: """ Select positions in a sequence suitable for data augmentation based on labels and delimiters. This function identifies two types of augmentation points: 1. **Prompt augmentation points**: positions where the label transitions from -100 to a valid token. - Typically corresponds to the start of the assistant's response. - Can be disabled if max_prompt_aug_num=0. 2. **Inference augmentation points**: positions inside valid label regions that follow a delimiter. - Only the first `max_num` points are kept. Args: input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing token IDs. labels (torch.Tensor): Tensor of shape (batch_size, seq_len), where -100 indicates positions that should not contribute to loss. delimiters (List[str]): List of string delimiters used to determine inference augmentation points. tokenizer (PreTrainedTokenizerBase): Tokenizer used to decode input_ids for delimiter matching. max_num (int, optional): Maximum number of inference augmentation points to select. Default is 10. Returns: List[int]: Sorted list of selected augmentation positions in the sequence. Raises: ValueError: If the prompt augmentation point count is not exactly one when it's expected. RuntimeError: If no valid augmentation points are found when both types are enabled. """ assert input_ids.shape == labels.shape B, seq_len = input_ids.size(0), input_ids.size(1) prompt_augment_idx = [] inference_augment_idx = [] for i in range(1, seq_len): # Skip the first token and last token for augmentation # Detect the boundary between prompt and label for prompt augmentation if (labels[:, i] != -100).all() and (labels[:, i - 1] == -100).all(): prompt_augment_idx.append(i) # Detect valid label regions for inference augmentation elif (labels[:, i] != -100).all() and (labels[:, i - 1] != -100).all(): batch_tokens_before_i = input_ids[:, :i] # Assume check_ends_with_delimiter is defined if any(check_ends_with_delimiter(batch_tokens_before_i, tokenizer, delimiters)): inference_augment_idx.append(i) # Validate prompt augmentation point exists (data format check) if len(prompt_augment_idx) != 1: raise ValueError("Single-turn forward must have exactly one prompt augment index in data format") final_points = [] # Include prompt augmentation point only if max_prompt_aug_num > 0 if self.max_prompt_aug_num > 0: final_points = prompt_augment_idx[:1] logging.info(f"[Augmentation] Prompt augmentation enabled: adding point at {prompt_augment_idx[0]}") else: logging.info(f"[Augmentation] Prompt augmentation disabled (max_prompt_aug_num=0)") # Limit the number of inference augmentation points to max_num if len(inference_augment_idx) > max_num: inference_augment_idx = inference_augment_idx[:max_num] final_points.extend(inference_augment_idx) # Allow empty augmentation points if both are disabled if len(final_points) == 0: if self.max_prompt_aug_num == 0 and max_num == 0: logging.warning("[Augmentation] No augmentation points selected (both prompt and inference disabled)") return [] else: raise RuntimeError("No valid augmentation points found despite augmentation being enabled") final_points.sort() logging.info(f"[Augmentation] Final augmentation points: {final_points}") return final_points # @log_function_call @torch.no_grad() def _should_augment( self, input_ids, attention_mask, sentence_augment_count: torch.Tensor, do_sample: bool, temperature: float = 0.0 ) -> torch.Tensor: """ Determine whether each sentence in the batch should be augmented based on the model's strategy and trigger predictions. aug_mask values: - -100: No sampling or augmentation decision not applicable - 0 : Sampled but trigger indicates no augmentation - 1 : Sampled and trigger indicates augmentation Args: input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len) attention_mask (torch.Tensor): Attention mask for input_ids sentence_augment_count (torch.Tensor): Tracks how many times each sentence has been augmented do_sample (bool): Whether to sample from the trigger output temperature (float): Sampling temperature Returns: torch.Tensor: Augmentation mask for each sentence in the batch """ tokenizer = self.tokenizer delimiters = self.delimiters trigger = self.trigger max_augment_num = self.max_inference_aug_num batch_size = input_ids.size(0) # Initialize aug_mask with -100, meaning no augmentation by default aug_mask = torch.full((batch_size,), -100, dtype=torch.long, device=input_ids.device) # Mark sentences ending with delimiters as candidates for augmentation (set to 0) ends_with_delimiters = check_ends_with_delimiter(input_ids, tokenizer, delimiters).squeeze(1) aug_mask[ends_with_delimiters] = 0 # If a sentence has already reached the max augmentation count, reset to -100 over_limit = (sentence_augment_count >= max_augment_num) aug_mask[over_limit] = -100 # Apply trigger model only on sentences that are candidates (aug_mask != -100) trigger_indices = (aug_mask != -100).nonzero(as_tuple=True)[0] if trigger_indices.numel() > 0: trigger_logits = trigger( input_ids=input_ids[trigger_indices], attention_mask=attention_mask[trigger_indices], ) last_token_logits = trigger_logits[:, -1] # [num_trigger_samples, 2] # Sample the next token to decide whether to augment next_tokens = get_next_token(last_token_logits, do_sample, temperature).view(-1) # Update aug_mask: 0 = no augmentation, 1 = augment aug_mask[trigger_indices] = next_tokens return aug_mask # @log_function_call @torch.no_grad() def _left_pad( self, input_embeds: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, pad_num: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if input_embeds is not None: B, L, D = input_embeds.shape pad_embeds = torch.zeros((B, pad_num, D), dtype=input_embeds.dtype, device=input_embeds.device) input_embeds = torch.cat([pad_embeds, input_embeds], dim=1) # [B, pad_num + L, D] if attention_mask is not None: B = attention_mask.size(0) pad_mask = torch.zeros((B, pad_num), dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([pad_mask, attention_mask], dim=1) # [B, pad_num + L] if position_ids is not None: B = position_ids.size(0) pad_pos = torch.zeros((B, pad_num), dtype=position_ids.dtype, device=position_ids.device) position_ids = torch.cat([pad_pos, position_ids], dim=1) # [B, pad_num + L] return input_embeds, attention_mask, position_ids @log_function_call @torch.no_grad() def _left_clip_pad_tokens( self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Left-trim padding tokens based on the attention mask. This function identifies the leftmost padding (attention_mask=0) in each sequence. If all sequences in the batch have at least some padding on the left, it trims the batch by the minimal left-padding length to remove unnecessary computation. Args: inputs_embeds (torch.Tensor): Input embeddings of shape (B, L, D) attention_mask (torch.Tensor): Attention mask of shape (B, L) position_ids (torch.Tensor): Position IDs of shape (B, L) Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Trimmed inputs_embeds, attention_mask, and position_ids """ B, L, D = inputs_embeds.shape device = inputs_embeds.device # Find the index of the first non-padding token in each sequence first_nonpad_idx = [] for b in range(B): nonzero = (attention_mask[b] != 0).nonzero(as_tuple=True)[0] if len(nonzero) == 0: # Entire row is padding; can potentially trim the whole sequence first_nonpad_idx.append(L) else: first_nonpad_idx.append(nonzero[0].item()) # Determine the minimum number of left-padding tokens across the batch min_pad = min(first_nonpad_idx) # If no padding on the left, return original tensors if min_pad == 0: return inputs_embeds, attention_mask, position_ids # Trim the left-padding from all sequences in the batch inputs_embeds = inputs_embeds[:, min_pad:, :] attention_mask = attention_mask[:, min_pad:] position_ids = position_ids[:, min_pad:] return inputs_embeds, attention_mask, position_ids def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """ Enable gradient checkpointing for all sub-modules that support it. """ kwargs = gradient_checkpointing_kwargs or {} for module in [self.model, getattr(self, "weaver", None), getattr(self, "trigger", None)]: if hasattr(module, "gradient_checkpointing_enable"): module.gradient_checkpointing_enable(**kwargs) def gradient_checkpointing_disable(self): """ Disable gradient checkpointing for all sub-modules that support it. """ for module in [self.model, getattr(self, "weaver", None), getattr(self, "trigger", None)]: if hasattr(module, "gradient_checkpointing_disable"): module.gradient_checkpointing_disable() @property def is_gradient_checkpointing(self): """ Check if gradient checkpointing is enabled for any sub-module. """ for module in [self.model, getattr(self, "weaver", None), getattr(self, "trigger", None)]: if hasattr(module, "is_gradient_checkpointing"): if module.is_gradient_checkpointing: return True return False