|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from transformers import AutoModel, AutoConfig
|
|
|
|
|
|
class IndoBERTClassifier(nn.Module):
|
|
|
def __init__(self, config):
|
|
|
super(IndoBERTClassifier, self).__init__()
|
|
|
|
|
|
self.bert = AutoModel.from_pretrained(config._name_or_path, config=config)
|
|
|
self.dropout = nn.Dropout(config.classifier_dropout if hasattr(config, 'classifier_dropout') else 0.1)
|
|
|
hidden_size = self.bert.config.hidden_size
|
|
|
|
|
|
self.num_clickbait_labels = config.num_clickbait_labels
|
|
|
self.num_kategori_labels = config.num_kategori_labels
|
|
|
|
|
|
self.clickbait_classifier = nn.Linear(hidden_size, self.num_clickbait_labels)
|
|
|
self.kategori_classifier = nn.Linear(hidden_size, self.num_kategori_labels)
|
|
|
|
|
|
def forward(self, input_ids, attention_mask, clickbait_labels=None, kategori_labels=None, **kwargs):
|
|
|
output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
pooled_output = output.last_hidden_state[:, 0, :]
|
|
|
|
|
|
dropout_output = self.dropout(pooled_output)
|
|
|
|
|
|
clickbait_logits = self.clickbait_classifier(dropout_output)
|
|
|
kategori_logits = self.kategori_classifier(dropout_output)
|
|
|
|
|
|
return {
|
|
|
"clickbait_logits": clickbait_logits,
|
|
|
"kategori_logits": kategori_logits
|
|
|
} |