alverciito commited on
Commit
8021b9c
·
1 Parent(s): f663b1c

zero shot experiment (fix v2)

Browse files
model.py CHANGED
@@ -169,8 +169,8 @@ class SentenceCoseNet(PreTrainedModel):
169
  `(batch_size, sequence_length, emb_dim)`.
170
  """
171
  # Convert to type:
172
- x = input_ids.int().unsqueeze(0)
173
- mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
174
 
175
  # Embedding and positional encoding:
176
  x = self.model.embedding(x)
@@ -213,8 +213,8 @@ class SentenceCoseNet(PreTrainedModel):
213
  Sentence embeddings of shape (B, D)
214
  """
215
  # Convert to type:
216
- x = input_ids.int().unsqueeze(0)
217
- mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
218
 
219
  # Embedding and positional encoding:
220
  x = self.model.embedding(x)
 
169
  `(batch_size, sequence_length, emb_dim)`.
170
  """
171
  # Convert to type:
172
+ x = input_ids.int()
173
+ mask = attention_mask if attention_mask is not None else None
174
 
175
  # Embedding and positional encoding:
176
  x = self.model.embedding(x)
 
213
  Sentence embeddings of shape (B, D)
214
  """
215
  # Convert to type:
216
+ x = input_ids.int()
217
+ mask = attention_mask if attention_mask is not None else None
218
 
219
  # Embedding and positional encoding:
220
  x = self.model.embedding(x)
research_files/bench.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import json
7
  from benchmark.segmentation_benchmark.proposed import evaluate_proposed
8
  from benchmark.segmentation_benchmark.heuristic import evaluate_textile
9
- from benchmark.segmentation_benchmark.transformers import evaluate_lms
10
  from benchmark.segmentation_benchmark.inference_proposed import evaluate_hf_proposed
11
 
12
 
@@ -41,7 +41,7 @@ if __name__ == '__main__':
41
  # 'hiiamsid/sentence_similarity_spanish_es', # Spanish similarity - sBERT
42
  # "jaimevera1107/all-MiniLM-L6-v2-similarity-es", # Spanish similarity - sBERT
43
  # "google-bert/bert-base-multilingual-cased", # mBERT (google)
44
- "sentence-transformers/LaBSE", # LaBSE (google)
45
  "FacebookAI/xlm-roberta-base" # XLM-R (facebook)
46
  ]:
47
  print("Evaluating Model (3 methods):", model)
 
6
  import json
7
  from benchmark.segmentation_benchmark.proposed import evaluate_proposed
8
  from benchmark.segmentation_benchmark.heuristic import evaluate_textile
9
+ from benchmark.segmentation_benchmark.sota_transformers import evaluate_lms
10
  from benchmark.segmentation_benchmark.inference_proposed import evaluate_hf_proposed
11
 
12
 
 
41
  # 'hiiamsid/sentence_similarity_spanish_es', # Spanish similarity - sBERT
42
  # "jaimevera1107/all-MiniLM-L6-v2-similarity-es", # Spanish similarity - sBERT
43
  # "google-bert/bert-base-multilingual-cased", # mBERT (google)
44
+ # "sentence-transformers/LaBSE", # LaBSE (google)
45
  "FacebookAI/xlm-roberta-base" # XLM-R (facebook)
46
  ]:
47
  print("Evaluating Model (3 methods):", model)
research_files/benchmark/segmentation_benchmark/{transformers.py → sota_transformers.py} RENAMED
File without changes
research_files/benchmark/segmentation_benchmark/zero_shot_transfer.py CHANGED
@@ -45,7 +45,7 @@ def zero_shot_proposed(
45
 
46
  with torch.no_grad():
47
  for batch in tqdm.tqdm(dataset['test'].batch(batch_size)):
48
- if not hasattr(model, 'get_sentence_emebedding'):
49
  inputs_1 = tokenizer(batch['sentence1'], return_tensors="pt", padding=True, truncation=True, max_length=382)
50
  inputs_2 = tokenizer(batch['sentence2'], return_tensors="pt", padding=True, truncation=True, max_length=382)
51
  inputs_1 = {k: v.to(device) for k, v in inputs_1.items()}
 
45
 
46
  with torch.no_grad():
47
  for batch in tqdm.tqdm(dataset['test'].batch(batch_size)):
48
+ if not hasattr(model, 'get_sentence_embedding'):
49
  inputs_1 = tokenizer(batch['sentence1'], return_tensors="pt", padding=True, truncation=True, max_length=382)
50
  inputs_2 = tokenizer(batch['sentence2'], return_tensors="pt", padding=True, truncation=True, max_length=382)
51
  inputs_1 = {k: v.to(device) for k, v in inputs_1.items()}
research_files/zero_shot_tranfer_experiment.py CHANGED
@@ -14,7 +14,7 @@ from benchmark.segmentation_benchmark.zero_shot_transfer import zero_shot_propos
14
  __file_path__ = os.path.dirname(__file__)
15
 
16
  if __name__ == '__main__':
17
- zero_shot_proposed("hiiamsid/sentence_similarity_spanish_es", "nflechas/semantic_sentence_similarity_ES")
18
  zero_shot_proposed("Alverciito/wikipedia_segmentation", "nflechas/semantic_sentence_similarity_ES")
19
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
20
  # END OF FILE #
 
14
  __file_path__ = os.path.dirname(__file__)
15
 
16
  if __name__ == '__main__':
17
+ # zero_shot_proposed("hiiamsid/sentence_similarity_spanish_es", "nflechas/semantic_sentence_similarity_ES")
18
  zero_shot_proposed("Alverciito/wikipedia_segmentation", "nflechas/semantic_sentence_similarity_ES")
19
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
20
  # END OF FILE #