alverciito commited on
Commit
869db96
·
1 Parent(s): 5da73fb

fix huggingface model weight missmatch

Browse files
Files changed (1) hide show
  1. model.py +10 -1
model.py CHANGED
@@ -4,6 +4,9 @@
4
  # Universidad de Alcalá - Escuela Politécnica Superior #
5
  # #
6
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
 
 
 
7
  import torch
8
  from transformers import PreTrainedModel, PretrainedConfig
9
  from src.model import SegmentationNetwork
@@ -138,7 +141,13 @@ class SentenceCoseNet(PreTrainedModel):
138
 
139
  # Core PyTorch model
140
  self.model = SegmentationNetwork(self.to_model_config(config))
141
- state_dict = torch.load("model.safetensors", map_location="cpu")
 
 
 
 
 
 
142
  self.model.load_state_dict(state_dict, strict=True)
143
  self.model.eval()
144
 
 
4
  # Universidad de Alcalá - Escuela Politécnica Superior #
5
  # #
6
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ import os
8
+ from safetensors.torch import load_file
9
+
10
  import torch
11
  from transformers import PreTrainedModel, PretrainedConfig
12
  from src.model import SegmentationNetwork
 
141
 
142
  # Core PyTorch model
143
  self.model = SegmentationNetwork(self.to_model_config(config))
144
+
145
+ weights_path = os.path.join(
146
+ config._name_or_path,
147
+ "model.safetensors"
148
+ )
149
+
150
+ state_dict = load_file(weights_path)
151
  self.model.load_state_dict(state_dict, strict=True)
152
  self.model.eval()
153