| |
|
|
| import sys |
| sys.path.append("..") |
|
|
| from minicheck_web.inference import Inferencer |
| from typing import List, Dict |
| import numpy as np |
|
|
|
|
| class MiniCheck: |
| def __init__(self, path, max_input_length=None, batch_size=16) -> None: |
|
|
| self.model = Inferencer( |
| path=path, |
| batch_size=batch_size, |
| max_input_length=max_input_length, |
| ) |
|
|
| def score(self, data: Dict) -> List[float]: |
| ''' |
| pred_labels: 0 / 1 (0: unsupported, 1: supported) |
| max_support_probs: the probability of "supported" for the chunk that determin the final pred_label |
| used_chunks: divided chunks of the input document |
| support_prob_per_chunk: the probability of "supported" for each chunk |
| ''' |
|
|
| inputs = data['inputs'] |
| docs = inputs['docs'] |
| claims = inputs['claims'] |
|
|
| if 'chunk_size' in inputs: |
| self.model.chunk_size = int(inputs['chunk_size']) |
| else: |
| self.model.chunk_size = self.model.default_chunk_size |
|
|
| assert isinstance(docs, list) or isinstance(docs, np.ndarray), f"docs must be a list or np.ndarray" |
| assert isinstance(claims, list) or isinstance(claims, np.ndarray), f"claims must be a list or np.ndarray" |
|
|
| max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims) |
| pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob] |
|
|
| return pred_label, max_support_prob, used_chunk, support_prob_per_chunk |