| from transformers import Pipeline |
| import torch |
| from typing import Union |
|
|
|
|
| class DocumentSentenceRelevancePipeline(Pipeline): |
| |
| def _sanitize_parameters(self, **kwargs): |
| threshold = kwargs.get("threshold", 0.5) |
| return {}, {}, {"threshold": threshold} |
|
|
| def preprocess(self, inputs): |
| question = inputs.get("question", "") |
| context = inputs.get("context", [""]) |
| response = inputs.get("response", "") |
|
|
| q_enc = self.tokenizer(question, add_special_tokens=True, truncation=False, padding=False) |
| r_enc = self.tokenizer(response, add_special_tokens=True, truncation=False, padding=False) |
|
|
| question_ids = q_enc["input_ids"] |
| response_ids = r_enc["input_ids"] |
|
|
| document_sentences_ids = [] |
| for s in context: |
| s_enc = self.tokenizer(s, add_special_tokens=True, truncation=False, padding=False) |
| document_sentences_ids.append(s_enc["input_ids"]) |
|
|
| ids = question_ids + response_ids |
| pair_ids = [] |
| for s_ids in document_sentences_ids: |
| pair_ids.extend(s_ids) |
|
|
| total_length = len(ids) + len(pair_ids) |
| if total_length > self.tokenizer.model_max_length: |
| num_tokens_to_remove = total_length - self.tokenizer.model_max_length |
| ids, pair_ids, _ = self.tokenizer.truncate_sequences( |
| ids=ids, |
| pair_ids=pair_ids, |
| num_tokens_to_remove=num_tokens_to_remove, |
| truncation_strategy="only_second", |
| stride=0, |
| ) |
| combined_ids = ids + pair_ids |
| token_types = [0]*len(ids) + [1]*len(pair_ids) |
| attention_mask = [1]*len(combined_ids) |
|
|
| sentence_positions = [] |
| current_pos = len(ids) |
| found_sentences = 0 |
|
|
| for i, tok_id in enumerate(pair_ids): |
| if tok_id == self.tokenizer.cls_token_id: |
| sentence_positions.append(current_pos + i) |
| found_sentences += 1 |
|
|
| input_ids = torch.tensor([combined_ids], dtype=torch.long) |
| attention_mask = torch.tensor([attention_mask], dtype=torch.long) |
| token_type_ids = torch.tensor([token_types], dtype=torch.long) |
| sentence_positions = torch.tensor([sentence_positions], dtype=torch.long) |
|
|
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "token_type_ids": token_type_ids, |
| "sentence_positions": sentence_positions |
| } |
|
|
| def _forward(self, model_inputs): |
| return self.model(**model_inputs) |
|
|
| def __call__(self, inputs: Union[dict[str, str], list[dict[str, str]]], **kwargs): |
| if isinstance(inputs, dict): |
| inputs = [inputs] |
| model_outputs = super().__call__(inputs, **kwargs) |
| pipeline_outputs = [] |
| for i, output in enumerate(model_outputs): |
| sentences = inputs[i]["context"] |
| sentences_dict = { |
| "sentence": sentences, |
| "label": output["sentences"]["label"], |
| "score": output["sentences"]["score"] |
| } |
| |
| final_output = { |
| "document": output["document"], |
| "sentences": [ |
| { |
| "sentence": sent, |
| "label": label, |
| "score": score |
| } |
| for sent, label, score in zip( |
| sentences_dict["sentence"], |
| sentences_dict["label"], |
| sentences_dict["score"] |
| ) |
| ] |
| } |
| pipeline_outputs.append(final_output) |
| return pipeline_outputs |
|
|
| def postprocess(self, model_outputs, threshold = 0.5): |
| doc_logits = model_outputs.doc_logits |
| sent_logits = model_outputs.sent_logits |
| document_probabilities = torch.softmax(doc_logits, dim=-1) |
| sentence_probabilities = torch.softmax(sent_logits, dim=-1) |
| |
| document_best_class = (document_probabilities[:, 1] > threshold).long() |
| sentence_best_class = (sentence_probabilities[:, :, 1] > threshold).long() |
| document_score = document_probabilities[:, document_best_class] |
| |
| sentence_best_class = sentence_best_class.squeeze() |
| sentence_probabilities = sentence_probabilities.squeeze() |
| |
| if len(sentence_best_class.shape) == 0: |
| sentence_best_class = sentence_best_class.unsqueeze(0) |
| sentence_probabilities = sentence_probabilities.unsqueeze(0) |
| |
| batch_indices = torch.arange(len(sentence_best_class)) |
| sentence_scores = sentence_probabilities[batch_indices, sentence_best_class] |
| |
| best_document_label = document_best_class.numpy().item() |
| best_document_label = self.model.config.id2label[best_document_label] |
|
|
| best_sentence_labels = sentence_best_class.numpy().tolist() |
| best_sentence_labels = [self.model.config.id2label[label] for label in best_sentence_labels] |
| |
| document_output = {"label": best_document_label, "score": document_score.numpy().item()} |
| sentence_output = {"label": best_sentence_labels, "score": sentence_scores.numpy().tolist()} |
| return {"document": document_output, "sentences": sentence_output} |