|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from .connector import Connector |
|
|
from .projector import Projector |
|
|
from .tensor_initializer import TensorInitializer |
|
|
from .custom_sfx import CustomSoftMax |
|
|
import numpy as np |
|
|
import warnings |
|
|
|
|
|
from typing import Literal |
|
|
|
|
|
import logging |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class UserLearner(nn.Module): |
|
|
|
|
|
k: int |
|
|
llm: nn.Module |
|
|
projectors: list[Projector] |
|
|
u_id_set: set |
|
|
softmax: nn.Module |
|
|
is_partition: bool |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
k: int, |
|
|
llm: nn.Module, |
|
|
projectors: list[Projector], |
|
|
softmax: nn.Module, |
|
|
is_partition: bool=False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.k = k |
|
|
self.llm = llm |
|
|
self.softmax = softmax |
|
|
|
|
|
self.u_id_set = set() |
|
|
self.W = nn.ParameterDict() |
|
|
self.tmp_store_user_ideal_points = None |
|
|
|
|
|
assert len(projectors) == k, f"The num of projectors should match up with num of groups: {k} != {len(projectors)}" |
|
|
self.projectors = nn.ModuleDict() |
|
|
for i in range(k): |
|
|
self.projectors[str(i)] = projectors[i] |
|
|
self.is_partition = is_partition |
|
|
|
|
|
def init_weight(self, u_ids:list, reinit:bool=False): |
|
|
for u_id in u_ids: |
|
|
if u_id not in self.u_id_set or reinit: |
|
|
self.W[u_id] = nn.Parameter( |
|
|
torch.randn((self.k), dtype=torch.float32), |
|
|
requires_grad=True, |
|
|
).to(next(self.projectors[str(0)].parameters()).device) |
|
|
self.u_id_set.add(u_id) |
|
|
else: |
|
|
logger.warning('π wait? same user?') |
|
|
|
|
|
def get_sfx_w(self, u_ids:list): |
|
|
w = torch.stack([self.W[key] for key in u_ids], dim=0) |
|
|
w = self.softmax(w) |
|
|
return w |
|
|
|
|
|
def get_hardmax_w(self, u_ids:list): |
|
|
w = torch.stack([self.W[key] for key in u_ids], dim=0) |
|
|
w = F.one_hot(w.argmax(dim=1), num_classes=self.k).float() |
|
|
return w |
|
|
|
|
|
def infer_gk(self, prompt_tokens, rm_cached=None): |
|
|
''' |
|
|
prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor} |
|
|
If you want to activate rm_cached, please pass in the rm_cached dict or empty dict. |
|
|
''' |
|
|
input_ids = prompt_tokens['input_ids'] |
|
|
attention_mask = prompt_tokens['attention_mask'] |
|
|
|
|
|
if rm_cached is None: |
|
|
embeds = self.llm( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
).last_hidden_state |
|
|
else: |
|
|
res = self.llm( |
|
|
input_ids=input_ids[:, -1:], |
|
|
|
|
|
past_key_values=rm_cached["user_learner"], |
|
|
use_cache=True |
|
|
) |
|
|
rm_cached["user_learner"] = res.past_key_values |
|
|
embeds = res.last_hidden_state |
|
|
|
|
|
|
|
|
shape = embeds.shape |
|
|
|
|
|
embeds = embeds[:, -1, :] |
|
|
embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) |
|
|
|
|
|
|
|
|
embeds = embeds.view(-1, shape[-1]) |
|
|
|
|
|
logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1) |
|
|
if rm_cached is None: |
|
|
return logits |
|
|
else: |
|
|
return logits, rm_cached |
|
|
|
|
|
def return_user_ideal_points(self): |
|
|
if self.tmp_store_user_ideal_points == None: |
|
|
raise ValueError('No user ideal points stored') |
|
|
return self.tmp_store_user_ideal_points |
|
|
|
|
|
def forward(self, uid, prompt_tokens, rm_cached=None): |
|
|
''' |
|
|
prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor} |
|
|
''' |
|
|
if rm_cached is None: |
|
|
prompt_logits = self.infer_gk(prompt_tokens) |
|
|
else: |
|
|
prompt_logits, rm_cached = self.infer_gk(prompt_tokens, rm_cached) |
|
|
bs = prompt_tokens['input_ids'].size(0) |
|
|
w = self.get_sfx_w([uid]*bs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
w = w.unsqueeze(-1).unsqueeze(-1) |
|
|
y_hat = (w * prompt_logits).sum(dim=1) |
|
|
self.tmp_store_user_ideal_points = y_hat |
|
|
if rm_cached is None: |
|
|
return y_hat |
|
|
else: |
|
|
return y_hat, rm_cached |
|
|
|
|
|
def eval(self): |
|
|
super().eval() |
|
|
if self.is_partition: |
|
|
warnings.warn("π€ UserPromptLearner(Partition version) is in eval mode: argmax") |
|
|
self.is_argmax = True |
|
|
else: |
|
|
warnings.warn("π€ UserPromptLearner(Mixture version) is in eval mode: sfx") |
|
|
self.is_argmax = False |
|
|
|
|
|
def train(self, mode: bool = True): |
|
|
super().train(mode) |
|
|
if mode: |
|
|
if self.is_partition: |
|
|
warnings.warn("π€ UserPromptLearner(Partition version) is in train mode: sfx") |
|
|
self.is_argmax = False |
|
|
else: |
|
|
warnings.warn("π€ UserPromptLearner(Mixture version) is in train mode: sfx") |
|
|
self.is_argmax = False |
|
|
else: |
|
|
if self.is_partition: |
|
|
warnings.warn("π€ UserPromptLearner(Partition version) is in eval mode: argmax") |
|
|
self.is_argmax = True |
|
|
else: |
|
|
warnings.warn("π€ UserPromptLearner(Mixture version) is in eval mode: sfx") |
|
|
self.is_argmax = False |