Spaces:
Running
Running
Merge pull request #1 from davidberenstein1957/main
Browse files- app_utils/entailment_checker.py +29 -12
app_utils/entailment_checker.py
CHANGED
|
@@ -60,8 +60,10 @@ class EntailmentChecker(BaseComponent):
|
|
| 60 |
def run(self, query: str, documents: List[Document]):
|
| 61 |
|
| 62 |
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
doc.meta["entailment_info"] = entailment_info
|
| 66 |
|
| 67 |
scores += doc.score
|
|
@@ -93,17 +95,32 @@ class EntailmentChecker(BaseComponent):
|
|
| 93 |
return entailment_checker_result, "output_1"
|
| 94 |
|
| 95 |
def run_batch(self, queries: List[str], documents: List[Document]):
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
def
|
|
|
|
| 99 |
with torch.inference_mode():
|
| 100 |
-
inputs = self.tokenizer(
|
| 101 |
-
f"{premise}{self.tokenizer.sep_token}{hypotesis}", return_tensors="pt"
|
| 102 |
-
).to(self.devices[0])
|
| 103 |
out = self.model(**inputs)
|
| 104 |
logits = out.logits
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
|
| 109 |
-
return entailment_dict
|
|
|
|
| 60 |
def run(self, query: str, documents: List[Document]):
|
| 61 |
|
| 62 |
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
|
| 63 |
+
premise_batch = [doc.content for doc in documents]
|
| 64 |
+
hypotesis_batch = [query] * len(documents)
|
| 65 |
+
entailment_info_batch = self.get_entailment_batch(premise_batch=premise_batch, hypotesis_batch=hypotesis_batch)
|
| 66 |
+
for i, (doc, entailment_info) in enumerate(zip(documents, entailment_info_batch)):
|
| 67 |
doc.meta["entailment_info"] = entailment_info
|
| 68 |
|
| 69 |
scores += doc.score
|
|
|
|
| 95 |
return entailment_checker_result, "output_1"
|
| 96 |
|
| 97 |
def run_batch(self, queries: List[str], documents: List[Document]):
|
| 98 |
+
entailment_checker_result_batch = []
|
| 99 |
+
entailment_info_batch = self.get_entailment_batch(premise_batch=documents, hypotesis_batch=queries)
|
| 100 |
+
for doc, entailment_info in zip(documents, entailment_info_batch):
|
| 101 |
+
doc.meta["entailment_info"] = entailment_info
|
| 102 |
+
aggregate_entailment_info = {
|
| 103 |
+
"contradiction": round(entailment_info["contradiction"] / doc.score),
|
| 104 |
+
"neutral": round(entailment_info["neutral"] / doc.score),
|
| 105 |
+
"entailment": round(entailment_info["entailment"] / doc.score),
|
| 106 |
+
}
|
| 107 |
+
entailment_checker_result_batch.append({
|
| 108 |
+
"documents": [doc],
|
| 109 |
+
"aggregate_entailment_info": aggregate_entailment_info,
|
| 110 |
+
})
|
| 111 |
+
return entailment_checker_result_batch, "output_1"
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_entailment_dict(self, probs):
|
| 115 |
+
entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
|
| 116 |
+
return entailment_dict
|
| 117 |
|
| 118 |
+
def get_entailment_batch(self, premise_batch: List[str], hypotesis_batch: List[str]):
|
| 119 |
+
formatted_texts = [f"{premise}{self.tokenizer.sep_token}{hypotesis}" for premise, hypotesis in zip(premise_batch, hypotesis_batch)]
|
| 120 |
with torch.inference_mode():
|
| 121 |
+
inputs = self.tokenizer(formatted_texts, return_tensors="pt", padding=True, truncation=True).to(self.devices[0])
|
|
|
|
|
|
|
| 122 |
out = self.model(**inputs)
|
| 123 |
logits = out.logits
|
| 124 |
+
probs_batch = (torch.nn.functional.softmax(logits, dim=-1).detach().cpu().numpy() )
|
| 125 |
+
return [self.get_entailment_dict(probs) for probs in probs_batch]
|
| 126 |
+
|
|
|
|
|
|