Update README.md
Browse files
README.md
CHANGED
|
@@ -222,6 +222,14 @@ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
|
|
| 222 |
ckpt = torch.load(model_path, map_location="cpu")
|
| 223 |
state = ckpt["model_state"] if "model_state" in ckpt else ckpt
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
# Initialize scoring head (single linear layer)
|
| 226 |
hidden_size = base_model.config.hidden_size
|
| 227 |
scoring_head = torch.nn.Linear(hidden_size, 1)
|
|
|
|
| 222 |
ckpt = torch.load(model_path, map_location="cpu")
|
| 223 |
state = ckpt["model_state"] if "model_state" in ckpt else ckpt
|
| 224 |
|
| 225 |
+
head_state = {
|
| 226 |
+
k.replace("score.", ""): v
|
| 227 |
+
for k, v in state.items()
|
| 228 |
+
if k.startswith("score.")
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
assert set(head_state.keys()) == {"weight", "bias"}
|
| 232 |
+
|
| 233 |
# Initialize scoring head (single linear layer)
|
| 234 |
hidden_size = base_model.config.hidden_size
|
| 235 |
scoring_head = torch.nn.Linear(hidden_size, 1)
|