File size: 4,879 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.nn as nn
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer
)
from peft import PeftConfig, get_peft_model

from typing import Optional, Tuple

# Decorator to log function calls in blue
import functools
import logging

def log_function_call(func):
    """Decorator to log function calls with blue color."""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        func_name = func.__name__
        # Blue color ANSI code
        # logging.info(f"\033[94m[Weaver] {func_name}\033[0m")
        return func(*args, **kwargs)
    return wrapper


class MemGenWeaver(torch.nn.Module):
    """
    Weaver module for the MemGen Model.
    - Input: the weaver receives `inputs_embeds` from the reasoner model's current decoding sequence.
    - Output: the weaver produces a sequence of hidden states with length K, 
      which are concatenated to the original `inputs_embeds` to alter the reasoner's decoding path.
    """
    def __init__(
        self, 
        pretrained_model_name_or_path: str, 
        prompt_latents_num: int,
        inference_latents_num: int,
        peft_config: Optional[PeftConfig] = None
    ):
        super().__init__()
        
        # base model
        self.model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
        )
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
        if peft_config is not None:
            self.model = get_peft_model(self.model, peft_config)
        
        self.config = self.model.config
        
        # prompt augmentation
        self.prompt_query_latents = nn.Parameter(
            torch.randn(prompt_latents_num, self.config.hidden_size), 
            requires_grad=True
        )

        # inference augmentation
        self.inference_query_latents = nn.Parameter(
            torch.randn(inference_latents_num, self.config.hidden_size), 
            requires_grad=True
        )
    
    @property
    def prompt_latents_num(self) -> int:
        return self.prompt_query_latents.size(0)

    @property
    def inference_latents_num(self) -> int:
        return self.inference_query_latents.size(0)

    @property
    def device(self):
        return self.model.device

    @log_function_call
    def _augment(
        self, 
        latents: torch.Tensor,
        inputs_embeds: torch.Tensor, 
        attention_mask: torch.Tensor, 
        position_ids: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        
        batch_size = attention_mask.shape[0]
        latents_num = latents.size(0)
        latents = latents.unsqueeze(0).repeat(batch_size, 1, 1)
        
        # inputs_embeds
        inputs_embeds = torch.cat([inputs_embeds, latents], dim=1)

        # attention_mask: (B, L_total)
        latents_mask = torch.ones(latents.shape[:-1], dtype=attention_mask.dtype, device=attention_mask.device)
        attention_mask = torch.cat([attention_mask, latents_mask], dim=1)
        
        # get position ids
        last_position_ids = position_ids.max(dim=1)[0]
        latents_relative_positions = torch.arange(latents_num, device=attention_mask.device)
        latents_position_ids = last_position_ids.unsqueeze(1) + latents_relative_positions + 1
        position_ids = torch.cat([position_ids.long(), latents_position_ids.long()], dim=1) 

        # the processor only outputs the hidden states
        assert inputs_embeds.shape[:2] == attention_mask.shape == position_ids.shape

        outputs = self.model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,  
            output_hidden_states=True,
        )
        hidden_states = outputs.hidden_states[-1]
        latents_hidden_states = hidden_states[:, -latents_num:, :]

        return latents_hidden_states, latents_mask, latents_position_ids

    @log_function_call
    def augment_prompt(
        self, 
        inputs_embeds: torch.Tensor, 
        attention_mask: torch.Tensor, 
        position_ids: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self._augment(
            latents=self.prompt_query_latents,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids
        )

    @log_function_call
    def augment_inference(
        self, 
        inputs_embeds: torch.Tensor, 
        attention_mask: torch.Tensor, 
        position_ids: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self._augment(
            latents=self.inference_query_latents,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids
        )