Upload PAL_B_RM_opt
Browse files- learner.py +1 -2
- userLearner.py +1 -1
learner.py
CHANGED
|
@@ -122,9 +122,8 @@ class PrefLearner(BasePrefLearner): # <f(x),f(u)>
|
|
| 122 |
# logger.critical(f"{prompt_prime[0]=}")
|
| 123 |
# logger.critical(f"{items_prime.shape=}")
|
| 124 |
# logger.critical(f"{prompt_prime.shape=}")
|
| 125 |
-
# FIXME: bug exist here
|
| 126 |
if self.pref_learner_type == 'angle':
|
| 127 |
-
#
|
| 128 |
prompt_last_prime = prompt_prime[:, -1, :]
|
| 129 |
prompt_last_prime = prompt_last_prime.unsqueeze(1)
|
| 130 |
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
|
|
|
|
| 122 |
# logger.critical(f"{prompt_prime[0]=}")
|
| 123 |
# logger.critical(f"{items_prime.shape=}")
|
| 124 |
# logger.critical(f"{prompt_prime.shape=}")
|
|
|
|
| 125 |
if self.pref_learner_type == 'angle':
|
| 126 |
+
# NOTICE: here we implement the "last token only" version of PAL-B
|
| 127 |
prompt_last_prime = prompt_prime[:, -1, :]
|
| 128 |
prompt_last_prime = prompt_last_prime.unsqueeze(1)
|
| 129 |
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
|
userLearner.py
CHANGED
|
@@ -92,7 +92,7 @@ class UserLearner(nn.Module):
|
|
| 92 |
|
| 93 |
# embeds shape: (bs, seq_len, hid_dim)
|
| 94 |
shape = embeds.shape
|
| 95 |
-
# only last hidden state start
|
| 96 |
embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim)
|
| 97 |
embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim)
|
| 98 |
# only last hidden state end
|
|
|
|
| 92 |
|
| 93 |
# embeds shape: (bs, seq_len, hid_dim)
|
| 94 |
shape = embeds.shape
|
| 95 |
+
# only last hidden state start (only use the last token of the prompt)
|
| 96 |
embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim)
|
| 97 |
embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim)
|
| 98 |
# only last hidden state end
|