import torch import torch.nn as nn from transformers import BertModel from huggingface_hub import hf_hub_download repo_id = "aaljabari/arabic-relation-extraction-v1" class BertRE(nn.Module): def __init__(self, num_labels): super().__init__() self.bert = BertModel.from_pretrained(repo_id) hidden = self.bert.config.hidden_size self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) self.classifier = nn.Linear(hidden * 2, num_labels) def forward(self, input_ids, attention_mask, sub_pos, obj_pos): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask ) hidden = outputs.last_hidden_state batch = hidden.shape[0] sub_vec = hidden[torch.arange(batch), sub_pos] obj_vec = hidden[torch.arange(batch), obj_pos] pair = torch.cat([sub_vec, obj_vec], dim=1) pair = self.dropout(pair) return self.classifier(pair)