Spaces:
Runtime error
Runtime error
File size: 984 Bytes
3f110a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import torch
import torch.nn as nn
from transformers import BertModel
from huggingface_hub import hf_hub_download
repo_id = "aaljabari/arabic-relation-extraction-v1"
class BertRE(nn.Module):
def __init__(self, num_labels):
super().__init__()
self.bert = BertModel.from_pretrained(repo_id)
hidden = self.bert.config.hidden_size
self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
self.classifier = nn.Linear(hidden * 2, num_labels)
def forward(self, input_ids, attention_mask, sub_pos, obj_pos):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
hidden = outputs.last_hidden_state
batch = hidden.shape[0]
sub_vec = hidden[torch.arange(batch), sub_pos]
obj_vec = hidden[torch.arange(batch), obj_pos]
pair = torch.cat([sub_vec, obj_vec], dim=1)
pair = self.dropout(pair)
return self.classifier(pair) |