alverciito commited on
Commit ·
4aa1cf7
1
Parent(s): 869db96
fix huggingface model weight missmatch
Browse files
model.py
CHANGED
|
@@ -142,18 +142,12 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 142 |
# Core PyTorch model
|
| 143 |
self.model = SegmentationNetwork(self.to_model_config(config))
|
| 144 |
|
| 145 |
-
weights_path = os.path.join(
|
| 146 |
-
config._name_or_path,
|
| 147 |
-
"model.safetensors"
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
state_dict = load_file(weights_path)
|
| 151 |
-
self.model.load_state_dict(state_dict, strict=True)
|
| 152 |
-
self.model.eval()
|
| 153 |
-
|
| 154 |
# Initialize weights following HF conventions
|
| 155 |
self.post_init()
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
def encode(
|
| 158 |
self,
|
| 159 |
input_ids: torch.Tensor,
|
|
|
|
| 142 |
# Core PyTorch model
|
| 143 |
self.model = SegmentationNetwork(self.to_model_config(config))
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
# Initialize weights following HF conventions
|
| 146 |
self.post_init()
|
| 147 |
|
| 148 |
+
# Set evaluation mode by default
|
| 149 |
+
self.model.eval()
|
| 150 |
+
|
| 151 |
def encode(
|
| 152 |
self,
|
| 153 |
input_ids: torch.Tensor,
|