LossFunctionLover commited on
Commit
14d62ee
·
verified ·
1 Parent(s): 44d5eef

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -0
README.md CHANGED
@@ -222,6 +222,14 @@ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
222
  ckpt = torch.load(model_path, map_location="cpu")
223
  state = ckpt["model_state"] if "model_state" in ckpt else ckpt
224
 
 
 
 
 
 
 
 
 
225
  # Initialize scoring head (single linear layer)
226
  hidden_size = base_model.config.hidden_size
227
  scoring_head = torch.nn.Linear(hidden_size, 1)
 
222
  ckpt = torch.load(model_path, map_location="cpu")
223
  state = ckpt["model_state"] if "model_state" in ckpt else ckpt
224
 
225
+ head_state = {
226
+ k.replace("score.", ""): v
227
+ for k, v in state.items()
228
+ if k.startswith("score.")
229
+ }
230
+
231
+ assert set(head_state.keys()) == {"weight", "bias"}
232
+
233
  # Initialize scoring head (single linear layer)
234
  hidden_size = base_model.config.hidden_size
235
  scoring_head = torch.nn.Linear(hidden_size, 1)