alverciito commited on
Commit
2fde924
·
1 Parent(s): 33c844e

fix huggingface model weight missmatch

Browse files
Files changed (1) hide show
  1. model.py +3 -6
model.py CHANGED
@@ -4,9 +4,6 @@
4
  # Universidad de Alcalá - Escuela Politécnica Superior #
5
  # #
6
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
- import os
8
- from safetensors.torch import load_file
9
-
10
  import torch
11
  from transformers import PreTrainedModel, PretrainedConfig
12
  from src.model import SegmentationNetwork
@@ -291,9 +288,9 @@ class SentenceCoseNet(PreTrainedModel):
291
  """
292
  # Convert to type:
293
  if len(input_ids.shape) == 2:
294
- x = input_ids.int().unsqueeze(1)
295
- mask = attention_mask.unsqueeze(1) if attention_mask is not None else None
296
- output = self.model(x=x, mask=mask).squeeze(1)
297
  elif len(input_ids.shape) == 3:
298
  x = input_ids.int()
299
  mask = attention_mask if attention_mask is not None else None
 
4
  # Universidad de Alcalá - Escuela Politécnica Superior #
5
  # #
6
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
 
 
 
7
  import torch
8
  from transformers import PreTrainedModel, PretrainedConfig
9
  from src.model import SegmentationNetwork
 
288
  """
289
  # Convert to type:
290
  if len(input_ids.shape) == 2:
291
+ x = input_ids.int().unsqueeze(0)
292
+ mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
293
+ output = self.model(x=x, mask=mask).squeeze(0)
294
  elif len(input_ids.shape) == 3:
295
  x = input_ids.int()
296
  mask = attention_mask if attention_mask is not None else None