Spaces:
Runtime error
Runtime error
File size: 1,606 Bytes
78cc03b 43c5345 54f7dc9 43c5345 78cc03b 54f7dc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import gradio as gr
import requests
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM, T5Config
import torch
MAX_SOURCE_LENGTH = 512
class ReviewerModel(T5ForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
self.cls_head = nn.Linear(config.d_model, 2, bias=True)
# Fixed typo: config not self.config
self.init()
def init(self):
nn.init.xavier_uniform_(self.lm_head.weight)
factor = self.config.initializer_factor
self.cls_head.weight.data.normal_(mean=0.0, std=factor * (self.config.d_model ** -0.5))
# Fixed exponentiation operator
self.cls_head.bias.data.zero_()
def forward(
self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels=None):
# Simplified method signature to only include necessary arguments
if labels is not None:
# Added validation check for seq2seq case
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=labels
)
# Call super forward method with correct arguments
return outputs
# Removed unnecessary conditional logic
# Return super() forward directly for generation case
|