FikriRiyadi's picture
Create model.py
4aebad0 verified
raw
history blame contribute delete
633 Bytes
import torch
import torch.nn as nn
from transformers import AutoModel
class CyberRoBERTa(nn.Module):
def __init__(self, model_name="cahya/roberta-base-indonesian-522M", num_labels=12):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
self.drop = nn.Dropout(0.3)
self.out = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state[:, 0]
x = self.drop(pooled)
return torch.sigmoid(self.out(x))