alverciito commited on
Commit
4aa1cf7
·
1 Parent(s): 869db96

fix huggingface model weight missmatch

Browse files
Files changed (1) hide show
  1. model.py +3 -9
model.py CHANGED
@@ -142,18 +142,12 @@ class SentenceCoseNet(PreTrainedModel):
142
  # Core PyTorch model
143
  self.model = SegmentationNetwork(self.to_model_config(config))
144
 
145
- weights_path = os.path.join(
146
- config._name_or_path,
147
- "model.safetensors"
148
- )
149
-
150
- state_dict = load_file(weights_path)
151
- self.model.load_state_dict(state_dict, strict=True)
152
- self.model.eval()
153
-
154
  # Initialize weights following HF conventions
155
  self.post_init()
156
 
 
 
 
157
  def encode(
158
  self,
159
  input_ids: torch.Tensor,
 
142
  # Core PyTorch model
143
  self.model = SegmentationNetwork(self.to_model_config(config))
144
 
 
 
 
 
 
 
 
 
 
145
  # Initialize weights following HF conventions
146
  self.post_init()
147
 
148
+ # Set evaluation mode by default
149
+ self.model.eval()
150
+
151
  def encode(
152
  self,
153
  input_ids: torch.Tensor,