File size: 7,027 Bytes
c9edf6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf54cb0
 
c9edf6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf54cb0
 
c9edf6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from transformers import AutoModel, AutoConfig

class PredicateAwareSRL(nn.Module):
    def __init__(self,
                 bert_name: str,
                 num_labels: int,
                 use_indicator: bool = True,
                 indicator_dim: int = 10,          # CHANGED: 10-dim predicate indicator
                 lstm_hidden: int = 768,           # CHANGED: BiLSTM hidden size = 768 (bidirectional)
                 mlp_hidden: int = 300,            # CHANGED: MLP hidden size = 300
                 dropout: float = 0.1,
                 use_distance: bool = True,        # NEW: enable relative position (distance) embeddings
                 pos_dim: int = 50,                # NEW: size of position embedding (random init, trainable)
                 max_distance: int = 128):         # NEW: clamp |i - p| to this range for bucketing
        super().__init__()
        self.config = AutoConfig.from_pretrained(bert_name)
        self.bert = AutoModel.from_pretrained(bert_name)
        # self.encoder = AutoModel.from_pretrained(bert_name)
        self.use_indicator = use_indicator

        # --- Input dim to BiLSTM = BERT_dim + (indicator_dim) + (pos_dim)
        bert_dim = self.config.hidden_size
        in_dim = bert_dim + (indicator_dim if use_indicator else 0)

        # Two rows which indicate 0 not predicate 1 is predicate, so need to 2 embedding (rows)
        # num_embeddings (int) – size of the dictionary of embeddings
        # embedding_dim (int) – the size of each embedding vector

        if use_indicator:
            self.indicator_emb = nn.Embedding(2, indicator_dim)  # 0/1 → 10-dim (random init, trainable)  # CHANGED

        self.use_distance = use_distance                         # NEW
        self.max_distance = max_distance                          # NEW
        if use_distance:
            # Distance buckets: [-max_distance .. +max_distance] → indices [0 .. 2*max_distance]
            self.pos_emb = nn.Embedding(2 * max_distance + 1, pos_dim)  # NEW (random init, trainable)
            in_dim += pos_dim                                         # NEW

        # BiLSTM (bidirectional): total output dim = lstm_hidden
        self.bilstm = nn.LSTM(
            input_size=in_dim,
            hidden_size=lstm_hidden // 2,  # bi → half per direction
            num_layers=1,
            dropout=0.0,
            bidirectional=True,
            batch_first=True
        )

        self.dropout = nn.Dropout(dropout)

        # Classifier: concat(g_i, gp) so input dim = 2 * lstm_hidden
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, mlp_hidden),   # CHANGED (mlp_hidden=300)
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, num_labels)
        )

        self.pad_label_id = -100

    def forward(self,
                input_ids: torch.Tensor,           # [B, L]
                token_type_ids: torch.Tensor,      # [B, L]
                attention_mask: torch.Tensor,      # [B, L]
                word_first_wp_fullidx: torch.Tensor,  # [B, max_n] (positions in full seq; -1 for pad)
                sentence_mask: torch.Tensor,       # [B, max_n] (bool)
                sent_lens: torch.Tensor,           # [B]
                pred_word_idx: torch.Tensor,       # [B]
                indicator: torch.Tensor | None = None,  # [B, max_n] 0/1
                labels: torch.Tensor | None = None):    # [B, max_n]

        B, L = input_ids.size()
        device = input_ids.device

        # ---- BERT encoder
        bert_out = self.bert(
        # bert_out = self.encoder(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        )
        H = bert_out.last_hidden_state  # [B, L, D]

        # ---- Subword → word pooling (first subword)

        # Gather sentence word-level representations by taking FIRST subtoken hidden per word
        # Prepare indices (replace -1 with 0 to avoid gather OOB; we'll mask later)
        # This process is required to feed word level to predict BIO and role per word
        #.clone is for deep copy won't change original data

        gather_idx = word_first_wp_fullidx.clone()
        gather_idx[gather_idx < 0] = 0
        gather_idx = gather_idx.unsqueeze(-1).expand(-1, -1, H.size(-1))  # [B, max_n, D]
        H_words = torch.gather(H, dim=1, index=gather_idx)                 # [B, max_n, D]
        H_words = H_words * sentence_mask.unsqueeze(-1)                    # zero out pads

        # ---- Concatenate predicate indicator (0/1 → emb)
        # word_first_wp_fullidx: [1, 2, 3, -1, -1]
        # gather_idx after clamp: [1, 2, 3, 0, 0]   # 0 points to [CLS], just a placeholder
        # H_words = gather(H, ...)                  # grabs vectors at positions 1,2,3,0,0
        # sentence_mask:        [1, 1, 1, 0, 0]
        # H_words *= mask →     [vec1, vec2, vec3, 0, 0]   # padded slots zeroed out


        X = H_words
        if self.use_indicator and indicator is not None:
            ind_emb = self.indicator_emb(indicator.clamp(0, 1))            # [B, max_n, 10]  # CHANGED
            X = torch.cat([X, ind_emb], dim=-1)

        # ---- NEW: Relative position (distance-to-predicate) embeddings
        if self.use_distance:
            # positions: 0..max_n-1 per sentence
            max_n = X.size(1)
            positions = torch.arange(max_n, device=device).unsqueeze(0).expand(B, -1)  # [B, max_n]
            rel = positions - pred_word_idx.unsqueeze(1)                               # [B, max_n], can be <0
            rel = rel.clamp(-self.max_distance, self.max_distance) + self.max_distance # shift to [0 .. 2*max_distance]
            pos_feats = self.pos_emb(rel)                                              # [B, max_n, pos_dim]  # NEW
            X = torch.cat([X, pos_feats], dim=-1)                                      # [B, max_n, in_dim]  # NEW

        # ---- BiLSTM (packed)
        lengths = sent_lens.detach().cpu()
        packed = pack_padded_sequence(X, lengths=lengths, batch_first=True, enforce_sorted=False)
        G_packed, _ = self.bilstm(packed)
        G, _ = pad_packed_sequence(G_packed, batch_first=True)      # [B, max_n, lstm_hidden]
        G = self.dropout(G)

        # ---- Predicate hidden (word-level) and concat to every position
        batch_idx = torch.arange(B, device=device)
        gp = G[batch_idx, pred_word_idx.clamp(min=0), :]            # [B, lstm_hidden]
        gp_expanded = gp.unsqueeze(1).expand(-1, G.size(1), -1)     # [B, max_n, lstm_hidden]

        logits = self.classifier(torch.cat([G, gp_expanded], dim=-1))  # [B, max_n, num_labels]

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad_label_id)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        return logits, loss