Spaces:
Runtime error
Runtime error
davebulaval
commited on
Commit
Β·
7217d6a
1
Parent(s):
1982c24
uniformization of interface and add .to for tokenizer output
Browse files- meaningbert.py +14 -16
meaningbert.py
CHANGED
|
@@ -64,8 +64,8 @@ _KWARGS_DESCRIPTION = """
|
|
| 64 |
MeaningBERT metric for assessing meaning preservation between sentences.
|
| 65 |
|
| 66 |
Args:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
device (str): Device to use for model inference. By default, set to "cuda".
|
| 70 |
|
| 71 |
Returns:
|
|
@@ -75,10 +75,10 @@ Returns:
|
|
| 75 |
|
| 76 |
Examples:
|
| 77 |
|
| 78 |
-
>>>
|
| 79 |
-
>>>
|
| 80 |
>>> meaning_bert = evaluate.load("davebulaval/meaningbert", device="cuda:0")
|
| 81 |
-
>>> results = meaning_bert.compute(
|
| 82 |
"""
|
| 83 |
|
| 84 |
_HASH = "21845c0cc85a2e8e16c89bb0053f489095cf64c5b19e9c3865d3e10047aba51b"
|
|
@@ -110,19 +110,17 @@ class MeaningBERT(evaluate.Metric):
|
|
| 110 |
|
| 111 |
def _compute(
|
| 112 |
self,
|
| 113 |
-
|
| 114 |
-
|
| 115 |
device: str = "cuda",
|
| 116 |
) -> Dict:
|
| 117 |
-
assert len(
|
| 118 |
-
|
| 119 |
-
), "The number of
|
| 120 |
hashcode = _HASH
|
| 121 |
|
| 122 |
# Index of sentence with perfect match between two sentences
|
| 123 |
-
matching_index = [
|
| 124 |
-
i for i, item in enumerate(documents) if item in simplifications
|
| 125 |
-
]
|
| 126 |
|
| 127 |
# We load the MeaningBERT pretrained model
|
| 128 |
scorer = AutoModelForSequenceClassification.from_pretrained(
|
|
@@ -135,12 +133,12 @@ class MeaningBERT(evaluate.Metric):
|
|
| 135 |
|
| 136 |
# We tokenize the text as a pair and return Pytorch Tensors
|
| 137 |
tokenize_text = tokenizer(
|
| 138 |
-
|
| 139 |
-
|
| 140 |
truncation=True,
|
| 141 |
padding=True,
|
| 142 |
return_tensors="pt",
|
| 143 |
-
)
|
| 144 |
|
| 145 |
with filter_logging_context():
|
| 146 |
# We process the text
|
|
|
|
| 64 |
MeaningBERT metric for assessing meaning preservation between sentences.
|
| 65 |
|
| 66 |
Args:
|
| 67 |
+
references (list of str): References sentences.
|
| 68 |
+
predictions (list of str): Predictions sentences (same number of element as documents).
|
| 69 |
device (str): Device to use for model inference. By default, set to "cuda".
|
| 70 |
|
| 71 |
Returns:
|
|
|
|
| 75 |
|
| 76 |
Examples:
|
| 77 |
|
| 78 |
+
>>> references = ["hello there", "general kenobi"]
|
| 79 |
+
>>> predictions = ["hello there", "general kenobi"]
|
| 80 |
>>> meaning_bert = evaluate.load("davebulaval/meaningbert", device="cuda:0")
|
| 81 |
+
>>> results = meaning_bert.compute(references=references, predictions=predictions)
|
| 82 |
"""
|
| 83 |
|
| 84 |
_HASH = "21845c0cc85a2e8e16c89bb0053f489095cf64c5b19e9c3865d3e10047aba51b"
|
|
|
|
| 110 |
|
| 111 |
def _compute(
|
| 112 |
self,
|
| 113 |
+
references: List,
|
| 114 |
+
predictions: List,
|
| 115 |
device: str = "cuda",
|
| 116 |
) -> Dict:
|
| 117 |
+
assert len(references) == len(
|
| 118 |
+
predictions
|
| 119 |
+
), "The number of references is different of the number of predictions."
|
| 120 |
hashcode = _HASH
|
| 121 |
|
| 122 |
# Index of sentence with perfect match between two sentences
|
| 123 |
+
matching_index = [i for i, item in enumerate(references) if item in predictions]
|
|
|
|
|
|
|
| 124 |
|
| 125 |
# We load the MeaningBERT pretrained model
|
| 126 |
scorer = AutoModelForSequenceClassification.from_pretrained(
|
|
|
|
| 133 |
|
| 134 |
# We tokenize the text as a pair and return Pytorch Tensors
|
| 135 |
tokenize_text = tokenizer(
|
| 136 |
+
references,
|
| 137 |
+
predictions,
|
| 138 |
truncation=True,
|
| 139 |
padding=True,
|
| 140 |
return_tensors="pt",
|
| 141 |
+
).to(device)
|
| 142 |
|
| 143 |
with filter_logging_context():
|
| 144 |
# We process the text
|