PHQ8-prototype / modeling_embedder_with_mlp.py
sasasassaszzd's picture
Update modeling_embedder_with_mlp.py
42c6c29 verified
from transformers import PreTrainedModel, AutoModel, AutoConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from .configuration_embedder_with_mlp import EmbedderWithMLPConfig
class EmbedderWithMLP(PreTrainedModel):
config_class = EmbedderWithMLPConfig
def __init__(self, config, embedder=None):
super().__init__(config)
if embedder is None:
backbone_config = AutoConfig.from_pretrained(config.model_name)
self.embedder = AutoModel.from_config(backbone_config)
else:
self.embedder = embedder # HF backbone (BERT, RoBERTa, etc.)
hidden_layer_list = [config.input_size] + config.hidden_layer_list
layers = nn.ModuleList()
for in_dim, out_dim in zip(hidden_layer_list[:-1], hidden_layer_list[1:]):
layers.append(nn.Linear(in_dim, out_dim))
layers.append(nn.GELU())
layers.append(nn.Dropout(config.dropout_ratio))
layers.append(nn.Linear(hidden_layer_list[-1], 1))
self.mlp = nn.Sequential(*layers)
self.post_init() # PreTrainedModel 필수
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
emb = self.embedder(input_ids=input_ids, attention_mask=attention_mask)
emb = emb.pooler_output
# emb = F.normalize(emb, p=2, dim=1)
logit = self.mlp(emb)
loss = None
if labels is not None:
loss = nn.BCELoss()(logit.view(-1), labels.float())
return {"loss": loss, "logits": logit}
def inference(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
emb = self.embedder(input_ids=input_ids, attention_mask=attention_mask)
emb = emb.pooler_output
# emb = F.normalize(emb, p=2, dim=1)
logit = self.mlp(emb)
prob = torch.sigmoid(logit)
score = prob.view(-1) * 24
score = score.detach().cpu().numpy()
score = score.item()
score = round(score)
return {"score": score}