Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel | |
| from huggingface_hub import hf_hub_download | |
| repo_id = "aaljabari/arabic-relation-extraction-v1" | |
| class BertRE(nn.Module): | |
| def __init__(self, num_labels): | |
| super().__init__() | |
| self.bert = BertModel.from_pretrained(repo_id) | |
| hidden = self.bert.config.hidden_size | |
| self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) | |
| self.classifier = nn.Linear(hidden * 2, num_labels) | |
| def forward(self, input_ids, attention_mask, sub_pos, obj_pos): | |
| outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| hidden = outputs.last_hidden_state | |
| batch = hidden.shape[0] | |
| sub_vec = hidden[torch.arange(batch), sub_pos] | |
| obj_vec = hidden[torch.arange(batch), obj_pos] | |
| pair = torch.cat([sub_vec, obj_vec], dim=1) | |
| pair = self.dropout(pair) | |
| return self.classifier(pair) |