File size: 4,879 Bytes
e34b94f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import torch
import torch.nn as nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer
)
from peft import PeftConfig, get_peft_model
from typing import Optional, Tuple
# Decorator to log function calls in blue
import functools
import logging
def log_function_call(func):
"""Decorator to log function calls with blue color."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
func_name = func.__name__
# Blue color ANSI code
# logging.info(f"\033[94m[Weaver] {func_name}\033[0m")
return func(*args, **kwargs)
return wrapper
class MemGenWeaver(torch.nn.Module):
"""
Weaver module for the MemGen Model.
- Input: the weaver receives `inputs_embeds` from the reasoner model's current decoding sequence.
- Output: the weaver produces a sequence of hidden states with length K,
which are concatenated to the original `inputs_embeds` to alter the reasoner's decoding path.
"""
def __init__(
self,
pretrained_model_name_or_path: str,
prompt_latents_num: int,
inference_latents_num: int,
peft_config: Optional[PeftConfig] = None
):
super().__init__()
# base model
self.model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
if peft_config is not None:
self.model = get_peft_model(self.model, peft_config)
self.config = self.model.config
# prompt augmentation
self.prompt_query_latents = nn.Parameter(
torch.randn(prompt_latents_num, self.config.hidden_size),
requires_grad=True
)
# inference augmentation
self.inference_query_latents = nn.Parameter(
torch.randn(inference_latents_num, self.config.hidden_size),
requires_grad=True
)
@property
def prompt_latents_num(self) -> int:
return self.prompt_query_latents.size(0)
@property
def inference_latents_num(self) -> int:
return self.inference_query_latents.size(0)
@property
def device(self):
return self.model.device
@log_function_call
def _augment(
self,
latents: torch.Tensor,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = attention_mask.shape[0]
latents_num = latents.size(0)
latents = latents.unsqueeze(0).repeat(batch_size, 1, 1)
# inputs_embeds
inputs_embeds = torch.cat([inputs_embeds, latents], dim=1)
# attention_mask: (B, L_total)
latents_mask = torch.ones(latents.shape[:-1], dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, latents_mask], dim=1)
# get position ids
last_position_ids = position_ids.max(dim=1)[0]
latents_relative_positions = torch.arange(latents_num, device=attention_mask.device)
latents_position_ids = last_position_ids.unsqueeze(1) + latents_relative_positions + 1
position_ids = torch.cat([position_ids.long(), latents_position_ids.long()], dim=1)
# the processor only outputs the hidden states
assert inputs_embeds.shape[:2] == attention_mask.shape == position_ids.shape
outputs = self.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=True,
)
hidden_states = outputs.hidden_states[-1]
latents_hidden_states = hidden_states[:, -latents_num:, :]
return latents_hidden_states, latents_mask, latents_position_ids
@log_function_call
def augment_prompt(
self,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self._augment(
latents=self.prompt_query_latents,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids
)
@log_function_call
def augment_inference(
self,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self._augment(
latents=self.inference_query_latents,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids
) |