FadQ commited on
Commit
e946e9f
·
1 Parent(s): 9ada12c

fix: fix siamese config

Browse files
Files changed (1) hide show
  1. services/inference_siamese.py +2 -6
services/inference_siamese.py CHANGED
@@ -20,13 +20,9 @@ class SiameseIndoBert(nn.Module):
20
  Siamese network using IndoBERT as the encoder.
21
  Takes two text inputs and predicts binary alignment.
22
  """
23
- def __init__(self, model_name: str, dropout: float = 0.1, hidden_dropout: float = 0.1):
24
  super().__init__()
25
- self.encoder = AutoModel.from_pretrained(
26
- model_name,
27
- use_safetensors=True,
28
- output_hidden_states=False
29
- )
30
  hs = self.encoder.config.hidden_size
31
  self.dropout = nn.Dropout(dropout)
32
  # Classifier: takes concatenation [hA, hB, |hA-hB|, hA*hB]
 
20
  Siamese network using IndoBERT as the encoder.
21
  Takes two text inputs and predicts binary alignment.
22
  """
23
+ def __init__(self, encoder, dropout: float = 0.1, hidden_dropout: float = 0.1):
24
  super().__init__()
25
+ self.encoder = encoder
 
 
 
 
26
  hs = self.encoder.config.hidden_size
27
  self.dropout = nn.Dropout(dropout)
28
  # Classifier: takes concatenation [hA, hB, |hA-hB|, hA*hB]