File size: 3,742 Bytes
f9c0522
 
 
 
 
 
 
 
56e5680
23a97a2
d0cb26f
f9c0522
798488e
56e5680
518e821
f9c0522
519dfd1
f16a715
 
 
 
 
56e5680
f16a715
56e5680
f16a715
 
 
 
 
 
 
 
8626c30
519dfd1
 
 
 
 
 
 
14d83dc
f9c0522
56e5680
 
 
 
 
 
 
 
 
 
f16a715
56e5680
f16a715
 
 
56e5680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f16a715
 
 
56e5680
 
 
 
 
 
 
 
f16a715
 
56e5680
 
 
 
 
 
798488e
 
 
 
 
7758fd9
1a9032d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep  5 15:30:06 2023

@author: peter
"""

import torch
import qarac.models.QaracEncoderModel
import qarac.models.QaracDecoderModel



class QaracTrainerModel(torch.nn.Module):
    
    def __init__(self,base_model_path,tokenizer):
        """
        Sets up the Trainer model

        Parameters
        ----------
        base_encoder_model : transformers.RobertaModel
            Base model for encoders.
        base_decoder_model : transformers.RobertaModel
            Base model for decoder
        tokenizer : transformers.RobertaTokenizer
            Tokeniaer for decoder
        Returns
        -------
        None.

        """
        super(QaracTrainerModel,self).__init__()
        self.question_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_model_path)
        self.answer_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_model_path)
        config = self.answer_encoder.config
        config.is_decoder = True
        self.decoder = qarac.models.QaracDecoderModel.QaracDecoderModel(base_model_path,
                                                                        config,
                                                                        tokenizer)
        self.cosine = torch.nn.CosineSimilarity(dim=1,eps=1.0e-12)
        
    def forward(self,
                all_text,
                offset_text,
                question,
                answer,
                proposition0,
                proposition1,
                conclusion_offset,
                statement0,
                statement1):
        """
        Generates training objectives from data

        Parameters
        ----------
        all_text : torch.tensor
            Tokenized text for encode-decode objective
        offset_text : torch.tensor
            As above, prefixed with <s>
        question : torch.tensor
            tokenized question for question ansering objective
        answer : torch.tensor
            tokenized answer for question answering objective
        proposition0 : torch.tensor
            tokenized proposition for reasoning objective.
        proposition1 : otrch.tensor
            tokenized proposition for reasoning objective
        conclusion_offset : torch.tensor
            tokeniaed conclusion for reasoning objective, prefixed with <s>
        statement0 : torch.tensor
            tokenized statement for consistency objective
        statement1 : torch.tensor
            tokenized.statement for consistency ogjective

        Returns
        -------
        encode_decode : transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
            Predicted text for encode-decode task
        question_answering : torch.tensor
            Difference between encoded question and encoded answeer
        reasoning : transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
            Predicted text for reasoning objective
        consistency : torch.tensor
            Cosine similarity of vectorized statements

        """
        encode_decode = self.decoder((self.answer_encoder(all_text),
                                      offset_text))
        question_answering = self.question_encoder(question) - self.answer_encoder(answer)
        reasoning = self.decoder((self.answer_encoder(proposition0)
                                             +self.answer_encoder(proposition1),
                                             conclusion_offset))
        s0 = self.answer_encoder(statement0)
        
        s1 = self.answer_encoder(statement1)
        
        consistency = self.cosine(s0,s1)
        return (encode_decode,question_answering,reasoning,consistency)