alverciito commited on
Commit ·
2fde924
1
Parent(s): 33c844e
fix huggingface model weight missmatch
Browse files
model.py
CHANGED
|
@@ -4,9 +4,6 @@
|
|
| 4 |
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
# #
|
| 6 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
-
import os
|
| 8 |
-
from safetensors.torch import load_file
|
| 9 |
-
|
| 10 |
import torch
|
| 11 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 12 |
from src.model import SegmentationNetwork
|
|
@@ -291,9 +288,9 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 291 |
"""
|
| 292 |
# Convert to type:
|
| 293 |
if len(input_ids.shape) == 2:
|
| 294 |
-
x = input_ids.int().unsqueeze(
|
| 295 |
-
mask = attention_mask.unsqueeze(
|
| 296 |
-
output = self.model(x=x, mask=mask).squeeze(
|
| 297 |
elif len(input_ids.shape) == 3:
|
| 298 |
x = input_ids.int()
|
| 299 |
mask = attention_mask if attention_mask is not None else None
|
|
|
|
| 4 |
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
# #
|
| 6 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
|
|
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 9 |
from src.model import SegmentationNetwork
|
|
|
|
| 288 |
"""
|
| 289 |
# Convert to type:
|
| 290 |
if len(input_ids.shape) == 2:
|
| 291 |
+
x = input_ids.int().unsqueeze(0)
|
| 292 |
+
mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
|
| 293 |
+
output = self.model(x=x, mask=mask).squeeze(0)
|
| 294 |
elif len(input_ids.shape) == 3:
|
| 295 |
x = input_ids.int()
|
| 296 |
mask = attention_mask if attention_mask is not None else None
|