alverciito commited on
Commit ·
8021b9c
1
Parent(s): f663b1c
zero shot experiment (fix v2)
Browse files
model.py
CHANGED
|
@@ -169,8 +169,8 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 169 |
`(batch_size, sequence_length, emb_dim)`.
|
| 170 |
"""
|
| 171 |
# Convert to type:
|
| 172 |
-
x = input_ids.int()
|
| 173 |
-
mask = attention_mask
|
| 174 |
|
| 175 |
# Embedding and positional encoding:
|
| 176 |
x = self.model.embedding(x)
|
|
@@ -213,8 +213,8 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 213 |
Sentence embeddings of shape (B, D)
|
| 214 |
"""
|
| 215 |
# Convert to type:
|
| 216 |
-
x = input_ids.int()
|
| 217 |
-
mask = attention_mask
|
| 218 |
|
| 219 |
# Embedding and positional encoding:
|
| 220 |
x = self.model.embedding(x)
|
|
|
|
| 169 |
`(batch_size, sequence_length, emb_dim)`.
|
| 170 |
"""
|
| 171 |
# Convert to type:
|
| 172 |
+
x = input_ids.int()
|
| 173 |
+
mask = attention_mask if attention_mask is not None else None
|
| 174 |
|
| 175 |
# Embedding and positional encoding:
|
| 176 |
x = self.model.embedding(x)
|
|
|
|
| 213 |
Sentence embeddings of shape (B, D)
|
| 214 |
"""
|
| 215 |
# Convert to type:
|
| 216 |
+
x = input_ids.int()
|
| 217 |
+
mask = attention_mask if attention_mask is not None else None
|
| 218 |
|
| 219 |
# Embedding and positional encoding:
|
| 220 |
x = self.model.embedding(x)
|
research_files/bench.py
CHANGED
|
@@ -6,7 +6,7 @@ import os
|
|
| 6 |
import json
|
| 7 |
from benchmark.segmentation_benchmark.proposed import evaluate_proposed
|
| 8 |
from benchmark.segmentation_benchmark.heuristic import evaluate_textile
|
| 9 |
-
from benchmark.segmentation_benchmark.
|
| 10 |
from benchmark.segmentation_benchmark.inference_proposed import evaluate_hf_proposed
|
| 11 |
|
| 12 |
|
|
@@ -41,7 +41,7 @@ if __name__ == '__main__':
|
|
| 41 |
# 'hiiamsid/sentence_similarity_spanish_es', # Spanish similarity - sBERT
|
| 42 |
# "jaimevera1107/all-MiniLM-L6-v2-similarity-es", # Spanish similarity - sBERT
|
| 43 |
# "google-bert/bert-base-multilingual-cased", # mBERT (google)
|
| 44 |
-
"sentence-transformers/LaBSE", # LaBSE (google)
|
| 45 |
"FacebookAI/xlm-roberta-base" # XLM-R (facebook)
|
| 46 |
]:
|
| 47 |
print("Evaluating Model (3 methods):", model)
|
|
|
|
| 6 |
import json
|
| 7 |
from benchmark.segmentation_benchmark.proposed import evaluate_proposed
|
| 8 |
from benchmark.segmentation_benchmark.heuristic import evaluate_textile
|
| 9 |
+
from benchmark.segmentation_benchmark.sota_transformers import evaluate_lms
|
| 10 |
from benchmark.segmentation_benchmark.inference_proposed import evaluate_hf_proposed
|
| 11 |
|
| 12 |
|
|
|
|
| 41 |
# 'hiiamsid/sentence_similarity_spanish_es', # Spanish similarity - sBERT
|
| 42 |
# "jaimevera1107/all-MiniLM-L6-v2-similarity-es", # Spanish similarity - sBERT
|
| 43 |
# "google-bert/bert-base-multilingual-cased", # mBERT (google)
|
| 44 |
+
# "sentence-transformers/LaBSE", # LaBSE (google)
|
| 45 |
"FacebookAI/xlm-roberta-base" # XLM-R (facebook)
|
| 46 |
]:
|
| 47 |
print("Evaluating Model (3 methods):", model)
|
research_files/benchmark/segmentation_benchmark/{transformers.py → sota_transformers.py}
RENAMED
|
File without changes
|
research_files/benchmark/segmentation_benchmark/zero_shot_transfer.py
CHANGED
|
@@ -45,7 +45,7 @@ def zero_shot_proposed(
|
|
| 45 |
|
| 46 |
with torch.no_grad():
|
| 47 |
for batch in tqdm.tqdm(dataset['test'].batch(batch_size)):
|
| 48 |
-
if not hasattr(model, '
|
| 49 |
inputs_1 = tokenizer(batch['sentence1'], return_tensors="pt", padding=True, truncation=True, max_length=382)
|
| 50 |
inputs_2 = tokenizer(batch['sentence2'], return_tensors="pt", padding=True, truncation=True, max_length=382)
|
| 51 |
inputs_1 = {k: v.to(device) for k, v in inputs_1.items()}
|
|
|
|
| 45 |
|
| 46 |
with torch.no_grad():
|
| 47 |
for batch in tqdm.tqdm(dataset['test'].batch(batch_size)):
|
| 48 |
+
if not hasattr(model, 'get_sentence_embedding'):
|
| 49 |
inputs_1 = tokenizer(batch['sentence1'], return_tensors="pt", padding=True, truncation=True, max_length=382)
|
| 50 |
inputs_2 = tokenizer(batch['sentence2'], return_tensors="pt", padding=True, truncation=True, max_length=382)
|
| 51 |
inputs_1 = {k: v.to(device) for k, v in inputs_1.items()}
|
research_files/zero_shot_tranfer_experiment.py
CHANGED
|
@@ -14,7 +14,7 @@ from benchmark.segmentation_benchmark.zero_shot_transfer import zero_shot_propos
|
|
| 14 |
__file_path__ = os.path.dirname(__file__)
|
| 15 |
|
| 16 |
if __name__ == '__main__':
|
| 17 |
-
zero_shot_proposed("hiiamsid/sentence_similarity_spanish_es", "nflechas/semantic_sentence_similarity_ES")
|
| 18 |
zero_shot_proposed("Alverciito/wikipedia_segmentation", "nflechas/semantic_sentence_similarity_ES")
|
| 19 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 20 |
# END OF FILE #
|
|
|
|
| 14 |
__file_path__ = os.path.dirname(__file__)
|
| 15 |
|
| 16 |
if __name__ == '__main__':
|
| 17 |
+
# zero_shot_proposed("hiiamsid/sentence_similarity_spanish_es", "nflechas/semantic_sentence_similarity_ES")
|
| 18 |
zero_shot_proposed("Alverciito/wikipedia_segmentation", "nflechas/semantic_sentence_similarity_ES")
|
| 19 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 20 |
# END OF FILE #
|