Add pipeline tag and library name
Browse filesThis PR improves the model card by adding a relevant pipeline tag and library name.
README.md
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
base_model:
|
| 4 |
- google/gemma-2-9b-it
|
|
|
|
|
|
|
|
|
|
| 5 |
---
|
| 6 |
|
| 7 |
# General Preference Representation Model (GPM)
|
|
@@ -131,8 +133,7 @@ def get_reward_model(base_causal_model, base_llm_model, value_head_dim: int, add
|
|
| 131 |
block_values = self.prompt_head(prompt_hidden_states).view(batch_size, dim // 2)
|
| 132 |
block_values = torch.softmax(block_values, dim=-1)
|
| 133 |
|
| 134 |
-
# Create a batch of zero matrices [batch_size, dim, dim]
|
| 135 |
-
batch_R_matrices = torch.zeros((batch_size, dim, dim), device=device, dtype=dtype)
|
| 136 |
|
| 137 |
# Fill only the block diagonal entries with the learned values
|
| 138 |
for i in range(0, dim, 2):
|
|
|
|
| 1 |
---
|
|
|
|
| 2 |
base_model:
|
| 3 |
- google/gemma-2-9b-it
|
| 4 |
+
license: apache-2.0
|
| 5 |
+
pipeline_tag: text-generation
|
| 6 |
+
library_name: transformers
|
| 7 |
---
|
| 8 |
|
| 9 |
# General Preference Representation Model (GPM)
|
|
|
|
| 133 |
block_values = self.prompt_head(prompt_hidden_states).view(batch_size, dim // 2)
|
| 134 |
block_values = torch.softmax(block_values, dim=-1)
|
| 135 |
|
| 136 |
+
# Create a batch of zero matrices [batch_size, dim, dim]\n batch_R_matrices = torch.zeros((batch_size, dim, dim), device=device, dtype=dtype)
|
|
|
|
| 137 |
|
| 138 |
# Fill only the block diagonal entries with the learned values
|
| 139 |
for i in range(0, dim, 2):
|