File size: 984 Bytes
3f110a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
from transformers import BertModel
from huggingface_hub import hf_hub_download

repo_id = "aaljabari/arabic-relation-extraction-v1"

class BertRE(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.bert = BertModel.from_pretrained(repo_id)

        hidden = self.bert.config.hidden_size
        self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        self.classifier = nn.Linear(hidden * 2, num_labels)

    def forward(self, input_ids, attention_mask, sub_pos, obj_pos):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        hidden = outputs.last_hidden_state
        batch = hidden.shape[0]

        sub_vec = hidden[torch.arange(batch), sub_pos]
        obj_vec = hidden[torch.arange(batch), obj_pos]

        pair = torch.cat([sub_vec, obj_vec], dim=1)
        pair = self.dropout(pair)

        return self.classifier(pair)