alverciito commited on
Commit ·
869db96
1
Parent(s): 5da73fb
fix huggingface model weight missmatch
Browse files
model.py
CHANGED
|
@@ -4,6 +4,9 @@
|
|
| 4 |
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
# #
|
| 6 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
|
|
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 9 |
from src.model import SegmentationNetwork
|
|
@@ -138,7 +141,13 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 138 |
|
| 139 |
# Core PyTorch model
|
| 140 |
self.model = SegmentationNetwork(self.to_model_config(config))
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
self.model.load_state_dict(state_dict, strict=True)
|
| 143 |
self.model.eval()
|
| 144 |
|
|
|
|
| 4 |
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
# #
|
| 6 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
import os
|
| 8 |
+
from safetensors.torch import load_file
|
| 9 |
+
|
| 10 |
import torch
|
| 11 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 12 |
from src.model import SegmentationNetwork
|
|
|
|
| 141 |
|
| 142 |
# Core PyTorch model
|
| 143 |
self.model = SegmentationNetwork(self.to_model_config(config))
|
| 144 |
+
|
| 145 |
+
weights_path = os.path.join(
|
| 146 |
+
config._name_or_path,
|
| 147 |
+
"model.safetensors"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
state_dict = load_file(weights_path)
|
| 151 |
self.model.load_state_dict(state_dict, strict=True)
|
| 152 |
self.model.eval()
|
| 153 |
|