| from argparse import Namespace |
| from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
| def get_markers_for_model(): |
| special_tokens_constants = Namespace() |
| special_tokens_constants.separator_different_qa = "&&&" |
| special_tokens_constants.separator_output_question_answer = "SSEEPP" |
| special_tokens_constants.source_prefix = "qa: " |
| return special_tokens_constants |
|
|
| def load_trained_model(name_or_path): |
| tokenizer = AutoTokenizer.from_pretrained(name_or_path) |
| model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path) |
| return model, tokenizer |
|
|
|
|
| class QADiscourse_Pipeline(Text2TextGenerationPipeline): |
| def __init__(self, model_repo: str, **kwargs): |
| model, tokenizer = load_trained_model(model_repo) |
| super().__init__(model, tokenizer, framework="pt") |
| self.special_tokens = get_markers_for_model() |
| |
|
|
| def preprocess(self, inputs): |
| if isinstance(inputs, str): |
| processed_inputs = self._preprocess_string(inputs) |
| elif hasattr(inputs, "__iter__"): |
| processed_inputs = [self._preprocess_string(s) for s in inputs] |
| else: |
| raise ValueError("inputs must be str or Iterable[str]") |
| |
| return super().preprocess(processed_inputs) |
| |
| def _preprocess_string(self, seq: str) -> str: |
| seq = self.special_tokens.source_prefix + seq |
| print(seq) |
| return seq |
| |
| def _forward(self, *args, **kwargs): |
| outputs = super()._forward(*args, **kwargs) |
| return outputs |
|
|
|
|
| def postprocess(self, model_outputs): |
| predictions = self.tokenizer.decode(model_outputs["output_ids"].squeeze(), skip_special_tokens=True, clean_up_tokenization_spaces=False) |
| seperated_qas = self._split_to_list(predictions) |
| qas = [] |
| for qa_pair in seperated_qas: |
| post_process = self._postrocess_qa(qa_pair) |
| if post_process is not None: |
| qas.append(post_process) |
| return qas |
|
|
| def _split_to_list(self, output_seq: str) -> list: |
| return output_seq.split(self.special_tokens.separator_different_qa) |
|
|
| |
| def _postrocess_qa(self, seq: str) -> str: |
| |
| if self.special_tokens.separator_output_question_answer in seq: |
| question, answer = seq.split(self.special_tokens.separator_output_question_answer) |
| else: |
| return None |
| return {"question": question, "answer": answer} |
| |
| |
| if __name__ == "__main__": |
| pipe = QADiscourse_Pipeline("RonEliav/QA_discourse_v2") |
| res1 = pipe("I don't like chocolate, but I like cookies.") |
| res2 = pipe(["I don't like chocolate, but I like cookies.", |
| "I dived in the sea easily"], num_beams=10) |
| print(res1) |
| print(res2) |