File size: 6,095 Bytes
3a2aa34 002a82b 51fef75 3a2aa34 de6cdb1 3a2aa34 de6cdb1 3a2aa34 de6cdb1 51fef75 3a2aa34 905068f 3a2aa34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | import torch
import torch.nn as nn
import torch.nn.functional as F
from .connector import Connector
from .projector import Projector
from .tensor_initializer import TensorInitializer
from .custom_sfx import CustomSoftMax
import numpy as np
import warnings
from typing import Literal
import logging
logger = logging.getLogger(__name__)
class UserLearner(nn.Module):
k: int # the number of groups
llm: nn.Module
projectors: list[Projector]
u_id_set: set
softmax: nn.Module
is_partition: bool
def __init__(
self,
k: int,
llm: nn.Module,
projectors: list[Projector],
softmax: nn.Module,
is_partition: bool=False,
):
super().__init__()
self.k = k
self.llm = llm
self.softmax = softmax
# init user_id registration table and user weights dictionary
self.u_id_set = set()
self.W = nn.ParameterDict()
self.tmp_store_user_ideal_points = None
# register all k projectors in the moduledict
assert len(projectors) == k, f"The num of projectors should match up with num of groups: {k} != {len(projectors)}"
self.projectors = nn.ModuleDict()
for i in range(k):
self.projectors[str(i)] = projectors[i]
self.is_partition = is_partition
def init_weight(self, u_ids:list, reinit:bool=False):
for u_id in u_ids:
if u_id not in self.u_id_set or reinit:
self.W[u_id] = nn.Parameter(
torch.randn((self.k), dtype=torch.float32),
requires_grad=True,
).to(next(self.projectors[str(0)].parameters()).device)
self.u_id_set.add(u_id)
else:
logger.warning('๐ wait? same user?')
def get_sfx_w(self, u_ids:list):
w = torch.stack([self.W[key] for key in u_ids], dim=0) # (bs, k)
w = self.softmax(w)
return w
def get_hardmax_w(self, u_ids:list):
w = torch.stack([self.W[key] for key in u_ids], dim=0)
w = F.one_hot(w.argmax(dim=1), num_classes=self.k).float() # (bs, k)
return w
def infer_gk(self, prompt_tokens, rm_cached=None):
'''
prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
If you want to activate rm_cached, please pass in the rm_cached dict or empty dict.
'''
input_ids = prompt_tokens['input_ids']
attention_mask = prompt_tokens['attention_mask']
if rm_cached is None:
embeds = self.llm(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
else:
res = self.llm(
input_ids=input_ids[:, -1:],
# attention_mask=attention_mask,
past_key_values=rm_cached["user_learner"],
use_cache=True
)
rm_cached["user_learner"] = res.past_key_values
embeds = res.last_hidden_state
# embeds shape: (bs, seq_len, hid_dim)
shape = embeds.shape
# only last hidden state start (only use the last token of the prompt)
embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim)
embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim)
# only last hidden state end
# logger.critical("using only last hidden state of prompt tokens")
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hid_dim)
# g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim)
logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1)
if rm_cached is None:
return logits
else:
return logits, rm_cached # (bs, k, seq_len, hidden_size)
def return_user_ideal_points(self):
if self.tmp_store_user_ideal_points == None:
raise ValueError('No user ideal points stored')
return self.tmp_store_user_ideal_points
def forward(self, uid, prompt_tokens, rm_cached=None): # only pass the prompt tokens
'''
prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
'''
if rm_cached is None:
prompt_logits = self.infer_gk(prompt_tokens)
else:
prompt_logits, rm_cached = self.infer_gk(prompt_tokens, rm_cached)
bs = prompt_tokens['input_ids'].size(0)
w = self.get_sfx_w([uid]*bs)
# assert sum(mix_weight) == 1
# w = self.softmax(mix_weight.repeat(bs, 1))
# w = mix_weight.repeat(bs, 1)
# logger.info(f"{w=}")
# logger.info(f"{w.shape=}")
w = w.unsqueeze(-1).unsqueeze(-1)
y_hat = (w * prompt_logits).sum(dim=1)
self.tmp_store_user_ideal_points = y_hat
if rm_cached is None:
return y_hat
else:
return y_hat, rm_cached
def eval(self):
super().eval()
if self.is_partition:
warnings.warn("๐ค UserPromptLearner(Partition version) is in eval mode: argmax")
self.is_argmax = True
else:
warnings.warn("๐ค UserPromptLearner(Mixture version) is in eval mode: sfx")
self.is_argmax = False
def train(self, mode: bool = True):
super().train(mode)
if mode:
if self.is_partition:
warnings.warn("๐ค UserPromptLearner(Partition version) is in train mode: sfx")
self.is_argmax = False
else:
warnings.warn("๐ค UserPromptLearner(Mixture version) is in train mode: sfx")
self.is_argmax = False
else:
if self.is_partition:
warnings.warn("๐ค UserPromptLearner(Partition version) is in eval mode: argmax")
self.is_argmax = True
else:
warnings.warn("๐ค UserPromptLearner(Mixture version) is in eval mode: sfx")
self.is_argmax = False |