Spaces:
Running
Running
| from typing import List, Optional | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
| import torch | |
| from haystack.nodes.base import BaseComponent | |
| from haystack.modeling.utils import initialize_device_settings | |
| from haystack.schema import Document, Answer, Span | |
| class EntailmentChecker(BaseComponent): | |
| """ | |
| This node checks the entailment between every document content and the query. | |
| It enrichs the documents metadata with entailment_info | |
| """ | |
| outgoing_edges = 1 | |
| def __init__( | |
| self, | |
| model_name_or_path: str = "roberta-large-mnli", | |
| model_version: Optional[str] = None, | |
| tokenizer: Optional[str] = None, | |
| use_gpu: bool = True, | |
| batch_size: int = 16, | |
| ): | |
| """ | |
| Load a Natural Language Inference model from Transformers. | |
| :param model_name_or_path: Directory of a saved model or the name of a public model. | |
| See https://huggingface.co/models for full list of available models. | |
| :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. | |
| :param tokenizer: Name of the tokenizer (usually the same as model) | |
| :param use_gpu: Whether to use GPU (if available). | |
| # :param batch_size: Number of Documents to be processed at a time. | |
| """ | |
| super().__init__() | |
| self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) | |
| tokenizer = tokenizer or model_name_or_path | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| pretrained_model_name_or_path=model_name_or_path, revision=model_version | |
| ) | |
| self.batch_size = batch_size | |
| self.model.to(str(self.devices[0])) | |
| id2label = AutoConfig.from_pretrained(model_name_or_path).id2label | |
| self.labels = [id2label[k].lower() for k in sorted(id2label)] | |
| if "entailment" not in self.labels: | |
| raise ValueError( | |
| "The model config must contain entailment value in the id2label dict." | |
| ) | |
| def run(self, query: str, documents: List[Document]): | |
| for doc in documents: | |
| entailment_dict = self.get_entailment(premise=doc.content, hypotesis=query) | |
| doc.meta["entailment_info"] = entailment_dict | |
| return {"documents": documents}, "output_1" | |
| def run_batch(): | |
| pass | |
| def get_entailment(self, premise, hypotesis): | |
| with torch.no_grad(): | |
| inputs = self.tokenizer( | |
| f"{premise}{self.tokenizer.sep_token}{hypotesis}", return_tensors="pt" | |
| ).to(self.devices[0]) | |
| out = self.model(**inputs) | |
| logits = out.logits | |
| probs = ( | |
| torch.nn.functional.softmax(logits, dim=-1)[0, :].cpu().detach().numpy() | |
| ) | |
| entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)} | |
| return entailment_dict | |