File size: 1,606 Bytes
78cc03b
43c5345
54f7dc9
 
 
43c5345
 
 
 
78cc03b
54f7dc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr  
import requests
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM, T5Config
import torch

MAX_SOURCE_LENGTH = 512


class ReviewerModel(T5ForConditionalGeneration):

    def __init__(self, config):
        super().__init__(config)
        self.cls_head = nn.Linear(config.d_model, 2, bias=True) 
        # Fixed typo: config not self.config
        self.init()

    def init(self):
        nn.init.xavier_uniform_(self.lm_head.weight)
        factor = self.config.initializer_factor
        self.cls_head.weight.data.normal_(mean=0.0, std=factor * (self.config.d_model ** -0.5))
        # Fixed exponentiation operator 
        self.cls_head.bias.data.zero_()

    def forward(
            self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels=None):
        
        # Simplified method signature to only include necessary arguments
        
        if labels is not None:
            # Added validation check for seq2seq case
            
            outputs = super().forward(
                input_ids=input_ids, 
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                labels=labels
            )
            # Call super forward method with correct arguments
            
            return outputs
        
        # Removed unnecessary conditional logic
        # Return super() forward directly for generation case