File size: 6,511 Bytes
3a2aa34 de6cdb1 3a2aa34 de6cdb1 3a2aa34 de6cdb1 3a2aa34 de6cdb1 3a2aa34 de6cdb1 3a2aa34 de6cdb1 51fef75 3a2aa34 de6cdb1 3a2aa34 de6cdb1 51fef75 3a2aa34 002a82b 3a2aa34 51fef75 3a2aa34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
#!/usr/bin/env python
# -*-coding:utf-8 -*-
'''
@Desc: This is the implementation of PAL-B
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig
from .connector import Connector
from .tensor_initializer import TensorInitializer
from .custom_sfx import CustomSoftMax
from .itemLearner import ItemLearner
from .userLearner import UserLearner
from collections import defaultdict
from typing import Literal, Optional, Tuple
import logging
logger = logging.getLogger(__name__)
class BasePrefLearner(nn.Module):
def __init__(
self,
d_hid: int,
d_pref: int,
k: int,
llm_name: str,
pref_learner_type: Literal["dist","dist_normalization","angle","norm","dist_logistic","angle_hinge"],
proj_arch: str,
initializer_type: Literal["gaussian"],
is_expectation_norm_init: bool, # the tensor initialization parameters
sfx_type: Literal["gumbel_softmax", "softmax"],
sfx_temperature: float,
is_temperature_learnable: bool,
is_gumbel_hard: Optional[bool]=None,
is_partition: bool=False,
seed: int=42,
**kwargs
):
super().__init__()
self.pref_learner_type = pref_learner_type
self.is_temperature_learnable = is_temperature_learnable
# init all necessary modules
model_config = AutoConfig.from_pretrained(llm_name)
self.llm = AutoModel.from_pretrained(llm_name,from_tf=bool(".ckpt" in llm_name),config=model_config)
self.tensor_initializer = TensorInitializer(initializer_type, seed, is_expectation_norm_init=is_expectation_norm_init)
self.projector_f = Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref)
self.projectors_gk = [Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref) for _ in range(k)]
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.softmax_w = CustomSoftMax(sfx_type=sfx_type,
temperature=sfx_temperature,
is_temperature_learnable=is_temperature_learnable,
is_gumbel_hard=is_gumbel_hard)
self.item_learner = ItemLearner(
llm = self.llm,
projector=self.projector_f,
)
self.is_partition = is_partition
self.user_learner = UserLearner(k=k, llm=self.llm, projectors=self.projectors_gk, softmax=self.softmax_w, is_partition=is_partition)
logger.critical('🛑 Remember to call update_trainable_params() after the model is initialized.')
def update_trainable_params(self, fix_modules: Tuple[str,...]=()):
# capture params
self.trainable_params = defaultdict(list)
if "llm" not in fix_modules:
self.trainable_params["llm"] = self.llm.parameters()
else:
self.llm.eval()
if "itemLearnerProjector" not in fix_modules:
self.trainable_params["projector_f"].extend(self.item_learner.projector.parameters())
if "userLearnerProjector" not in fix_modules:
self.trainable_params["projectors_gk"].extend(list(self.user_learner.projectors.parameters()))
if "W" not in fix_modules:
self.trainable_params["W"] = self.user_learner.W.parameters()
if self.pref_learner_type in ["angle","dist_logistic"] and "logit_scale" not in fix_modules:
self.trainable_params["logit_scale"] = self.logit_scale
if self.is_temperature_learnable and "temperature" not in fix_modules:
self.trainable_params["temperature"] = self.softmax_w.temperature
def map_to_pref_embedding_space(self, x, rm_cached=None):
# (
# uid,
# {
# 'input_ids': prompt_input_ids,\
# 'attention_mask': prompt_attention_mask,
# },\
# {
# 'input_ids': eval_input_ids,\
# 'attention_mask': eval_attention_mask,\
# })
uid, prompt, items = x
if rm_cached is None:
items_prime = self.item_learner(items)
prompt_prime = self.user_learner(uid, prompt)
return items_prime, prompt_prime
else:
items_prime, rm_cached = self.item_learner(items, rm_cached)
prompt_prime, rm_cached = self.user_learner(uid, prompt, rm_cached)
return items_prime, prompt_prime, rm_cached
class PrefLearner(BasePrefLearner): # <f(x),f(u)>
def __init__(self,*args, **kwargs):
super().__init__(*args, **kwargs)
def specify_user_ids(self, uid): # personalize the model for a specific user
self.uid = uid
def forward(self, x, rm_cached=None):
assert self.uid is not None, "Please specify the user id first by calling specify_user_ids() to personalize the reward model"
prompt, items = x
if rm_cached is None:
items_prime, prompt_prime = self.map_to_pref_embedding_space((self.uid, prompt, items))
else:
items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((self.uid, prompt, items), rm_cached)
# logger.critical(f"{items_prime[0]=}")
# logger.critical(f"{prompt_prime[0]=}")
# logger.critical(f"{items_prime.shape=}")
# logger.critical(f"{prompt_prime.shape=}")
if self.pref_learner_type == 'angle':
# NOTICE: here we implement the "last token only" version of PAL-B
prompt_last_prime = prompt_prime[:, -1, :]
prompt_last_prime = prompt_last_prime.unsqueeze(1)
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
items_last_prime = items_prime[:, -1, :]
items_last_prime = items_last_prime.unsqueeze(1)
items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True)
logit_scale = self.logit_scale.exp()
clamped_logit_scale = torch.clamp(logit_scale, max=100)
# logger.critical(f"{prompt_last_prime.shape=}")
# logger.critical(f"{items_last_prime.shape=}")
sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length)
if rm_cached is None:
return sim_score
else:
return sim_score, rm_cached
else:
raise NotImplementedError
|