LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
import torch
import torch.nn as nn
from transformers import (
PreTrainedModel,
AutoModelForCausalLM,
AutoTokenizer,
Cache
)
from transformers.utils import TransformersKwargs
from transformers.processing_utils import Unpack
from typing import Optional
from abc import ABC, abstractmethod
from peft import PeftConfig, get_peft_model
class Trigger(torch.nn.Module, ABC):
def __init__(self):
pass
@abstractmethod
def forward(self, **kwargs) -> bool:
...
class NanoTrigger(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("_device", torch.tensor(0.0))
@property
def device(self):
return self._device.device
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> bool:
# This "nano trigger" always predicts insertion.
# It outputs logits where the probability of insertion (index=1) is set to 1.0
# for every token position in the batch.
batch_size, seq_len = input_ids.shape
logits = torch.zeros(batch_size, seq_len, 2, device=input_ids.device)
logits[..., 1] = 1.0
return logits
class MemGenTrigger(torch.nn.Module):
"""
Trigger module for the MemGen Model.
- Input: the weaver receives `inputs_embeds` from the reasoner model's current decoding sequence.
- Output: the weaver produces a sequence of hidden states with length K,
which are concatenated to the original `inputs_embeds` to alter the reasoner's decoding path.
"""
def __init__(
self,
pretrained_model_name_or_path: str,
peft_config: Optional[PeftConfig] = None
):
super().__init__()
# base model
self.model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
self.model = self._postprocess(self.model)
if peft_config is not None:
self.model = get_peft_model(self.model, peft_config)
self.config = self.model.config
@property
def device(self):
return self.model.device
def _postprocess(self, model: PreTrainedModel):
for parameter in model.parameters():
parameter.requires_grad = True
# Replace lm_head with a binary classification head
hidden_size = model.config.hidden_size
classification_head = nn.Linear(hidden_size, 2)
model.lm_head = classification_head
# Ensure the new head parameters are trainable
for param in model.lm_head.parameters():
param.requires_grad = True
return model
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
"""Trigger decision mechanism for sequence generation.
The trigger determines its decision based on the already generated `input_ids`.
It is influenced by the data distribution but is independent of the weaver module.
Args:
input_ids (Optional[torch.LongTensor]): Token IDs of the generated sequence.
attention_mask (Optional[torch.Tensor]): Attention mask to avoid attending to padding tokens. Defaults to None.
**kwargs: Additional keyword arguments passed to the underlying model.
Returns:
torch.Tensor: Logits tensor of shape `(batch_size, seq_len, num_classes)`,
representing the trigger's decision probabilities.
"""
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs
).logits