pal-b-large-opt-350m / userLearner.py
daiweichen's picture
Upload PAL_B_RM_opt
002a82b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from .connector import Connector
from .projector import Projector
from .tensor_initializer import TensorInitializer
from .custom_sfx import CustomSoftMax
import numpy as np
import warnings
from typing import Literal
import logging
logger = logging.getLogger(__name__)
class UserLearner(nn.Module):
k: int # the number of groups
llm: nn.Module
projectors: list[Projector]
u_id_set: set
softmax: nn.Module
is_partition: bool
def __init__(
self,
k: int,
llm: nn.Module,
projectors: list[Projector],
softmax: nn.Module,
is_partition: bool=False,
):
super().__init__()
self.k = k
self.llm = llm
self.softmax = softmax
# init user_id registration table and user weights dictionary
self.u_id_set = set()
self.W = nn.ParameterDict()
self.tmp_store_user_ideal_points = None
# register all k projectors in the moduledict
assert len(projectors) == k, f"The num of projectors should match up with num of groups: {k} != {len(projectors)}"
self.projectors = nn.ModuleDict()
for i in range(k):
self.projectors[str(i)] = projectors[i]
self.is_partition = is_partition
def init_weight(self, u_ids:list, reinit:bool=False):
for u_id in u_ids:
if u_id not in self.u_id_set or reinit:
self.W[u_id] = nn.Parameter(
torch.randn((self.k), dtype=torch.float32),
requires_grad=True,
).to(next(self.projectors[str(0)].parameters()).device)
self.u_id_set.add(u_id)
else:
logger.warning('πŸ‘‹ wait? same user?')
def get_sfx_w(self, u_ids:list):
w = torch.stack([self.W[key] for key in u_ids], dim=0) # (bs, k)
w = self.softmax(w)
return w
def get_hardmax_w(self, u_ids:list):
w = torch.stack([self.W[key] for key in u_ids], dim=0)
w = F.one_hot(w.argmax(dim=1), num_classes=self.k).float() # (bs, k)
return w
def infer_gk(self, prompt_tokens, rm_cached=None):
'''
prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
If you want to activate rm_cached, please pass in the rm_cached dict or empty dict.
'''
input_ids = prompt_tokens['input_ids']
attention_mask = prompt_tokens['attention_mask']
if rm_cached is None:
embeds = self.llm(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
else:
res = self.llm(
input_ids=input_ids[:, -1:],
# attention_mask=attention_mask,
past_key_values=rm_cached["user_learner"],
use_cache=True
)
rm_cached["user_learner"] = res.past_key_values
embeds = res.last_hidden_state
# embeds shape: (bs, seq_len, hid_dim)
shape = embeds.shape
# only last hidden state start (only use the last token of the prompt)
embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim)
embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim)
# only last hidden state end
# logger.critical("using only last hidden state of prompt tokens")
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hid_dim)
# g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim)
logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1)
if rm_cached is None:
return logits
else:
return logits, rm_cached # (bs, k, seq_len, hidden_size)
def return_user_ideal_points(self):
if self.tmp_store_user_ideal_points == None:
raise ValueError('No user ideal points stored')
return self.tmp_store_user_ideal_points
def forward(self, uid, prompt_tokens, rm_cached=None): # only pass the prompt tokens
'''
prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
'''
if rm_cached is None:
prompt_logits = self.infer_gk(prompt_tokens)
else:
prompt_logits, rm_cached = self.infer_gk(prompt_tokens, rm_cached)
bs = prompt_tokens['input_ids'].size(0)
w = self.get_sfx_w([uid]*bs)
# assert sum(mix_weight) == 1
# w = self.softmax(mix_weight.repeat(bs, 1))
# w = mix_weight.repeat(bs, 1)
# logger.info(f"{w=}")
# logger.info(f"{w.shape=}")
w = w.unsqueeze(-1).unsqueeze(-1)
y_hat = (w * prompt_logits).sum(dim=1)
self.tmp_store_user_ideal_points = y_hat
if rm_cached is None:
return y_hat
else:
return y_hat, rm_cached
def eval(self):
super().eval()
if self.is_partition:
warnings.warn("πŸ€– UserPromptLearner(Partition version) is in eval mode: argmax")
self.is_argmax = True
else:
warnings.warn("πŸ€– UserPromptLearner(Mixture version) is in eval mode: sfx")
self.is_argmax = False
def train(self, mode: bool = True):
super().train(mode)
if mode:
if self.is_partition:
warnings.warn("πŸ€– UserPromptLearner(Partition version) is in train mode: sfx")
self.is_argmax = False
else:
warnings.warn("πŸ€– UserPromptLearner(Mixture version) is in train mode: sfx")
self.is_argmax = False
else:
if self.is_partition:
warnings.warn("πŸ€– UserPromptLearner(Partition version) is in eval mode: argmax")
self.is_argmax = True
else:
warnings.warn("πŸ€– UserPromptLearner(Mixture version) is in eval mode: sfx")
self.is_argmax = False