KlikBERT / model.py
TrioF's picture
Upload 3 files
75d06a1 verified
# Nama file: model.py
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
class IndoBERTClassifier(nn.Module):
def __init__(self, config):
super(IndoBERTClassifier, self).__init__()
# Gunakan config dari model dasar untuk mengambil hidden_size
self.bert = AutoModel.from_pretrained(config._name_or_path, config=config)
self.dropout = nn.Dropout(config.classifier_dropout if hasattr(config, 'classifier_dropout') else 0.1)
hidden_size = self.bert.config.hidden_size
self.num_clickbait_labels = config.num_clickbait_labels
self.num_kategori_labels = config.num_kategori_labels
self.clickbait_classifier = nn.Linear(hidden_size, self.num_clickbait_labels)
self.kategori_classifier = nn.Linear(hidden_size, self.num_kategori_labels)
def forward(self, input_ids, attention_mask, clickbait_labels=None, kategori_labels=None, **kwargs):
output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = output.last_hidden_state[:, 0, :] # Ambil token [CLS]
dropout_output = self.dropout(pooled_output)
clickbait_logits = self.clickbait_classifier(dropout_output)
kategori_logits = self.kategori_classifier(dropout_output)
return {
"clickbait_logits": clickbait_logits,
"kategori_logits": kategori_logits
}