File size: 6,095 Bytes
3a2aa34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
002a82b
51fef75
 
 
 
3a2aa34
 
 
 
 
 
 
 
 
 
 
 
 
de6cdb1
3a2aa34
 
 
 
 
 
 
 
de6cdb1
 
3a2aa34
de6cdb1
51fef75
 
3a2aa34
 
 
905068f
 
 
 
3a2aa34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
import torch.nn as nn
import torch.nn.functional as F
from .connector import Connector
from .projector import Projector
from .tensor_initializer import TensorInitializer
from .custom_sfx import CustomSoftMax
import numpy as np
import warnings

from typing import Literal

import logging
logger = logging.getLogger(__name__)

class UserLearner(nn.Module):
    
    k: int      # the number of groups
    llm: nn.Module
    projectors: list[Projector]
    u_id_set: set
    softmax: nn.Module
    is_partition: bool

    def __init__(
        self,
        k: int,
        llm: nn.Module,
        projectors: list[Projector],
        softmax: nn.Module,
        is_partition: bool=False,
    ):
        super().__init__()

        self.k = k
        self.llm = llm
        self.softmax = softmax
        # init user_id registration table and user weights dictionary
        self.u_id_set = set()
        self.W = nn.ParameterDict()
        self.tmp_store_user_ideal_points = None
        # register all k projectors in the moduledict
        assert len(projectors) == k, f"The num of projectors should match up with num of groups: {k} != {len(projectors)}"
        self.projectors = nn.ModuleDict()
        for i in range(k):
            self.projectors[str(i)] = projectors[i]
        self.is_partition = is_partition

    def init_weight(self, u_ids:list, reinit:bool=False):
        for u_id in u_ids:
            if u_id not in self.u_id_set or reinit:
                self.W[u_id] = nn.Parameter(
                    torch.randn((self.k), dtype=torch.float32),
                    requires_grad=True,
                ).to(next(self.projectors[str(0)].parameters()).device)
                self.u_id_set.add(u_id)
            else:
                logger.warning('๐Ÿ‘‹ wait? same user?')

    def get_sfx_w(self, u_ids:list):
        w = torch.stack([self.W[key] for key in u_ids], dim=0)   # (bs, k)
        w = self.softmax(w)
        return w

    def get_hardmax_w(self, u_ids:list):
        w = torch.stack([self.W[key] for key in u_ids], dim=0)
        w = F.one_hot(w.argmax(dim=1), num_classes=self.k).float()  # (bs, k)
        return w

    def infer_gk(self, prompt_tokens, rm_cached=None):
        '''
        prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
        If you want to activate rm_cached, please pass in the rm_cached dict or empty dict.
        '''
        input_ids = prompt_tokens['input_ids']
        attention_mask = prompt_tokens['attention_mask']
        
        if rm_cached is None:
            embeds = self.llm(
                input_ids=input_ids,
                attention_mask=attention_mask,
            ).last_hidden_state
        else:
            res = self.llm(
                input_ids=input_ids[:, -1:],
                # attention_mask=attention_mask,
                past_key_values=rm_cached["user_learner"],
                use_cache=True 
            )
            rm_cached["user_learner"] = res.past_key_values
            embeds = res.last_hidden_state

        # embeds shape: (bs, seq_len, hid_dim)
        shape = embeds.shape
        # only last hidden state start (only use the last token of the prompt)
        embeds = embeds[:, -1, :]  # (bs, seq_len, hid_dim) -> (bs, hid_dim)
        embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1)  # (bs, hid_dim) -> (bs, seq_len, hid_dim)
        # only last hidden state end
        # logger.critical("using only last hidden state of prompt tokens")
        embeds = embeds.view(-1, shape[-1])  # (bs*seq_len, hid_dim)
        # g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim)
        logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1)
        if rm_cached is None:
            return logits
        else:
            return logits, rm_cached   # (bs, k, seq_len, hidden_size)

    def return_user_ideal_points(self):
        if self.tmp_store_user_ideal_points == None:
            raise ValueError('No user ideal points stored')
        return self.tmp_store_user_ideal_points

    def forward(self, uid, prompt_tokens, rm_cached=None):    # only pass the prompt tokens
        '''
        prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
        '''
        if rm_cached is None:
            prompt_logits = self.infer_gk(prompt_tokens)
        else:
            prompt_logits, rm_cached = self.infer_gk(prompt_tokens, rm_cached)
        bs = prompt_tokens['input_ids'].size(0)
        w = self.get_sfx_w([uid]*bs)
        # assert sum(mix_weight) == 1
        # w = self.softmax(mix_weight.repeat(bs, 1))
        # w = mix_weight.repeat(bs, 1)
        # logger.info(f"{w=}")
        # logger.info(f"{w.shape=}")
        w = w.unsqueeze(-1).unsqueeze(-1)
        y_hat = (w * prompt_logits).sum(dim=1)
        self.tmp_store_user_ideal_points = y_hat
        if rm_cached is None:
            return y_hat
        else:
            return y_hat, rm_cached
    
    def eval(self):
        super().eval()
        if self.is_partition:
            warnings.warn("๐Ÿค– UserPromptLearner(Partition version) is in eval mode: argmax")
            self.is_argmax = True
        else:
            warnings.warn("๐Ÿค– UserPromptLearner(Mixture version) is in eval mode: sfx")
            self.is_argmax = False
    
    def train(self, mode: bool = True):
        super().train(mode)
        if mode:
            if self.is_partition:
                warnings.warn("๐Ÿค– UserPromptLearner(Partition version) is in train mode: sfx")
                self.is_argmax = False
            else:
                warnings.warn("๐Ÿค– UserPromptLearner(Mixture version) is in train mode: sfx")
                self.is_argmax = False
        else:
            if self.is_partition:
                warnings.warn("๐Ÿค– UserPromptLearner(Partition version) is in eval mode: argmax")
                self.is_argmax = True
            else:
                warnings.warn("๐Ÿค– UserPromptLearner(Mixture version) is in eval mode: sfx")
                self.is_argmax = False