Liyan06 commited on
Commit ·
8c6cca8
1
Parent(s): 87eb2c4
claim format debug
Browse files- minicheck/minicheck.py +7 -4
minicheck/minicheck.py
CHANGED
|
@@ -4,7 +4,7 @@ import sys
|
|
| 4 |
sys.path.append("..")
|
| 5 |
|
| 6 |
from minicheck.inference import Inferencer
|
| 7 |
-
from typing import List
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
|
|
@@ -18,7 +18,7 @@ class MiniCheck:
|
|
| 18 |
max_input_length=max_input_length,
|
| 19 |
)
|
| 20 |
|
| 21 |
-
def score(self,
|
| 22 |
'''
|
| 23 |
pred_labels: 0 / 1 (0: unsupported, 1: supported)
|
| 24 |
max_support_probs: the probability of "supported" for the chunk that determin the final pred_label
|
|
@@ -26,8 +26,11 @@ class MiniCheck:
|
|
| 26 |
support_prob_per_chunk: the probability of "supported" for each chunk
|
| 27 |
'''
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims)
|
| 33 |
pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob]
|
|
|
|
| 4 |
sys.path.append("..")
|
| 5 |
|
| 6 |
from minicheck.inference import Inferencer
|
| 7 |
+
from typing import List, Dict
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
|
|
|
|
| 18 |
max_input_length=max_input_length,
|
| 19 |
)
|
| 20 |
|
| 21 |
+
def score(self, inputs: Dict) -> List[float]:
|
| 22 |
'''
|
| 23 |
pred_labels: 0 / 1 (0: unsupported, 1: supported)
|
| 24 |
max_support_probs: the probability of "supported" for the chunk that determin the final pred_label
|
|
|
|
| 26 |
support_prob_per_chunk: the probability of "supported" for each chunk
|
| 27 |
'''
|
| 28 |
|
| 29 |
+
docs = inputs['docs']
|
| 30 |
+
claims = inputs['claims']
|
| 31 |
+
|
| 32 |
+
assert isinstance(docs, list) or isinstance(docs, np.ndarray), f"docs must be a list or np.ndarray"
|
| 33 |
+
assert isinstance(claims, list) or isinstance(claims, np.ndarray), f"claims must be a list or np.ndarray"
|
| 34 |
|
| 35 |
max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims)
|
| 36 |
pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob]
|