Spaces:
Runtime error
Runtime error
refactor: allow custom Encoder instances
Browse files- README.md +2 -2
- encoder_models.py +48 -40
- semf1.py +157 -108
- tests.py +135 -76
README.md
CHANGED
|
@@ -59,8 +59,8 @@ Sem-F1 takes 2 mandatory arguments:
|
|
| 59 |
Sem-F1 also accepts multiple optional arguments:
|
| 60 |
|
| 61 |
|
| 62 |
-
- `model_type (str)`: Model to use for encoding sentences. Options: ['pv1' ([paraphrase-distilroberta-base-v1](https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1)), 'stsb' ([stsb-roberta-large](https://huggingface.co/sentence-transformers/stsb-roberta-large)), 'use' ([Universal Sentence Encoder](https://huggingface.co/sentence-transformers/use-cmlm-multilingual)) (Default)]. Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by SentenceTransformer
|
| 63 |
-
such as `all-mpnet-base-v2` or `roberta-base`.
|
| 64 |
- `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
| 65 |
- `multi_references (bool)`: Flag to indicate whether multiple references are provided. Default: False.
|
| 66 |
- `gpu (Union[bool, str, int, List[Union[str, int]]])`: Whether to use GPU, CPU or multiple-processes for computation.
|
|
|
|
| 59 |
Sem-F1 also accepts multiple optional arguments:
|
| 60 |
|
| 61 |
|
| 62 |
+
- `model_type (Optional[Union[str, Encoder]])`: Model to use for encoding sentences. Options: ['pv1' ([paraphrase-distilroberta-base-v1](https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1)), 'stsb' ([stsb-roberta-large](https://huggingface.co/sentence-transformers/stsb-roberta-large)), 'use' ([Universal Sentence Encoder](https://huggingface.co/sentence-transformers/use-cmlm-multilingual)) (Default)]. Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by SentenceTransformer
|
| 63 |
+
such as `all-mpnet-base-v2` or `roberta-base`. Users can also pass a custom `Encoder` which must implement the `encode` method. Refer SemF1/encoder_models.py
|
| 64 |
- `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
| 65 |
- `multi_references (bool)`: Flag to indicate whether multiple references are provided. Default: False.
|
| 66 |
- `gpu (Union[bool, str, int, List[Union[str, int]]])`: Whether to use GPU, CPU or multiple-processes for computation.
|
encoder_models.py
CHANGED
|
@@ -9,68 +9,83 @@ from .type_aliases import ENCODER_DEVICE_TYPE
|
|
| 9 |
|
| 10 |
class Encoder(abc.ABC):
|
| 11 |
@abc.abstractmethod
|
| 12 |
-
def encode(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
-
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
"""
|
| 25 |
raise NotImplementedError("Method 'encode' must be implemented in subclass.")
|
| 26 |
|
| 27 |
|
| 28 |
class SBertEncoder(Encoder):
|
| 29 |
-
def __init__(self, model_name: str
|
| 30 |
"""
|
| 31 |
-
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
|
| 36 |
-
batch_size (int): Batch size for encoding.
|
| 37 |
-
verbose (bool): Whether to print verbose information during encoding.
|
| 38 |
"""
|
| 39 |
self.model = SentenceTransformer(model_name, trust_remote_code=True)
|
| 40 |
-
self.device = device
|
| 41 |
-
self.batch_size = batch_size
|
| 42 |
-
self.verbose = verbose
|
| 43 |
|
| 44 |
-
def encode(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
"""
|
| 46 |
-
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
"""
|
| 54 |
|
| 55 |
# SBert output is always Batch x Dim
|
| 56 |
-
if isinstance(
|
| 57 |
# Use multiprocess encoding for list of devices
|
| 58 |
-
pool = self.model.start_multi_process_pool(target_devices=
|
| 59 |
-
embeddings = self.model.encode_multi_process(
|
|
|
|
|
|
|
| 60 |
self.model.stop_multi_process_pool(pool)
|
| 61 |
else:
|
| 62 |
# Single device encoding
|
| 63 |
embeddings = self.model.encode(
|
| 64 |
prediction,
|
| 65 |
-
device=
|
| 66 |
-
batch_size=
|
| 67 |
-
show_progress_bar=
|
| 68 |
)
|
| 69 |
-
|
| 70 |
return embeddings
|
| 71 |
|
| 72 |
|
| 73 |
-
def get_encoder(model_name: str
|
| 74 |
"""
|
| 75 |
Get the encoder instance based on the specified model name.
|
| 76 |
|
|
@@ -83,11 +98,6 @@ def get_encoder(model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, v
|
|
| 83 |
Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
|
| 84 |
SentenceTransformer.
|
| 85 |
|
| 86 |
-
device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
|
| 87 |
-
(e.g., "cuda", 0 for GPU, "cpu").
|
| 88 |
-
batch_size (int): Batch size for encoding.
|
| 89 |
-
verbose (bool): Whether to print verbose information during encoder initialization.
|
| 90 |
-
|
| 91 |
Returns:
|
| 92 |
Encoder: Instance of the selected encoder based on the model_name.
|
| 93 |
|
|
@@ -96,12 +106,10 @@ def get_encoder(model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, v
|
|
| 96 |
"""
|
| 97 |
|
| 98 |
try:
|
| 99 |
-
encoder = SBertEncoder(model_name, device, batch_size, verbose)
|
| 100 |
except EnvironmentError as err:
|
| 101 |
raise EnvironmentError(str(err)) from None
|
| 102 |
except Exception as err:
|
| 103 |
raise RuntimeError(str(err)) from None
|
| 104 |
|
| 105 |
return encoder
|
| 106 |
-
|
| 107 |
-
|
|
|
|
| 9 |
|
| 10 |
class Encoder(abc.ABC):
|
| 11 |
@abc.abstractmethod
|
| 12 |
+
def encode(
|
| 13 |
+
self,
|
| 14 |
+
prediction: List[str],
|
| 15 |
+
*,
|
| 16 |
+
device: ENCODER_DEVICE_TYPE = "cpu",
|
| 17 |
+
batch_size: int = 32,
|
| 18 |
+
verbose: bool = False,
|
| 19 |
+
) -> NDArray:
|
| 20 |
"""
|
| 21 |
+
Abstract method to encode a list of sentences into sentence embeddings.
|
| 22 |
|
| 23 |
+
Args:
|
| 24 |
+
prediction (List[str]): List of sentences to encode.
|
| 25 |
+
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding.
|
| 26 |
+
batch_size (int): Batch size for encoding.
|
| 27 |
+
verbose (bool): Whether to print verbose information during encoding.
|
| 28 |
|
| 29 |
+
Returns:
|
| 30 |
+
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
|
| 31 |
|
| 32 |
+
Raises:
|
| 33 |
+
NotImplementedError: If the method is not implemented in the subclass.
|
| 34 |
"""
|
| 35 |
raise NotImplementedError("Method 'encode' must be implemented in subclass.")
|
| 36 |
|
| 37 |
|
| 38 |
class SBertEncoder(Encoder):
|
| 39 |
+
def __init__(self, model_name: str):
|
| 40 |
"""
|
| 41 |
+
Initialize SBertEncoder instance.
|
| 42 |
|
| 43 |
+
Args:
|
| 44 |
+
model_name (str): Name or path of the Sentence Transformer model.
|
|
|
|
|
|
|
|
|
|
| 45 |
"""
|
| 46 |
self.model = SentenceTransformer(model_name, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
def encode(
|
| 49 |
+
self,
|
| 50 |
+
prediction: List[str],
|
| 51 |
+
*,
|
| 52 |
+
device: ENCODER_DEVICE_TYPE = "cpu",
|
| 53 |
+
batch_size: int = 32,
|
| 54 |
+
verbose: bool = False,
|
| 55 |
+
) -> NDArray:
|
| 56 |
"""
|
| 57 |
+
Encode a list of sentences into sentence embeddings.
|
| 58 |
|
| 59 |
+
Args:
|
| 60 |
+
prediction (List[str]): List of sentences to encode.
|
| 61 |
+
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
|
| 62 |
+
batch_size (int): Batch size for encoding.
|
| 63 |
+
verbose (bool): Whether to print verbose information during encoding.
|
| 64 |
|
| 65 |
+
Returns:
|
| 66 |
+
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
|
| 67 |
"""
|
| 68 |
|
| 69 |
# SBert output is always Batch x Dim
|
| 70 |
+
if isinstance(device, list):
|
| 71 |
# Use multiprocess encoding for list of devices
|
| 72 |
+
pool = self.model.start_multi_process_pool(target_devices=device)
|
| 73 |
+
embeddings = self.model.encode_multi_process(
|
| 74 |
+
prediction, pool=pool, batch_size=batch_size
|
| 75 |
+
)
|
| 76 |
self.model.stop_multi_process_pool(pool)
|
| 77 |
else:
|
| 78 |
# Single device encoding
|
| 79 |
embeddings = self.model.encode(
|
| 80 |
prediction,
|
| 81 |
+
device=device,
|
| 82 |
+
batch_size=batch_size,
|
| 83 |
+
show_progress_bar=verbose,
|
| 84 |
)
|
|
|
|
| 85 |
return embeddings
|
| 86 |
|
| 87 |
|
| 88 |
+
def get_encoder(model_name: str) -> Encoder:
|
| 89 |
"""
|
| 90 |
Get the encoder instance based on the specified model name.
|
| 91 |
|
|
|
|
| 98 |
Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
|
| 99 |
SentenceTransformer.
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
Returns:
|
| 102 |
Encoder: Instance of the selected encoder based on the model_name.
|
| 103 |
|
|
|
|
| 106 |
"""
|
| 107 |
|
| 108 |
try:
|
| 109 |
+
encoder = SBertEncoder(model_name) # , device, batch_size, verbose)
|
| 110 |
except EnvironmentError as err:
|
| 111 |
raise EnvironmentError(str(err)) from None
|
| 112 |
except Exception as err:
|
| 113 |
raise RuntimeError(str(err)) from None
|
| 114 |
|
| 115 |
return encoder
|
|
|
|
|
|
semf1.py
CHANGED
|
@@ -16,7 +16,7 @@ Sem-F1 metric
|
|
| 16 |
Author: Naman Bansal
|
| 17 |
"""
|
| 18 |
|
| 19 |
-
from typing import List, Optional, Tuple
|
| 20 |
|
| 21 |
import datasets
|
| 22 |
import evaluate
|
|
@@ -25,9 +25,16 @@ import numpy as np
|
|
| 25 |
from numpy.typing import NDArray
|
| 26 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 27 |
|
| 28 |
-
from .encoder_models import get_encoder
|
| 29 |
from .type_aliases import DEVICE_TYPE, PREDICTION_TYPE, REFERENCE_TYPE
|
| 30 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
_CITATION = """\
|
| 33 |
@inproceedings{bansal-etal-2022-sem,
|
|
@@ -63,13 +70,15 @@ using precision, recall, and F1 score based on sentence embeddings.
|
|
| 63 |
Args:
|
| 64 |
predictions (list): List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
| 65 |
references (list): List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
| 66 |
-
model_type (str): Model to use for encoding sentences.
|
| 67 |
-
pv1
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
tokenize_sentences (bool): Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
| 74 |
multi_references (bool): Flag to indicate whether multiple references are provided. Default is False.
|
| 75 |
gpu (Union[bool, str, int, List[Union[str, int]]]): Whether to use GPU or CPU for computation.
|
|
@@ -151,19 +160,21 @@ Examples:
|
|
| 151 |
"""
|
| 152 |
|
| 153 |
|
| 154 |
-
def _compute_cosine_similarity(
|
|
|
|
|
|
|
| 155 |
"""
|
| 156 |
-
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
# Compute cosine similarity between predicted and reference embeddings
|
| 168 |
cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
|
| 169 |
|
|
@@ -181,60 +192,65 @@ def _compute_cosine_similarity(pred_embeds: NDArray, ref_embeds: NDArray) -> Tup
|
|
| 181 |
|
| 182 |
|
| 183 |
def _validate_input_format(
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
):
|
| 189 |
"""
|
| 190 |
-
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
| 224 |
"""
|
| 225 |
|
| 226 |
if len(predictions) != len(references):
|
| 227 |
-
raise ValueError(
|
| 228 |
-
|
|
|
|
|
|
|
| 229 |
|
| 230 |
if len(predictions) == 0:
|
| 231 |
raise ValueError("Can't have empty inputs")
|
| 232 |
|
| 233 |
def check_format(lst_obj, expected_depth: int, name: str):
|
| 234 |
-
is_valid, error_message = is_nested_list_of_type(
|
|
|
|
|
|
|
| 235 |
if not is_valid:
|
| 236 |
-
raise ValueError(
|
| 237 |
-
|
|
|
|
| 238 |
|
| 239 |
try:
|
| 240 |
if tokenize_sentences and multi_references:
|
|
@@ -274,9 +290,13 @@ class SemF1(evaluate.Metric):
|
|
| 274 |
datasets.Features(
|
| 275 |
{
|
| 276 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
| 277 |
-
"predictions": datasets.Sequence(
|
|
|
|
|
|
|
| 278 |
# references: List[List[str]] - List of references where each reference is a list of sentences
|
| 279 |
-
"references": datasets.Sequence(
|
|
|
|
|
|
|
| 280 |
}
|
| 281 |
),
|
| 282 |
# F1: Multi References: False, Tokenize_Sentences = True
|
|
@@ -292,12 +312,18 @@ class SemF1(evaluate.Metric):
|
|
| 292 |
datasets.Features(
|
| 293 |
{
|
| 294 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
| 295 |
-
"predictions": datasets.Sequence(
|
|
|
|
|
|
|
| 296 |
# references: List[List[List[str]]] - List of multi-references.
|
| 297 |
# So each "reference" is also a list (r1, r2, ...).
|
| 298 |
# Further, each ri's are also list of sentences.
|
| 299 |
"references": datasets.Sequence(
|
| 300 |
-
datasets.Sequence(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
}
|
| 302 |
),
|
| 303 |
# F3: Multi References: True, Tokenize_Sentences = True
|
|
@@ -307,13 +333,15 @@ class SemF1(evaluate.Metric):
|
|
| 307 |
"predictions": datasets.Value("string", id="sequence"),
|
| 308 |
# references: List[List[List[str]]] - List of multi-references.
|
| 309 |
# So each "reference" is also a list (r1, r2, ...).
|
| 310 |
-
"references": datasets.Sequence(
|
|
|
|
|
|
|
| 311 |
}
|
| 312 |
),
|
| 313 |
],
|
| 314 |
# # Homepage of the module for documentation
|
| 315 |
# Additional links to the codebase or references
|
| 316 |
-
reference_urls=["https://aclanthology.org/2022.emnlp-main.49/"]
|
| 317 |
)
|
| 318 |
|
| 319 |
def _get_model_name(self, model_type: Optional[str] = None) -> str:
|
|
@@ -328,51 +356,62 @@ class SemF1(evaluate.Metric):
|
|
| 328 |
def _download_and_prepare(self, dl_manager):
|
| 329 |
"""Optional: download external resources useful to compute the scores"""
|
| 330 |
import nltk
|
|
|
|
| 331 |
nltk.download("punkt", quiet=True)
|
| 332 |
|
| 333 |
def _compute(
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
) -> List[Scores]:
|
| 345 |
"""
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
"""
|
| 367 |
|
| 368 |
# Note: I have to specifically handle this case because the library considers the feature corresponding to
|
| 369 |
# this case (F2) as the feature for the other case (F0) i.e. it can't make any distinction between
|
| 370 |
# List[str] and List[List[str]]
|
| 371 |
if not tokenize_sentences and multi_references:
|
| 372 |
-
references = [
|
|
|
|
|
|
|
| 373 |
|
| 374 |
# Validate inputs corresponding to flags
|
| 375 |
-
_validate_input_format(
|
|
|
|
|
|
|
| 376 |
|
| 377 |
# Get GPU
|
| 378 |
device = get_gpu(gpu)
|
|
@@ -380,8 +419,15 @@ class SemF1(evaluate.Metric):
|
|
| 380 |
print(f"Using devices: {device}")
|
| 381 |
|
| 382 |
# Get the encoder model
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
# We'll handle the single reference and multi-reference case same way. So change the data format accordingly
|
| 387 |
if not multi_references:
|
|
@@ -401,11 +447,15 @@ class SemF1(evaluate.Metric):
|
|
| 401 |
|
| 402 |
# Note: This is the most optimal way of doing it
|
| 403 |
# Encode all sentences in one go
|
| 404 |
-
embeddings = encoder.encode(
|
|
|
|
|
|
|
| 405 |
|
| 406 |
# Get embeddings corresponding to predictions and references
|
| 407 |
pred_embeddings = slice_embeddings(embeddings, prediction_sentences_count)
|
| 408 |
-
ref_embeddings = slice_embeddings(
|
|
|
|
|
|
|
| 409 |
|
| 410 |
# Init output scores
|
| 411 |
results = []
|
|
@@ -418,23 +468,22 @@ class SemF1(evaluate.Metric):
|
|
| 418 |
precision = np.clip(precision, a_min=0.0, a_max=1.0).item()
|
| 419 |
|
| 420 |
# Recall: Compute individually for each reference
|
| 421 |
-
recall_scores = [
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
results.append(Scores(precision, recall_scores))
|
| 425 |
|
| 426 |
# run aggregation procedure
|
| 427 |
if aggregate:
|
| 428 |
-
mean_prec = np.mean(
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
mean_recall = np.mean(np.concatenate(
|
| 432 |
-
[np.array(score.recall) for score in results]
|
| 433 |
-
))
|
| 434 |
-
aggregated_score = Scores(
|
| 435 |
-
float(mean_prec),
|
| 436 |
-
[float(mean_recall)]
|
| 437 |
)
|
|
|
|
| 438 |
results = aggregated_score
|
| 439 |
|
| 440 |
return results
|
|
|
|
| 16 |
Author: Naman Bansal
|
| 17 |
"""
|
| 18 |
|
| 19 |
+
from typing import List, Optional, Tuple, Union
|
| 20 |
|
| 21 |
import datasets
|
| 22 |
import evaluate
|
|
|
|
| 25 |
from numpy.typing import NDArray
|
| 26 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 27 |
|
| 28 |
+
from .encoder_models import get_encoder, Encoder
|
| 29 |
from .type_aliases import DEVICE_TYPE, PREDICTION_TYPE, REFERENCE_TYPE
|
| 30 |
+
from .utils import (
|
| 31 |
+
is_nested_list_of_type,
|
| 32 |
+
Scores,
|
| 33 |
+
slice_embeddings,
|
| 34 |
+
flatten_list,
|
| 35 |
+
get_gpu,
|
| 36 |
+
sent_tokenize,
|
| 37 |
+
)
|
| 38 |
|
| 39 |
_CITATION = """\
|
| 40 |
@inproceedings{bansal-etal-2022-sem,
|
|
|
|
| 70 |
Args:
|
| 71 |
predictions (list): List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
| 72 |
references (list): List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
| 73 |
+
model_type (Optional[Union[str, Encoder]]): Model to use for encoding sentences.
|
| 74 |
+
Options: ['pv1', 'stsb', 'use']
|
| 75 |
+
pv1 - paraphrase-distilroberta-base-v1
|
| 76 |
+
stsb - stsb-roberta-large
|
| 77 |
+
use - Universal Sentence Encoder (Default)
|
| 78 |
+
- A string path or name for any model on Huggingface/SentenceTransformer that is supported by
|
| 79 |
+
SentenceTransformer such as `all-mpnet-base-v2` or `roberta-base` .
|
| 80 |
+
- A custom instance of an Encoder (must implement the encode() method). Refer SemF1/encoder_models.py
|
| 81 |
+
|
| 82 |
tokenize_sentences (bool): Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
| 83 |
multi_references (bool): Flag to indicate whether multiple references are provided. Default is False.
|
| 84 |
gpu (Union[bool, str, int, List[Union[str, int]]]): Whether to use GPU or CPU for computation.
|
|
|
|
| 160 |
"""
|
| 161 |
|
| 162 |
|
| 163 |
+
def _compute_cosine_similarity(
|
| 164 |
+
pred_embeds: NDArray, ref_embeds: NDArray
|
| 165 |
+
) -> Tuple[float, float]:
|
| 166 |
"""
|
| 167 |
+
Compute precision and recall based on cosine similarity between predicted and reference embeddings.
|
| 168 |
|
| 169 |
+
Args:
|
| 170 |
+
pred_embeds (NDArray): Predicted embeddings (shape: [num_pred, embedding_dim]).
|
| 171 |
+
ref_embeds (NDArray): Reference embeddings (shape: [num_ref, embedding_dim]).
|
| 172 |
|
| 173 |
+
Returns:
|
| 174 |
+
Tuple[float, float]: Precision and recall based on cosine similarity scores.
|
| 175 |
+
Precision: Average maximum cosine similarity score per predicted embedding.
|
| 176 |
+
Recall: Average maximum cosine similarity score per reference embedding.
|
| 177 |
+
"""
|
| 178 |
# Compute cosine similarity between predicted and reference embeddings
|
| 179 |
cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
|
| 180 |
|
|
|
|
| 192 |
|
| 193 |
|
| 194 |
def _validate_input_format(
|
| 195 |
+
tokenize_sentences: bool,
|
| 196 |
+
multi_references: bool,
|
| 197 |
+
predictions: PREDICTION_TYPE,
|
| 198 |
+
references: REFERENCE_TYPE,
|
| 199 |
):
|
| 200 |
"""
|
| 201 |
+
Validate the format of predictions and references based on specified criteria.
|
| 202 |
|
| 203 |
+
Args:
|
| 204 |
+
- tokenize_sentences (bool): Flag indicating whether sentences should be tokenized.
|
| 205 |
+
- multi_references (bool): Flag indicating whether multiple references are provided.
|
| 206 |
+
- predictions (PREDICTION_TYPE): Predictions to validate.
|
| 207 |
+
- references (REFERENCE_TYPE): References to validate.
|
| 208 |
|
| 209 |
+
Raises:
|
| 210 |
+
- ValueError: If the format of predictions or references does not meet the specified criteria.
|
| 211 |
|
| 212 |
+
Validation Criteria:
|
| 213 |
+
The function validates predictions and references based on the following conditions:
|
| 214 |
+
1. If `tokenize_sentences` is True and `multi_references` is True:
|
| 215 |
+
- Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
|
| 216 |
+
- References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
|
| 217 |
|
| 218 |
+
2. If `tokenize_sentences` is False and `multi_references` is True:
|
| 219 |
+
- Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
|
| 220 |
+
- References must be a list of list of list of strings (`is_list_of_strings_at_depth(references, 3)`).
|
| 221 |
|
| 222 |
+
3. If `tokenize_sentences` is True and `multi_references` is False:
|
| 223 |
+
- Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
|
| 224 |
+
- References must be a list of strings (`is_list_of_strings_at_depth(references, 1)`).
|
| 225 |
|
| 226 |
+
4. If `tokenize_sentences` is False and `multi_references` is False:
|
| 227 |
+
- Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
|
| 228 |
+
- References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
|
| 229 |
|
| 230 |
+
The function checks these conditions and raises a ValueError if any condition is not met,
|
| 231 |
+
indicating that predictions or references are not in the valid input format.
|
| 232 |
|
| 233 |
+
Note:
|
| 234 |
+
- `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
|
| 235 |
"""
|
| 236 |
|
| 237 |
if len(predictions) != len(references):
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"Predictions and references must have the same length. "
|
| 240 |
+
f"Got {len(predictions)} predictions and {len(references)} references."
|
| 241 |
+
)
|
| 242 |
|
| 243 |
if len(predictions) == 0:
|
| 244 |
raise ValueError("Can't have empty inputs")
|
| 245 |
|
| 246 |
def check_format(lst_obj, expected_depth: int, name: str):
|
| 247 |
+
is_valid, error_message = is_nested_list_of_type(
|
| 248 |
+
lst_obj, element_type=str, depth=expected_depth
|
| 249 |
+
)
|
| 250 |
if not is_valid:
|
| 251 |
+
raise ValueError(
|
| 252 |
+
f"{name} are not in the expected format.\n" f"Error: {error_message}."
|
| 253 |
+
)
|
| 254 |
|
| 255 |
try:
|
| 256 |
if tokenize_sentences and multi_references:
|
|
|
|
| 290 |
datasets.Features(
|
| 291 |
{
|
| 292 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
| 293 |
+
"predictions": datasets.Sequence(
|
| 294 |
+
datasets.Value("string", id="sequence"), id="predictions"
|
| 295 |
+
),
|
| 296 |
# references: List[List[str]] - List of references where each reference is a list of sentences
|
| 297 |
+
"references": datasets.Sequence(
|
| 298 |
+
datasets.Value("string", id="sequence"), id="references"
|
| 299 |
+
),
|
| 300 |
}
|
| 301 |
),
|
| 302 |
# F1: Multi References: False, Tokenize_Sentences = True
|
|
|
|
| 312 |
datasets.Features(
|
| 313 |
{
|
| 314 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
| 315 |
+
"predictions": datasets.Sequence(
|
| 316 |
+
datasets.Value("string", id="sequence"), id="predictions"
|
| 317 |
+
),
|
| 318 |
# references: List[List[List[str]]] - List of multi-references.
|
| 319 |
# So each "reference" is also a list (r1, r2, ...).
|
| 320 |
# Further, each ri's are also list of sentences.
|
| 321 |
"references": datasets.Sequence(
|
| 322 |
+
datasets.Sequence(
|
| 323 |
+
datasets.Value("string", id="sequence"), id="ref"
|
| 324 |
+
),
|
| 325 |
+
id="references",
|
| 326 |
+
),
|
| 327 |
}
|
| 328 |
),
|
| 329 |
# F3: Multi References: True, Tokenize_Sentences = True
|
|
|
|
| 333 |
"predictions": datasets.Value("string", id="sequence"),
|
| 334 |
# references: List[List[List[str]]] - List of multi-references.
|
| 335 |
# So each "reference" is also a list (r1, r2, ...).
|
| 336 |
+
"references": datasets.Sequence(
|
| 337 |
+
datasets.Value("string", id="ref"), id="references"
|
| 338 |
+
),
|
| 339 |
}
|
| 340 |
),
|
| 341 |
],
|
| 342 |
# # Homepage of the module for documentation
|
| 343 |
# Additional links to the codebase or references
|
| 344 |
+
reference_urls=["https://aclanthology.org/2022.emnlp-main.49/"],
|
| 345 |
)
|
| 346 |
|
| 347 |
def _get_model_name(self, model_type: Optional[str] = None) -> str:
|
|
|
|
| 356 |
def _download_and_prepare(self, dl_manager):
|
| 357 |
"""Optional: download external resources useful to compute the scores"""
|
| 358 |
import nltk
|
| 359 |
+
|
| 360 |
nltk.download("punkt", quiet=True)
|
| 361 |
|
| 362 |
def _compute(
|
| 363 |
+
self,
|
| 364 |
+
predictions,
|
| 365 |
+
references,
|
| 366 |
+
model_type: Optional[Union[str, Encoder]] = None,
|
| 367 |
+
tokenize_sentences: bool = True,
|
| 368 |
+
multi_references: bool = False,
|
| 369 |
+
gpu: DEVICE_TYPE = False,
|
| 370 |
+
batch_size: int = 32,
|
| 371 |
+
verbose: bool = False,
|
| 372 |
+
aggregate: bool = False,
|
| 373 |
) -> List[Scores]:
|
| 374 |
"""
|
| 375 |
+
Compute precision, recall, and F1 scores for given predictions and references.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
- predictions
|
| 379 |
+
- references
|
| 380 |
+
- model_type: Type of model to use for encoding.
|
| 381 |
+
Options: [pv1, stsb, use]
|
| 382 |
+
pv1 - paraphrase-distilroberta-base-v1
|
| 383 |
+
stsb - stsb-roberta-large
|
| 384 |
+
use - Universal Sentence Encoder (Default)
|
| 385 |
+
- A string path or name for any model on Huggingface/SentenceTransformer that is supported by
|
| 386 |
+
SentenceTransformer.
|
| 387 |
+
- A custom instance of an Encoder (must implement the encode() method). Refer SemF1/encoder_models.py
|
| 388 |
+
|
| 389 |
+
- tokenize_sentences: Flag to sentence tokenize the document.
|
| 390 |
+
- multi_references: Flag to indicate multiple references.
|
| 391 |
+
- gpu: GPU device to use.
|
| 392 |
+
- batch_size: Batch size for encoding.
|
| 393 |
+
- verbose: Flag to indicate verbose output.
|
| 394 |
+
- aggregate: Flag to determine if output should be averaged
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
Singleton/List of Scores dataclass with attributes as follows -
|
| 398 |
+
precision: float - precision score
|
| 399 |
+
recall: List[float] - List of recall scores corresponding to single/multiple references
|
| 400 |
+
f1: float - F1 score (between precision and average recall)
|
| 401 |
"""
|
| 402 |
|
| 403 |
# Note: I have to specifically handle this case because the library considers the feature corresponding to
|
| 404 |
# this case (F2) as the feature for the other case (F0) i.e. it can't make any distinction between
|
| 405 |
# List[str] and List[List[str]]
|
| 406 |
if not tokenize_sentences and multi_references:
|
| 407 |
+
references = [
|
| 408 |
+
[eval(ref) for ref in mul_ref_ex] for mul_ref_ex in references
|
| 409 |
+
]
|
| 410 |
|
| 411 |
# Validate inputs corresponding to flags
|
| 412 |
+
_validate_input_format(
|
| 413 |
+
tokenize_sentences, multi_references, predictions, references
|
| 414 |
+
)
|
| 415 |
|
| 416 |
# Get GPU
|
| 417 |
device = get_gpu(gpu)
|
|
|
|
| 419 |
print(f"Using devices: {device}")
|
| 420 |
|
| 421 |
# Get the encoder model
|
| 422 |
+
if model_type is None or isinstance(model_type, str):
|
| 423 |
+
model_name = self._get_model_name(model_type)
|
| 424 |
+
encoder = get_encoder(model_name)
|
| 425 |
+
elif isinstance(model_type, Encoder):
|
| 426 |
+
encoder = model_type
|
| 427 |
+
else:
|
| 428 |
+
raise TypeError(
|
| 429 |
+
f"Unsupported model_type: expected str or Encoder instance, got {type(model_type)}"
|
| 430 |
+
)
|
| 431 |
|
| 432 |
# We'll handle the single reference and multi-reference case same way. So change the data format accordingly
|
| 433 |
if not multi_references:
|
|
|
|
| 447 |
|
| 448 |
# Note: This is the most optimal way of doing it
|
| 449 |
# Encode all sentences in one go
|
| 450 |
+
embeddings = encoder.encode(
|
| 451 |
+
all_sentences, device=device, batch_size=batch_size, verbose=verbose
|
| 452 |
+
)
|
| 453 |
|
| 454 |
# Get embeddings corresponding to predictions and references
|
| 455 |
pred_embeddings = slice_embeddings(embeddings, prediction_sentences_count)
|
| 456 |
+
ref_embeddings = slice_embeddings(
|
| 457 |
+
embeddings[sum(prediction_sentences_count) :], reference_sentences_count
|
| 458 |
+
)
|
| 459 |
|
| 460 |
# Init output scores
|
| 461 |
results = []
|
|
|
|
| 468 |
precision = np.clip(precision, a_min=0.0, a_max=1.0).item()
|
| 469 |
|
| 470 |
# Recall: Compute individually for each reference
|
| 471 |
+
recall_scores = [
|
| 472 |
+
_compute_cosine_similarity(r_embeds, preds) for r_embeds in refs
|
| 473 |
+
]
|
| 474 |
+
recall_scores = [
|
| 475 |
+
np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores
|
| 476 |
+
]
|
| 477 |
|
| 478 |
results.append(Scores(precision, recall_scores))
|
| 479 |
|
| 480 |
# run aggregation procedure
|
| 481 |
if aggregate:
|
| 482 |
+
mean_prec = np.mean([score.precision for score in results])
|
| 483 |
+
mean_recall = np.mean(
|
| 484 |
+
np.concatenate([np.array(score.recall) for score in results])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
)
|
| 486 |
+
aggregated_score = Scores(float(mean_prec), [float(mean_recall)])
|
| 487 |
results = aggregated_score
|
| 488 |
|
| 489 |
return results
|
tests.py
CHANGED
|
@@ -10,7 +10,14 @@ from unittest import TestLoader
|
|
| 10 |
|
| 11 |
from .encoder_models import SBertEncoder, get_encoder
|
| 12 |
from .semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
|
| 13 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class TestUtils(unittest.TestCase):
|
|
@@ -40,20 +47,29 @@ class TestUtils(unittest.TestCase):
|
|
| 40 |
self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
|
| 41 |
|
| 42 |
# Test list input with unique elements
|
| 43 |
-
self.assertEqual(
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# Test list input with duplicate elements
|
| 46 |
-
self.assertEqual(
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# Test list input with duplicate elements of different types
|
| 49 |
-
self.assertEqual(
|
|
|
|
|
|
|
| 50 |
|
| 51 |
# Test list input but only one element
|
| 52 |
self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu")
|
| 53 |
|
| 54 |
# Test list input with all integers
|
| 55 |
-
self.assertEqual(
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
|
| 58 |
with self.assertRaises(ValueError):
|
| 59 |
get_gpu("invalid")
|
|
@@ -66,12 +82,19 @@ class TestUtils(unittest.TestCase):
|
|
| 66 |
num_sentences = [3, 2, 5]
|
| 67 |
expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
|
| 68 |
self.assertTrue(
|
| 69 |
-
all(
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
)
|
| 72 |
|
| 73 |
num_sentences_nested = [[2, 1], [3, 4]]
|
| 74 |
-
expected_output_nested = [
|
|
|
|
|
|
|
|
|
|
| 75 |
self.assertTrue(
|
| 76 |
slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
|
| 77 |
)
|
|
@@ -88,7 +111,9 @@ class TestUtils(unittest.TestCase):
|
|
| 88 |
self.assertEqual(is_valid, False)
|
| 89 |
|
| 90 |
# Test case: Depth 1, list of elements matching element_type
|
| 91 |
-
self.assertEqual(
|
|
|
|
|
|
|
| 92 |
|
| 93 |
# Test case: Depth 1, list of elements not matching element_type
|
| 94 |
is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1)
|
|
@@ -100,15 +125,18 @@ class TestUtils(unittest.TestCase):
|
|
| 100 |
|
| 101 |
# Depth 2
|
| 102 |
self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, ""))
|
| 103 |
-
self.assertEqual(
|
|
|
|
|
|
|
| 104 |
is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)
|
| 105 |
self.assertEqual(is_valid, False)
|
| 106 |
|
| 107 |
-
|
| 108 |
# Depth 3
|
| 109 |
is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3)
|
| 110 |
self.assertEqual(is_valid, False)
|
| 111 |
-
self.assertEqual(
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Test case: Depth is negative, expecting ValueError
|
| 114 |
with self.assertRaises(ValueError):
|
|
@@ -134,38 +162,55 @@ class TestUtils(unittest.TestCase):
|
|
| 134 |
class TestSBertEncoder(unittest.TestCase):
|
| 135 |
def setUp(self, device=None):
|
| 136 |
if device is None:
|
| 137 |
-
self.device =
|
| 138 |
else:
|
| 139 |
self.device = device
|
| 140 |
self.model_name = "stsb-roberta-large"
|
| 141 |
self.batch_size = 8
|
| 142 |
self.verbose = False
|
| 143 |
-
self.encoder = SBertEncoder(self.model_name
|
| 144 |
|
| 145 |
def test_initialization(self):
|
| 146 |
self.assertIsInstance(self.encoder.model, SentenceTransformer)
|
| 147 |
-
self.assertEqual(self.encoder.device, self.device)
|
| 148 |
-
self.assertEqual(self.encoder.batch_size, self.batch_size)
|
| 149 |
-
self.assertEqual(self.encoder.verbose, self.verbose)
|
| 150 |
|
| 151 |
def test_encode_single_device(self):
|
| 152 |
sentences = ["This is a test sentence.", "Here is another sentence."]
|
| 153 |
-
embeddings = self.encoder.encode(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
self.assertIsInstance(embeddings, np.ndarray)
|
| 155 |
self.assertEqual(embeddings.shape[0], len(sentences))
|
| 156 |
-
self.assertEqual(
|
|
|
|
|
|
|
| 157 |
|
| 158 |
def test_encode_multi_device(self):
|
| 159 |
if torch.cuda.device_count() < 2:
|
| 160 |
self.skipTest("Multi-GPU test requires at least 2 GPUs.")
|
| 161 |
else:
|
| 162 |
-
devices = ["cuda:0", "cuda:1"]
|
|
|
|
| 163 |
self.setUp(devices)
|
| 164 |
-
sentences = [
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
self.assertIsInstance(embeddings, np.ndarray)
|
| 167 |
self.assertEqual(embeddings.shape[0], 3)
|
| 168 |
-
self.assertEqual(
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
class TestGetEncoder(unittest.TestCase):
|
|
@@ -175,13 +220,8 @@ class TestGetEncoder(unittest.TestCase):
|
|
| 175 |
self.verbose = False
|
| 176 |
|
| 177 |
def _base_test(self, model_name):
|
| 178 |
-
encoder = get_encoder(model_name
|
| 179 |
-
|
| 180 |
-
# Assert
|
| 181 |
self.assertIsInstance(encoder, SBertEncoder)
|
| 182 |
-
self.assertEqual(encoder.device, self.device)
|
| 183 |
-
self.assertEqual(encoder.batch_size, self.batch_size)
|
| 184 |
-
self.assertEqual(encoder.verbose, self.verbose)
|
| 185 |
|
| 186 |
def test_get_sbert_encoder(self):
|
| 187 |
model_name = "stsb-roberta-large"
|
|
@@ -196,15 +236,15 @@ class TestGetEncoder(unittest.TestCase):
|
|
| 196 |
model_name = "roberta-base"
|
| 197 |
self._base_test(model_name)
|
| 198 |
|
| 199 |
-
def test_get_encoder_environment_error(self):
|
| 200 |
model_name = "abc" # Wrong model_name
|
| 201 |
with self.assertRaises(EnvironmentError):
|
| 202 |
-
get_encoder(model_name
|
| 203 |
|
| 204 |
def test_get_encoder_other_exception(self):
|
| 205 |
model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib
|
| 206 |
with self.assertRaises(RuntimeError):
|
| 207 |
-
get_encoder(model_name
|
| 208 |
|
| 209 |
|
| 210 |
class TestSemF1(unittest.TestCase):
|
|
@@ -213,9 +253,11 @@ class TestSemF1(unittest.TestCase):
|
|
| 213 |
|
| 214 |
# Example cases, #Samples = 1
|
| 215 |
self.untokenized_single_reference_predictions = [
|
| 216 |
-
"This is a prediction sentence 1. This is a prediction sentence 2."
|
|
|
|
| 217 |
self.untokenized_single_reference_references = [
|
| 218 |
-
"This is a reference sentence 1. This is a reference sentence 2."
|
|
|
|
| 219 |
|
| 220 |
self.tokenized_single_reference_predictions = [
|
| 221 |
["This is a prediction sentence 1.", "This is a prediction sentence 2."],
|
|
@@ -228,7 +270,10 @@ class TestSemF1(unittest.TestCase):
|
|
| 228 |
"Prediction sentence 1. Prediction sentence 2."
|
| 229 |
]
|
| 230 |
self.untokenized_multi_reference_references = [
|
| 231 |
-
[
|
|
|
|
|
|
|
|
|
|
| 232 |
]
|
| 233 |
|
| 234 |
self.tokenized_multi_reference_predictions = [
|
|
@@ -237,21 +282,21 @@ class TestSemF1(unittest.TestCase):
|
|
| 237 |
self.tokenized_multi_reference_references = [
|
| 238 |
[
|
| 239 |
["Reference sentence 1.", "Reference sentence 2."],
|
| 240 |
-
["Alternative reference 1.", "Alternative reference 2."]
|
| 241 |
],
|
| 242 |
]
|
| 243 |
self.multi_sample_refs = [
|
| 244 |
-
|
| 245 |
-
|
| 246 |
]
|
| 247 |
self.multi_sample_preds = [
|
| 248 |
-
|
| 249 |
-
|
| 250 |
]
|
| 251 |
-
|
| 252 |
def test_aggregate_multi_sample(self):
|
| 253 |
"""
|
| 254 |
-
check if a `Scores` class is returned instead of a list of
|
| 255 |
`Scores`
|
| 256 |
"""
|
| 257 |
scores = self.semf1_metric.compute(
|
|
@@ -265,7 +310,7 @@ class TestSemF1(unittest.TestCase):
|
|
| 265 |
aggregate=True,
|
| 266 |
)
|
| 267 |
self.assertIsInstance(scores, Scores)
|
| 268 |
-
print(f
|
| 269 |
|
| 270 |
def test_aggregate_untokenized_single_ref(self):
|
| 271 |
scores = self.semf1_metric.compute(
|
|
@@ -279,7 +324,7 @@ class TestSemF1(unittest.TestCase):
|
|
| 279 |
aggregate=True,
|
| 280 |
)
|
| 281 |
self.assertIsInstance(scores, Scores)
|
| 282 |
-
print(f
|
| 283 |
|
| 284 |
def test_aggregate_tokenized_single_ref(self):
|
| 285 |
scores = self.semf1_metric.compute(
|
|
@@ -293,7 +338,7 @@ class TestSemF1(unittest.TestCase):
|
|
| 293 |
aggregate=True,
|
| 294 |
)
|
| 295 |
self.assertIsInstance(scores, Scores)
|
| 296 |
-
print(f
|
| 297 |
|
| 298 |
def test_aggregate_untokenized_multi_ref(self):
|
| 299 |
scores = self.semf1_metric.compute(
|
|
@@ -307,7 +352,7 @@ class TestSemF1(unittest.TestCase):
|
|
| 307 |
aggregate=True,
|
| 308 |
)
|
| 309 |
self.assertIsInstance(scores, Scores)
|
| 310 |
-
print(f
|
| 311 |
|
| 312 |
def test_aggregate_tokenized_multi_ref(self):
|
| 313 |
scores = self.semf1_metric.compute(
|
|
@@ -321,7 +366,7 @@ class TestSemF1(unittest.TestCase):
|
|
| 321 |
aggregate=True,
|
| 322 |
)
|
| 323 |
self.assertIsInstance(scores, Scores)
|
| 324 |
-
print(f
|
| 325 |
|
| 326 |
def test_aggregate_same_pred_and_ref(self):
|
| 327 |
scores = self.semf1_metric.compute(
|
|
@@ -335,7 +380,7 @@ class TestSemF1(unittest.TestCase):
|
|
| 335 |
aggregate=True,
|
| 336 |
)
|
| 337 |
self.assertIsInstance(scores, Scores)
|
| 338 |
-
print(f
|
| 339 |
|
| 340 |
def test_untokenized_single_reference(self):
|
| 341 |
scores = self.semf1_metric.compute(
|
|
@@ -345,10 +390,12 @@ class TestSemF1(unittest.TestCase):
|
|
| 345 |
multi_references=False,
|
| 346 |
gpu=False,
|
| 347 |
batch_size=32,
|
| 348 |
-
verbose=False
|
| 349 |
)
|
| 350 |
self.assertIsInstance(scores, list)
|
| 351 |
-
self.assertEqual(
|
|
|
|
|
|
|
| 352 |
|
| 353 |
def test_tokenized_single_reference(self):
|
| 354 |
scores = self.semf1_metric.compute(
|
|
@@ -358,7 +405,7 @@ class TestSemF1(unittest.TestCase):
|
|
| 358 |
multi_references=False,
|
| 359 |
gpu=False,
|
| 360 |
batch_size=32,
|
| 361 |
-
verbose=False
|
| 362 |
)
|
| 363 |
self.assertIsInstance(scores, list)
|
| 364 |
self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
|
|
@@ -376,7 +423,7 @@ class TestSemF1(unittest.TestCase):
|
|
| 376 |
multi_references=True,
|
| 377 |
gpu=False,
|
| 378 |
batch_size=32,
|
| 379 |
-
verbose=False
|
| 380 |
)
|
| 381 |
self.assertIsInstance(scores, list)
|
| 382 |
self.assertEqual(len(scores), len(self.untokenized_multi_reference_predictions))
|
|
@@ -389,7 +436,7 @@ class TestSemF1(unittest.TestCase):
|
|
| 389 |
multi_references=True,
|
| 390 |
gpu=False,
|
| 391 |
batch_size=32,
|
| 392 |
-
verbose=False
|
| 393 |
)
|
| 394 |
self.assertIsInstance(scores, list)
|
| 395 |
self.assertEqual(len(scores), len(self.tokenized_multi_reference_predictions))
|
|
@@ -407,7 +454,7 @@ class TestSemF1(unittest.TestCase):
|
|
| 407 |
multi_references=False,
|
| 408 |
gpu=False,
|
| 409 |
batch_size=32,
|
| 410 |
-
verbose=False
|
| 411 |
)
|
| 412 |
|
| 413 |
self.assertIsInstance(scores, list)
|
|
@@ -416,7 +463,12 @@ class TestSemF1(unittest.TestCase):
|
|
| 416 |
for score in scores:
|
| 417 |
self.assertIsInstance(score, Scores)
|
| 418 |
self.assertAlmostEqual(score.precision, 1.0, places=6)
|
| 419 |
-
assert_almost_equal(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
def test_exact_output_scores(self):
|
| 422 |
predictions = [
|
|
@@ -473,7 +525,9 @@ class TestSemF1(unittest.TestCase):
|
|
| 473 |
["I am", "I am"],
|
| 474 |
[None, "I am"],
|
| 475 |
]
|
| 476 |
-
print(
|
|
|
|
|
|
|
| 477 |
|
| 478 |
# Case 2: tokenize_sentences = False, multi_references = True
|
| 479 |
tokenize_sentences = False
|
|
@@ -486,7 +540,9 @@ class TestSemF1(unittest.TestCase):
|
|
| 486 |
[["I am", "I am"], [None, "I am"]],
|
| 487 |
[[None, "I am"]],
|
| 488 |
]
|
| 489 |
-
print(
|
|
|
|
|
|
|
| 490 |
|
| 491 |
# Case 3: tokenize_sentences = True, multi_references = False
|
| 492 |
tokenize_sentences = True
|
|
@@ -499,7 +555,9 @@ class TestSemF1(unittest.TestCase):
|
|
| 499 |
"I am. I am.",
|
| 500 |
"I am. I am.",
|
| 501 |
]
|
| 502 |
-
print(
|
|
|
|
|
|
|
| 503 |
|
| 504 |
# Case 4: tokenize_sentences = False, multi_references = False
|
| 505 |
# This is taken care by the library itself
|
|
@@ -513,7 +571,9 @@ class TestSemF1(unittest.TestCase):
|
|
| 513 |
["I am.", "I am."],
|
| 514 |
["I am.", "I am."],
|
| 515 |
]
|
| 516 |
-
print(
|
|
|
|
|
|
|
| 517 |
|
| 518 |
def test_empty_input(self):
|
| 519 |
predictions = ["", ""]
|
|
@@ -538,22 +598,16 @@ class TestCosineSimilarity(unittest.TestCase):
|
|
| 538 |
|
| 539 |
def setUp(self):
|
| 540 |
# Sample embeddings for testing
|
| 541 |
-
self.pred_embeds = np.array([
|
| 542 |
-
|
| 543 |
-
[0, 1, 0],
|
| 544 |
-
[0, 0, 1]
|
| 545 |
-
])
|
| 546 |
-
self.ref_embeds = np.array([
|
| 547 |
-
[1, 0, 0],
|
| 548 |
-
[0, 1, 0],
|
| 549 |
-
[0, 0, 1]
|
| 550 |
-
])
|
| 551 |
|
| 552 |
self.pred_embeds_random = np.random.rand(3, 3)
|
| 553 |
self.ref_embeds_random = np.random.rand(3, 3)
|
| 554 |
|
| 555 |
def test_cosine_similarity_perfect_match(self):
|
| 556 |
-
precision, recall = _compute_cosine_similarity(
|
|
|
|
|
|
|
| 557 |
|
| 558 |
# Expected values are 1.0 for both precision and recall since embeddings are identical
|
| 559 |
self.assertAlmostEqual(precision, 1.0, places=5)
|
|
@@ -571,7 +625,9 @@ class TestCosineSimilarity(unittest.TestCase):
|
|
| 571 |
self.assertAlmostEqual(recall, expected_recall, places=5)
|
| 572 |
|
| 573 |
def test_cosine_similarity_random(self):
|
| 574 |
-
self._test_cosine_similarity_base(
|
|
|
|
|
|
|
| 575 |
|
| 576 |
def test_cosine_similarity_different_shapes(self):
|
| 577 |
pred_embeds_diff = np.random.rand(5, 3)
|
|
@@ -607,7 +663,7 @@ class TestValidateInputFormat(unittest.TestCase):
|
|
| 607 |
self.untokenized_multi_reference_references = [
|
| 608 |
[
|
| 609 |
"This is a reference sentence 1. This is a reference sentence 2.",
|
| 610 |
-
"Another reference sentence."
|
| 611 |
]
|
| 612 |
]
|
| 613 |
|
|
@@ -618,7 +674,7 @@ class TestValidateInputFormat(unittest.TestCase):
|
|
| 618 |
self.tokenized_multi_reference_references = [
|
| 619 |
[
|
| 620 |
["This is a reference sentence 1.", "This is a reference sentence 2."],
|
| 621 |
-
["Another reference sentence."]
|
| 622 |
]
|
| 623 |
]
|
| 624 |
|
|
@@ -701,7 +757,10 @@ class TestValidateInputFormat(unittest.TestCase):
|
|
| 701 |
True,
|
| 702 |
True,
|
| 703 |
self.untokenized_single_reference_predictions,
|
| 704 |
-
[
|
|
|
|
|
|
|
|
|
|
| 705 |
)
|
| 706 |
|
| 707 |
|
|
@@ -709,5 +768,5 @@ def run_tests():
|
|
| 709 |
unittest.main(verbosity=2)
|
| 710 |
|
| 711 |
|
| 712 |
-
if __name__ ==
|
| 713 |
run_tests()
|
|
|
|
| 10 |
|
| 11 |
from .encoder_models import SBertEncoder, get_encoder
|
| 12 |
from .semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
|
| 13 |
+
from .utils import (
|
| 14 |
+
get_gpu,
|
| 15 |
+
slice_embeddings,
|
| 16 |
+
is_nested_list_of_type,
|
| 17 |
+
flatten_list,
|
| 18 |
+
compute_f1,
|
| 19 |
+
Scores,
|
| 20 |
+
)
|
| 21 |
|
| 22 |
|
| 23 |
class TestUtils(unittest.TestCase):
|
|
|
|
| 47 |
self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
|
| 48 |
|
| 49 |
# Test list input with unique elements
|
| 50 |
+
self.assertEqual(
|
| 51 |
+
get_gpu([True, "cpu", 0]),
|
| 52 |
+
[0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"],
|
| 53 |
+
)
|
| 54 |
|
| 55 |
# Test list input with duplicate elements
|
| 56 |
+
self.assertEqual(
|
| 57 |
+
get_gpu([0, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"]
|
| 58 |
+
)
|
| 59 |
|
| 60 |
# Test list input with duplicate elements of different types
|
| 61 |
+
self.assertEqual(
|
| 62 |
+
get_gpu([True, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"]
|
| 63 |
+
)
|
| 64 |
|
| 65 |
# Test list input but only one element
|
| 66 |
self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu")
|
| 67 |
|
| 68 |
# Test list input with all integers
|
| 69 |
+
self.assertEqual(
|
| 70 |
+
get_gpu(list(range(gpu_count))),
|
| 71 |
+
list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"],
|
| 72 |
+
)
|
| 73 |
|
| 74 |
with self.assertRaises(ValueError):
|
| 75 |
get_gpu("invalid")
|
|
|
|
| 82 |
num_sentences = [3, 2, 5]
|
| 83 |
expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
|
| 84 |
self.assertTrue(
|
| 85 |
+
all(
|
| 86 |
+
np.array_equal(a, b)
|
| 87 |
+
for a, b in zip(
|
| 88 |
+
slice_embeddings(embeddings, num_sentences), expected_output
|
| 89 |
+
)
|
| 90 |
+
)
|
| 91 |
)
|
| 92 |
|
| 93 |
num_sentences_nested = [[2, 1], [3, 4]]
|
| 94 |
+
expected_output_nested = [
|
| 95 |
+
[embeddings[:2], embeddings[2:3]],
|
| 96 |
+
[embeddings[3:6], embeddings[6:]],
|
| 97 |
+
]
|
| 98 |
self.assertTrue(
|
| 99 |
slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
|
| 100 |
)
|
|
|
|
| 111 |
self.assertEqual(is_valid, False)
|
| 112 |
|
| 113 |
# Test case: Depth 1, list of elements matching element_type
|
| 114 |
+
self.assertEqual(
|
| 115 |
+
is_nested_list_of_type(["apple", "banana"], str, 1), (True, "")
|
| 116 |
+
)
|
| 117 |
|
| 118 |
# Test case: Depth 1, list of elements not matching element_type
|
| 119 |
is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1)
|
|
|
|
| 125 |
|
| 126 |
# Depth 2
|
| 127 |
self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, ""))
|
| 128 |
+
self.assertEqual(
|
| 129 |
+
is_nested_list_of_type([["1", "2"], ["3", "4"]], str, 2), (True, "")
|
| 130 |
+
)
|
| 131 |
is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)
|
| 132 |
self.assertEqual(is_valid, False)
|
| 133 |
|
|
|
|
| 134 |
# Depth 3
|
| 135 |
is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3)
|
| 136 |
self.assertEqual(is_valid, False)
|
| 137 |
+
self.assertEqual(
|
| 138 |
+
is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3), (True, "")
|
| 139 |
+
)
|
| 140 |
|
| 141 |
# Test case: Depth is negative, expecting ValueError
|
| 142 |
with self.assertRaises(ValueError):
|
|
|
|
| 162 |
class TestSBertEncoder(unittest.TestCase):
|
| 163 |
def setUp(self, device=None):
|
| 164 |
if device is None:
|
| 165 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 166 |
else:
|
| 167 |
self.device = device
|
| 168 |
self.model_name = "stsb-roberta-large"
|
| 169 |
self.batch_size = 8
|
| 170 |
self.verbose = False
|
| 171 |
+
self.encoder = SBertEncoder(self.model_name)
|
| 172 |
|
| 173 |
def test_initialization(self):
|
| 174 |
self.assertIsInstance(self.encoder.model, SentenceTransformer)
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
def test_encode_single_device(self):
|
| 177 |
sentences = ["This is a test sentence.", "Here is another sentence."]
|
| 178 |
+
embeddings = self.encoder.encode(
|
| 179 |
+
sentences,
|
| 180 |
+
device=self.device,
|
| 181 |
+
batch_size=self.batch_size,
|
| 182 |
+
verbose=self.verbose,
|
| 183 |
+
)
|
| 184 |
self.assertIsInstance(embeddings, np.ndarray)
|
| 185 |
self.assertEqual(embeddings.shape[0], len(sentences))
|
| 186 |
+
self.assertEqual(
|
| 187 |
+
embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension()
|
| 188 |
+
)
|
| 189 |
|
| 190 |
def test_encode_multi_device(self):
|
| 191 |
if torch.cuda.device_count() < 2:
|
| 192 |
self.skipTest("Multi-GPU test requires at least 2 GPUs.")
|
| 193 |
else:
|
| 194 |
+
# devices = ["cuda:0", "cuda:1"]
|
| 195 |
+
devices = [0, 1]
|
| 196 |
self.setUp(devices)
|
| 197 |
+
sentences = [
|
| 198 |
+
"This is a test sentence.",
|
| 199 |
+
"Here is another sentence.",
|
| 200 |
+
"This is a test sentence.",
|
| 201 |
+
]
|
| 202 |
+
embeddings = self.encoder.encode(
|
| 203 |
+
sentences,
|
| 204 |
+
device=devices,
|
| 205 |
+
batch_size=self.batch_size,
|
| 206 |
+
verbose=self.verbose,
|
| 207 |
+
)
|
| 208 |
self.assertIsInstance(embeddings, np.ndarray)
|
| 209 |
self.assertEqual(embeddings.shape[0], 3)
|
| 210 |
+
self.assertEqual(
|
| 211 |
+
embeddings.shape[1],
|
| 212 |
+
self.encoder.model.get_sentence_embedding_dimension(),
|
| 213 |
+
)
|
| 214 |
|
| 215 |
|
| 216 |
class TestGetEncoder(unittest.TestCase):
|
|
|
|
| 220 |
self.verbose = False
|
| 221 |
|
| 222 |
def _base_test(self, model_name):
|
| 223 |
+
encoder = get_encoder(model_name)
|
|
|
|
|
|
|
| 224 |
self.assertIsInstance(encoder, SBertEncoder)
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
def test_get_sbert_encoder(self):
|
| 227 |
model_name = "stsb-roberta-large"
|
|
|
|
| 236 |
model_name = "roberta-base"
|
| 237 |
self._base_test(model_name)
|
| 238 |
|
| 239 |
+
def test_get_encoder_environment_error(self):
|
| 240 |
model_name = "abc" # Wrong model_name
|
| 241 |
with self.assertRaises(EnvironmentError):
|
| 242 |
+
get_encoder(model_name)
|
| 243 |
|
| 244 |
def test_get_encoder_other_exception(self):
|
| 245 |
model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib
|
| 246 |
with self.assertRaises(RuntimeError):
|
| 247 |
+
get_encoder(model_name)
|
| 248 |
|
| 249 |
|
| 250 |
class TestSemF1(unittest.TestCase):
|
|
|
|
| 253 |
|
| 254 |
# Example cases, #Samples = 1
|
| 255 |
self.untokenized_single_reference_predictions = [
|
| 256 |
+
"This is a prediction sentence 1. This is a prediction sentence 2."
|
| 257 |
+
]
|
| 258 |
self.untokenized_single_reference_references = [
|
| 259 |
+
"This is a reference sentence 1. This is a reference sentence 2."
|
| 260 |
+
]
|
| 261 |
|
| 262 |
self.tokenized_single_reference_predictions = [
|
| 263 |
["This is a prediction sentence 1.", "This is a prediction sentence 2."],
|
|
|
|
| 270 |
"Prediction sentence 1. Prediction sentence 2."
|
| 271 |
]
|
| 272 |
self.untokenized_multi_reference_references = [
|
| 273 |
+
[
|
| 274 |
+
"Reference sentence 1. Reference sentence 2.",
|
| 275 |
+
"Alternative reference 1. Alternative reference 2.",
|
| 276 |
+
],
|
| 277 |
]
|
| 278 |
|
| 279 |
self.tokenized_multi_reference_predictions = [
|
|
|
|
| 282 |
self.tokenized_multi_reference_references = [
|
| 283 |
[
|
| 284 |
["Reference sentence 1.", "Reference sentence 2."],
|
| 285 |
+
["Alternative reference 1.", "Alternative reference 2."],
|
| 286 |
],
|
| 287 |
]
|
| 288 |
self.multi_sample_refs = [
|
| 289 |
+
"this is the first reference sample",
|
| 290 |
+
"this is the second reference sample",
|
| 291 |
]
|
| 292 |
self.multi_sample_preds = [
|
| 293 |
+
"this is the first prediction sample",
|
| 294 |
+
"this is the second prediction sample",
|
| 295 |
]
|
| 296 |
+
|
| 297 |
def test_aggregate_multi_sample(self):
|
| 298 |
"""
|
| 299 |
+
check if a `Scores` class is returned instead of a list of
|
| 300 |
`Scores`
|
| 301 |
"""
|
| 302 |
scores = self.semf1_metric.compute(
|
|
|
|
| 310 |
aggregate=True,
|
| 311 |
)
|
| 312 |
self.assertIsInstance(scores, Scores)
|
| 313 |
+
print(f"Score: {scores}")
|
| 314 |
|
| 315 |
def test_aggregate_untokenized_single_ref(self):
|
| 316 |
scores = self.semf1_metric.compute(
|
|
|
|
| 324 |
aggregate=True,
|
| 325 |
)
|
| 326 |
self.assertIsInstance(scores, Scores)
|
| 327 |
+
print(f"Score: {scores}")
|
| 328 |
|
| 329 |
def test_aggregate_tokenized_single_ref(self):
|
| 330 |
scores = self.semf1_metric.compute(
|
|
|
|
| 338 |
aggregate=True,
|
| 339 |
)
|
| 340 |
self.assertIsInstance(scores, Scores)
|
| 341 |
+
print(f"Score: {scores}")
|
| 342 |
|
| 343 |
def test_aggregate_untokenized_multi_ref(self):
|
| 344 |
scores = self.semf1_metric.compute(
|
|
|
|
| 352 |
aggregate=True,
|
| 353 |
)
|
| 354 |
self.assertIsInstance(scores, Scores)
|
| 355 |
+
print(f"Score: {scores}")
|
| 356 |
|
| 357 |
def test_aggregate_tokenized_multi_ref(self):
|
| 358 |
scores = self.semf1_metric.compute(
|
|
|
|
| 366 |
aggregate=True,
|
| 367 |
)
|
| 368 |
self.assertIsInstance(scores, Scores)
|
| 369 |
+
print(f"Score: {scores}")
|
| 370 |
|
| 371 |
def test_aggregate_same_pred_and_ref(self):
|
| 372 |
scores = self.semf1_metric.compute(
|
|
|
|
| 380 |
aggregate=True,
|
| 381 |
)
|
| 382 |
self.assertIsInstance(scores, Scores)
|
| 383 |
+
print(f"Score: {scores}")
|
| 384 |
|
| 385 |
def test_untokenized_single_reference(self):
|
| 386 |
scores = self.semf1_metric.compute(
|
|
|
|
| 390 |
multi_references=False,
|
| 391 |
gpu=False,
|
| 392 |
batch_size=32,
|
| 393 |
+
verbose=False,
|
| 394 |
)
|
| 395 |
self.assertIsInstance(scores, list)
|
| 396 |
+
self.assertEqual(
|
| 397 |
+
len(scores), len(self.untokenized_single_reference_predictions)
|
| 398 |
+
)
|
| 399 |
|
| 400 |
def test_tokenized_single_reference(self):
|
| 401 |
scores = self.semf1_metric.compute(
|
|
|
|
| 405 |
multi_references=False,
|
| 406 |
gpu=False,
|
| 407 |
batch_size=32,
|
| 408 |
+
verbose=False,
|
| 409 |
)
|
| 410 |
self.assertIsInstance(scores, list)
|
| 411 |
self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
|
|
|
|
| 423 |
multi_references=True,
|
| 424 |
gpu=False,
|
| 425 |
batch_size=32,
|
| 426 |
+
verbose=False,
|
| 427 |
)
|
| 428 |
self.assertIsInstance(scores, list)
|
| 429 |
self.assertEqual(len(scores), len(self.untokenized_multi_reference_predictions))
|
|
|
|
| 436 |
multi_references=True,
|
| 437 |
gpu=False,
|
| 438 |
batch_size=32,
|
| 439 |
+
verbose=False,
|
| 440 |
)
|
| 441 |
self.assertIsInstance(scores, list)
|
| 442 |
self.assertEqual(len(scores), len(self.tokenized_multi_reference_predictions))
|
|
|
|
| 454 |
multi_references=False,
|
| 455 |
gpu=False,
|
| 456 |
batch_size=32,
|
| 457 |
+
verbose=False,
|
| 458 |
)
|
| 459 |
|
| 460 |
self.assertIsInstance(scores, list)
|
|
|
|
| 463 |
for score in scores:
|
| 464 |
self.assertIsInstance(score, Scores)
|
| 465 |
self.assertAlmostEqual(score.precision, 1.0, places=6)
|
| 466 |
+
assert_almost_equal(
|
| 467 |
+
score.recall,
|
| 468 |
+
1,
|
| 469 |
+
decimal=5,
|
| 470 |
+
err_msg="Not all values are almost equal to 1",
|
| 471 |
+
)
|
| 472 |
|
| 473 |
def test_exact_output_scores(self):
|
| 474 |
predictions = [
|
|
|
|
| 525 |
["I am", "I am"],
|
| 526 |
[None, "I am"],
|
| 527 |
]
|
| 528 |
+
print(
|
| 529 |
+
f"Case I\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
|
| 530 |
+
)
|
| 531 |
|
| 532 |
# Case 2: tokenize_sentences = False, multi_references = True
|
| 533 |
tokenize_sentences = False
|
|
|
|
| 540 |
[["I am", "I am"], [None, "I am"]],
|
| 541 |
[[None, "I am"]],
|
| 542 |
]
|
| 543 |
+
print(
|
| 544 |
+
f"Case II\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
|
| 545 |
+
)
|
| 546 |
|
| 547 |
# Case 3: tokenize_sentences = True, multi_references = False
|
| 548 |
tokenize_sentences = True
|
|
|
|
| 555 |
"I am. I am.",
|
| 556 |
"I am. I am.",
|
| 557 |
]
|
| 558 |
+
print(
|
| 559 |
+
f"Case III\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
|
| 560 |
+
)
|
| 561 |
|
| 562 |
# Case 4: tokenize_sentences = False, multi_references = False
|
| 563 |
# This is taken care by the library itself
|
|
|
|
| 571 |
["I am.", "I am."],
|
| 572 |
["I am.", "I am."],
|
| 573 |
]
|
| 574 |
+
print(
|
| 575 |
+
f"Case IV\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
|
| 576 |
+
)
|
| 577 |
|
| 578 |
def test_empty_input(self):
|
| 579 |
predictions = ["", ""]
|
|
|
|
| 598 |
|
| 599 |
def setUp(self):
|
| 600 |
# Sample embeddings for testing
|
| 601 |
+
self.pred_embeds = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
| 602 |
+
self.ref_embeds = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
self.pred_embeds_random = np.random.rand(3, 3)
|
| 605 |
self.ref_embeds_random = np.random.rand(3, 3)
|
| 606 |
|
| 607 |
def test_cosine_similarity_perfect_match(self):
|
| 608 |
+
precision, recall = _compute_cosine_similarity(
|
| 609 |
+
self.pred_embeds, self.ref_embeds
|
| 610 |
+
)
|
| 611 |
|
| 612 |
# Expected values are 1.0 for both precision and recall since embeddings are identical
|
| 613 |
self.assertAlmostEqual(precision, 1.0, places=5)
|
|
|
|
| 625 |
self.assertAlmostEqual(recall, expected_recall, places=5)
|
| 626 |
|
| 627 |
def test_cosine_similarity_random(self):
|
| 628 |
+
self._test_cosine_similarity_base(
|
| 629 |
+
self.pred_embeds_random, self.ref_embeds_random
|
| 630 |
+
)
|
| 631 |
|
| 632 |
def test_cosine_similarity_different_shapes(self):
|
| 633 |
pred_embeds_diff = np.random.rand(5, 3)
|
|
|
|
| 663 |
self.untokenized_multi_reference_references = [
|
| 664 |
[
|
| 665 |
"This is a reference sentence 1. This is a reference sentence 2.",
|
| 666 |
+
"Another reference sentence.",
|
| 667 |
]
|
| 668 |
]
|
| 669 |
|
|
|
|
| 674 |
self.tokenized_multi_reference_references = [
|
| 675 |
[
|
| 676 |
["This is a reference sentence 1.", "This is a reference sentence 2."],
|
| 677 |
+
["Another reference sentence."],
|
| 678 |
]
|
| 679 |
]
|
| 680 |
|
|
|
|
| 757 |
True,
|
| 758 |
True,
|
| 759 |
self.untokenized_single_reference_predictions,
|
| 760 |
+
[
|
| 761 |
+
self.untokenized_single_reference_predictions[0],
|
| 762 |
+
self.untokenized_single_reference_predictions[0],
|
| 763 |
+
],
|
| 764 |
)
|
| 765 |
|
| 766 |
|
|
|
|
| 768 |
unittest.main(verbosity=2)
|
| 769 |
|
| 770 |
|
| 771 |
+
if __name__ == "__main__":
|
| 772 |
run_tests()
|