Not-Grim-Refer's picture
Update app.py
54f7dc9
import gradio as gr
import requests
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM, T5Config
import torch
MAX_SOURCE_LENGTH = 512
class ReviewerModel(T5ForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
self.cls_head = nn.Linear(config.d_model, 2, bias=True)
# Fixed typo: config not self.config
self.init()
def init(self):
nn.init.xavier_uniform_(self.lm_head.weight)
factor = self.config.initializer_factor
self.cls_head.weight.data.normal_(mean=0.0, std=factor * (self.config.d_model ** -0.5))
# Fixed exponentiation operator
self.cls_head.bias.data.zero_()
def forward(
self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels=None):
# Simplified method signature to only include necessary arguments
if labels is not None:
# Added validation check for seq2seq case
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=labels
)
# Call super forward method with correct arguments
return outputs
# Removed unnecessary conditional logic
# Return super() forward directly for generation case