| from transformers import T5EncoderModel, T5Config, PreTrainedModel | |
| import torch.nn as nn | |
| import torch | |
| class T5RegressionModel(PreTrainedModel): | |
| config_class = T5Config | |
| def __init__(self, config, d_model=None): | |
| super().__init__(config) | |
| self.encoder = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50") | |
| hidden_dim = d_model if d_model is not None else config.d_model | |
| self.regression_head = nn.Linear(hidden_dim, 1) | |
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): | |
| encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| hidden_states = encoder_outputs.last_hidden_state | |
| pooled_output = hidden_states[:, -1, :] | |
| logits = self.regression_head(pooled_output).squeeze(-1) | |
| loss = None | |
| if labels is not None: | |
| labels = labels.float() | |
| loss = nn.MSELoss()(logits, labels) | |
| return {"loss": loss, "logits": logits} |