Liyan06
commited on
Commit
·
09efa05
1
Parent(s):
3fe4664
customize chunk_size in score function
Browse files- handler.py +13 -1
- minicheck_web/inference.py +4 -2
- minicheck_web/minicheck.py +6 -2
handler.py
CHANGED
|
@@ -55,10 +55,20 @@ class EndpointHandler():
|
|
| 55 |
|
| 56 |
self.tfidf_order = True
|
| 57 |
self.num_highlights = 1
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
def __call__(self, data):
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
claim = data['inputs']['claims'][0]
|
| 63 |
ents = extract_entities(claim)
|
| 64 |
|
|
@@ -128,9 +138,11 @@ class EndpointHandler():
|
|
| 128 |
retrieved_data = {
|
| 129 |
'inputs': {
|
| 130 |
'docs': list(retrieved_docs),
|
| 131 |
-
'claims': [claim]*len(retrieved_docs)
|
|
|
|
| 132 |
}
|
| 133 |
}
|
|
|
|
| 134 |
_, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=retrieved_data)
|
| 135 |
end = time()
|
| 136 |
num_chunks = len([item for items in used_chunk for item in items])
|
|
|
|
| 55 |
|
| 56 |
self.tfidf_order = True
|
| 57 |
self.num_highlights = 1
|
| 58 |
+
|
| 59 |
+
self.default_chunk_size = 500
|
| 60 |
+
self.chunk_size = 500
|
| 61 |
|
| 62 |
|
| 63 |
def __call__(self, data):
|
| 64 |
|
| 65 |
+
# this is necessary for setting the chunk size for
|
| 66 |
+
# retrived docs
|
| 67 |
+
if 'chunk_size' in data['inputs']:
|
| 68 |
+
self.chunk_size = int(data['inputs']['chunk_size'])
|
| 69 |
+
else:
|
| 70 |
+
self.chunk_size = self.default_chunk_size
|
| 71 |
+
|
| 72 |
claim = data['inputs']['claims'][0]
|
| 73 |
ents = extract_entities(claim)
|
| 74 |
|
|
|
|
| 138 |
retrieved_data = {
|
| 139 |
'inputs': {
|
| 140 |
'docs': list(retrieved_docs),
|
| 141 |
+
'claims': [claim]*len(retrieved_docs),
|
| 142 |
+
'chunk_size': self.chunk_size
|
| 143 |
}
|
| 144 |
}
|
| 145 |
+
|
| 146 |
_, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=retrieved_data)
|
| 147 |
end = time()
|
| 148 |
num_chunks = len([item for items in used_chunk for item in items])
|
minicheck_web/inference.py
CHANGED
|
@@ -28,7 +28,7 @@ def sent_tokenize_with_newlines(text):
|
|
| 28 |
|
| 29 |
|
| 30 |
class Inferencer():
|
| 31 |
-
def __init__(self, path,
|
| 32 |
|
| 33 |
self.path = path
|
| 34 |
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
@@ -36,7 +36,9 @@ class Inferencer():
|
|
| 36 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device)
|
| 37 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
| 38 |
|
| 39 |
-
self.
|
|
|
|
|
|
|
| 40 |
self.max_input_length=2048 if max_input_length is None else max_input_length
|
| 41 |
self.max_output_length = 256
|
| 42 |
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
class Inferencer():
|
| 31 |
+
def __init__(self, path, max_input_length, batch_size) -> None:
|
| 32 |
|
| 33 |
self.path = path
|
| 34 |
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 36 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device)
|
| 37 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
| 38 |
|
| 39 |
+
self.default_chunk_size = 500
|
| 40 |
+
self.chunk_size=500
|
| 41 |
+
|
| 42 |
self.max_input_length=2048 if max_input_length is None else max_input_length
|
| 43 |
self.max_output_length = 256
|
| 44 |
|
minicheck_web/minicheck.py
CHANGED
|
@@ -9,12 +9,11 @@ import numpy as np
|
|
| 9 |
|
| 10 |
|
| 11 |
class MiniCheck:
|
| 12 |
-
def __init__(self, path,
|
| 13 |
|
| 14 |
self.model = Inferencer(
|
| 15 |
path=path,
|
| 16 |
batch_size=batch_size,
|
| 17 |
-
chunk_size=chunk_size,
|
| 18 |
max_input_length=max_input_length,
|
| 19 |
)
|
| 20 |
|
|
@@ -30,6 +29,11 @@ class MiniCheck:
|
|
| 30 |
docs = inputs['docs']
|
| 31 |
claims = inputs['claims']
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
assert isinstance(docs, list) or isinstance(docs, np.ndarray), f"docs must be a list or np.ndarray"
|
| 34 |
assert isinstance(claims, list) or isinstance(claims, np.ndarray), f"claims must be a list or np.ndarray"
|
| 35 |
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class MiniCheck:
|
| 12 |
+
def __init__(self, path, max_input_length=None, batch_size=16) -> None:
|
| 13 |
|
| 14 |
self.model = Inferencer(
|
| 15 |
path=path,
|
| 16 |
batch_size=batch_size,
|
|
|
|
| 17 |
max_input_length=max_input_length,
|
| 18 |
)
|
| 19 |
|
|
|
|
| 29 |
docs = inputs['docs']
|
| 30 |
claims = inputs['claims']
|
| 31 |
|
| 32 |
+
if 'chunk_size' in inputs:
|
| 33 |
+
self.model.chunk_size = int(inputs['chunk_size'])
|
| 34 |
+
else:
|
| 35 |
+
self.model.chunk_size = self.model.default_chunk_size
|
| 36 |
+
|
| 37 |
assert isinstance(docs, list) or isinstance(docs, np.ndarray), f"docs must be a list or np.ndarray"
|
| 38 |
assert isinstance(claims, list) or isinstance(claims, np.ndarray), f"claims must be a list or np.ndarray"
|
| 39 |
|