#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue Sep 5 15:30:06 2023 @author: peter """ import torch import qarac.models.QaracEncoderModel import qarac.models.QaracDecoderModel class QaracTrainerModel(torch.nn.Module): def __init__(self,base_model_path,tokenizer): """ Sets up the Trainer model Parameters ---------- base_encoder_model : transformers.RobertaModel Base model for encoders. base_decoder_model : transformers.RobertaModel Base model for decoder tokenizer : transformers.RobertaTokenizer Tokeniaer for decoder Returns ------- None. """ super(QaracTrainerModel,self).__init__() self.question_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_model_path) self.answer_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_model_path) config = self.answer_encoder.config config.is_decoder = True self.decoder = qarac.models.QaracDecoderModel.QaracDecoderModel(base_model_path, config, tokenizer) self.cosine = torch.nn.CosineSimilarity(dim=1,eps=1.0e-12) def forward(self, all_text, offset_text, question, answer, proposition0, proposition1, conclusion_offset, statement0, statement1): """ Generates training objectives from data Parameters ---------- all_text : torch.tensor Tokenized text for encode-decode objective offset_text : torch.tensor As above, prefixed with question : torch.tensor tokenized question for question ansering objective answer : torch.tensor tokenized answer for question answering objective proposition0 : torch.tensor tokenized proposition for reasoning objective. proposition1 : otrch.tensor tokenized proposition for reasoning objective conclusion_offset : torch.tensor tokeniaed conclusion for reasoning objective, prefixed with statement0 : torch.tensor tokenized statement for consistency objective statement1 : torch.tensor tokenized.statement for consistency ogjective Returns ------- encode_decode : transformers.modeling_outputs.CausalLMOutputWithCrossAttentions Predicted text for encode-decode task question_answering : torch.tensor Difference between encoded question and encoded answeer reasoning : transformers.modeling_outputs.CausalLMOutputWithCrossAttentions Predicted text for reasoning objective consistency : torch.tensor Cosine similarity of vectorized statements """ encode_decode = self.decoder((self.answer_encoder(all_text), offset_text)) question_answering = self.question_encoder(question) - self.answer_encoder(answer) reasoning = self.decoder((self.answer_encoder(proposition0) +self.answer_encoder(proposition1), conclusion_offset)) s0 = self.answer_encoder(statement0) s1 = self.answer_encoder(statement1) consistency = self.cosine(s0,s1) return (encode_decode,question_answering,reasoning,consistency)