Upload PAL_B_RM_opt
Browse files- config.json +13 -1
- configuration_pal_b_rm.py +2 -0
- modeling_pal_b_rm.py +3 -1
config.json
CHANGED
|
@@ -20,5 +20,17 @@
|
|
| 20 |
"sfx_temperature": 1.0,
|
| 21 |
"sfx_type": "softmax",
|
| 22 |
"torch_dtype": "float32",
|
| 23 |
-
"transformers_version": "4.44.2"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
}
|
|
|
|
| 20 |
"sfx_temperature": 1.0,
|
| 21 |
"sfx_type": "softmax",
|
| 22 |
"torch_dtype": "float32",
|
| 23 |
+
"transformers_version": "4.44.2",
|
| 24 |
+
"uids": [
|
| 25 |
+
"KZL1qeRzHNYSfDAuOctL1iyVV8WC5N",
|
| 26 |
+
"ZzGCcAhvqF0HnKxNsUjtJFadcZdyZj",
|
| 27 |
+
"p4Oh7rUGyLe1EpilJFWr9sPDpkO016",
|
| 28 |
+
"qo6WIyEh27cwAjWpA3Q60J7NaDxzQJ",
|
| 29 |
+
"zKV8BFGy60O0q7102ALF84S6Jo5i4q",
|
| 30 |
+
"i8YiBZlrYmlkkChr5b9BUKvDO6lR1d",
|
| 31 |
+
"M3icahkfAtC9CJrtKgQ7qvyZ5SD8wC",
|
| 32 |
+
"HNzkrs9geGu1YMMfZ5Qvdt0ZaCthfB",
|
| 33 |
+
"Jxv4hxfb9zTVa5nsMDFlnjSX5LZ8MK",
|
| 34 |
+
"UhQipwcpQmiGJmScocXOGOKyCBaFUg"
|
| 35 |
+
]
|
| 36 |
}
|
configuration_pal_b_rm.py
CHANGED
|
@@ -17,6 +17,7 @@ class PAL_B_Config(PretrainedConfig):
|
|
| 17 |
sfx_temperature: float = 1.0,
|
| 18 |
is_temperature_learnable: bool = False,
|
| 19 |
is_gumbel_hard: bool = None,
|
|
|
|
| 20 |
**kwargs,
|
| 21 |
):
|
| 22 |
self.d_hid = d_hid
|
|
@@ -31,5 +32,6 @@ class PAL_B_Config(PretrainedConfig):
|
|
| 31 |
self.sfx_temperature = sfx_temperature
|
| 32 |
self.is_temperature_learnable = is_temperature_learnable
|
| 33 |
self.is_gumbel_hard = is_gumbel_hard
|
|
|
|
| 34 |
super().__init__(**kwargs)
|
| 35 |
|
|
|
|
| 17 |
sfx_temperature: float = 1.0,
|
| 18 |
is_temperature_learnable: bool = False,
|
| 19 |
is_gumbel_hard: bool = None,
|
| 20 |
+
uids: list = None,
|
| 21 |
**kwargs,
|
| 22 |
):
|
| 23 |
self.d_hid = d_hid
|
|
|
|
| 32 |
self.sfx_temperature = sfx_temperature
|
| 33 |
self.is_temperature_learnable = is_temperature_learnable
|
| 34 |
self.is_gumbel_hard = is_gumbel_hard
|
| 35 |
+
self.uids = uids
|
| 36 |
super().__init__(**kwargs)
|
| 37 |
|
modeling_pal_b_rm.py
CHANGED
|
@@ -20,8 +20,10 @@ class PAL_B_RM_opt(PreTrainedModel):
|
|
| 20 |
sfx_temperature=config.sfx_temperature,
|
| 21 |
is_temperature_learnable=config.is_temperature_learnable,
|
| 22 |
is_gumbel_hard=config.is_gumbel_hard,
|
|
|
|
| 23 |
)
|
| 24 |
-
|
|
|
|
| 25 |
|
| 26 |
def forward(self, x):
|
| 27 |
logits = self.model(x)
|
|
|
|
| 20 |
sfx_temperature=config.sfx_temperature,
|
| 21 |
is_temperature_learnable=config.is_temperature_learnable,
|
| 22 |
is_gumbel_hard=config.is_gumbel_hard,
|
| 23 |
+
uids=config.uids,
|
| 24 |
)
|
| 25 |
+
if config.uids is not None:
|
| 26 |
+
self.model.user_learner.init_weight(config.uids)
|
| 27 |
|
| 28 |
def forward(self, x):
|
| 29 |
logits = self.model(x)
|