File size: 17,235 Bytes
7186695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
"""
Transformer-based Dependency Parser using PhoBERT.

This module implements a Biaffine dependency parser with PhoBERT as the encoder,
following the Trankit approach but using Vietnamese-specific PhoBERT.

Architecture:
    Input → PhoBERT → Word-level pooling → MLP projections → Biaffine attention → MST decoding

Reference:
- Dozat & Manning (2017): Deep Biaffine Attention for Neural Dependency Parsing
- Nguyen & Nguyen (2020): PhoBERT: Pre-trained language models for Vietnamese
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Dict, Any
import numpy as np

from bamboo1.models.mst import mst_decode, batch_mst_decode


class MLP(nn.Module):
    """Multi-layer perceptron for biaffine scoring."""

    def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33):
        super().__init__()
        self.linear = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.LeakyReLU(0.1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.activation(self.linear(x)))


class Biaffine(nn.Module):
    """Biaffine attention layer for dependency scoring."""

    def __init__(
        self,
        input_dim: int,
        output_dim: int = 1,
        bias_x: bool = True,
        bias_y: bool = True
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.bias_x = bias_x
        self.bias_y = bias_y

        self.weight = nn.Parameter(
            torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y)
        )
        nn.init.xavier_uniform_(self.weight)

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, input_dim) - dependent representations
            y: (batch, seq_len, input_dim) - head representations

        Returns:
            scores: (batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1
        """
        if self.bias_x:
            x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1)
        if self.bias_y:
            y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1)

        # (batch, seq_len, output_dim, input_dim+1)
        x = torch.einsum('bxi,oij->bxoj', x, self.weight)
        # (batch, seq_len, seq_len, output_dim)
        scores = torch.einsum('bxoj,byj->bxyo', x, y)

        if self.output_dim == 1:
            scores = scores.squeeze(-1)

        return scores


class PhoBERTDependencyParser(nn.Module):
    """
    PhoBERT-based Biaffine Dependency Parser.

    Uses PhoBERT as encoder with first-subword pooling for word alignment,
    followed by biaffine attention for arc and relation prediction.
    """

    def __init__(
        self,
        encoder_name: str = "vinai/phobert-base",
        n_rels: int = 50,
        arc_hidden: int = 500,
        rel_hidden: int = 100,
        dropout: float = 0.33,
        use_mst: bool = True,
    ):
        """
        Args:
            encoder_name: HuggingFace model name for PhoBERT
            n_rels: Number of dependency relations
            arc_hidden: Hidden dimension for arc MLPs
            rel_hidden: Hidden dimension for relation MLPs
            dropout: Dropout rate
            use_mst: Use MST decoding (True) or greedy decoding (False)
        """
        super().__init__()

        from transformers import AutoModel, AutoTokenizer

        self.encoder_name = encoder_name
        self.n_rels = n_rels
        self.use_mst = use_mst

        # Load PhoBERT encoder
        self.encoder = AutoModel.from_pretrained(encoder_name)
        self.tokenizer = AutoTokenizer.from_pretrained(encoder_name)
        self.hidden_size = self.encoder.config.hidden_size  # 768 for phobert-base

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # MLP projections
        self.mlp_arc_dep = MLP(self.hidden_size, arc_hidden, dropout)
        self.mlp_arc_head = MLP(self.hidden_size, arc_hidden, dropout)
        self.mlp_rel_dep = MLP(self.hidden_size, rel_hidden, dropout)
        self.mlp_rel_head = MLP(self.hidden_size, rel_hidden, dropout)

        # Biaffine attention
        self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False)
        self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True)

    def _get_word_embeddings(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        word_starts: torch.Tensor,
        word_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Get word-level embeddings from subword encoder output.

        Uses first-subword pooling strategy: each word is represented by
        the embedding of its first subword token.

        Args:
            input_ids: (batch, subword_seq_len) - Subword token IDs
            attention_mask: (batch, subword_seq_len) - Attention mask for subwords
            word_starts: (batch, word_seq_len) - Indices of first subword for each word
            word_mask: (batch, word_seq_len) - Mask for actual words

        Returns:
            word_embeddings: (batch, word_seq_len, hidden_size)
        """
        # Get encoder output
        encoder_output = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        hidden_states = encoder_output.last_hidden_state  # (batch, subword_seq_len, hidden)

        # Apply dropout
        hidden_states = self.dropout(hidden_states)

        # Extract word embeddings using first-subword indices
        batch_size, word_seq_len = word_starts.shape

        # Gather word embeddings
        # word_starts: (batch, word_seq_len) -> (batch, word_seq_len, hidden)
        word_embeddings = torch.gather(
            hidden_states,
            dim=1,
            index=word_starts.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
        )

        return word_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        word_starts: torch.Tensor,
        word_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass computing arc and relation scores.

        Args:
            input_ids: (batch, subword_seq_len) - Subword token IDs
            attention_mask: (batch, subword_seq_len) - Attention mask for subwords
            word_starts: (batch, word_seq_len) - Indices of first subword for each word
            word_mask: (batch, word_seq_len) - Mask for actual words

        Returns:
            arc_scores: (batch, word_seq_len, word_seq_len) - Arc scores
            rel_scores: (batch, word_seq_len, word_seq_len, n_rels) - Relation scores
        """
        # Get word-level embeddings
        word_embeddings = self._get_word_embeddings(
            input_ids, attention_mask, word_starts, word_mask
        )

        # MLP projections
        arc_dep = self.mlp_arc_dep(word_embeddings)
        arc_head = self.mlp_arc_head(word_embeddings)
        rel_dep = self.mlp_rel_dep(word_embeddings)
        rel_head = self.mlp_rel_head(word_embeddings)

        # Biaffine attention
        arc_scores = self.arc_attn(arc_dep, arc_head)  # (batch, seq, seq)
        rel_scores = self.rel_attn(rel_dep, rel_head)  # (batch, seq, seq, n_rels)

        return arc_scores, rel_scores

    def loss(
        self,
        arc_scores: torch.Tensor,
        rel_scores: torch.Tensor,
        heads: torch.Tensor,
        rels: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute cross-entropy loss for arcs and relations.

        Args:
            arc_scores: (batch, seq_len, seq_len) - Arc scores
            rel_scores: (batch, seq_len, seq_len, n_rels) - Relation scores
            heads: (batch, seq_len) - Gold head indices
            rels: (batch, seq_len) - Gold relation indices
            mask: (batch, seq_len) - Token mask (1 for real tokens, 0 for padding)

        Returns:
            Total loss (arc_loss + rel_loss)
        """
        batch_size, seq_len = mask.shape

        # Mask invalid positions
        arc_scores_masked = arc_scores.clone()
        arc_scores_masked = arc_scores_masked.masked_fill(~mask.unsqueeze(2), float('-inf'))

        # Arc loss: cross-entropy over possible heads
        arc_loss = F.cross_entropy(
            arc_scores_masked[mask].view(-1, seq_len),
            heads[mask],
            reduction='mean'
        )

        # Relation loss: cross-entropy conditioned on gold heads
        batch_indices = torch.arange(batch_size, device=rel_scores.device).unsqueeze(1)
        seq_indices = torch.arange(seq_len, device=rel_scores.device)
        rel_scores_gold = rel_scores[batch_indices, seq_indices, heads]

        rel_loss = F.cross_entropy(
            rel_scores_gold[mask],
            rels[mask],
            reduction='mean'
        )

        return arc_loss + rel_loss

    def decode(
        self,
        arc_scores: torch.Tensor,
        rel_scores: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode predictions using MST or greedy decoding.

        Args:
            arc_scores: (batch, seq_len, seq_len) - Arc scores
            rel_scores: (batch, seq_len, seq_len, n_rels) - Relation scores
            mask: (batch, seq_len) - Token mask

        Returns:
            arc_preds: (batch, seq_len) - Predicted head indices
            rel_preds: (batch, seq_len) - Predicted relation indices
        """
        batch_size, seq_len = mask.shape
        device = arc_scores.device

        if self.use_mst:
            # MST decoding for valid tree structure
            lengths = mask.sum(dim=1).cpu().numpy()
            arc_scores_np = arc_scores.cpu().numpy()
            arc_preds_np = batch_mst_decode(arc_scores_np, lengths)
            arc_preds = torch.from_numpy(arc_preds_np).to(device)
        else:
            # Greedy decoding
            arc_preds = arc_scores.argmax(dim=-1)

        # Get relation predictions for predicted heads
        batch_indices = torch.arange(batch_size, device=device).unsqueeze(1)
        seq_indices = torch.arange(seq_len, device=device)
        rel_scores_pred = rel_scores[batch_indices, seq_indices, arc_preds]
        rel_preds = rel_scores_pred.argmax(dim=-1)

        return arc_preds, rel_preds

    def predict(
        self,
        words: List[str],
        return_probs: bool = False,
    ) -> List[Tuple[str, int, str]]:
        """
        Predict dependencies for a single sentence.

        Args:
            words: List of words (pre-tokenized)
            return_probs: Whether to return probability scores

        Returns:
            List of (word, head, deprel) tuples
        """
        self.eval()
        device = next(self.parameters()).device

        # Tokenize with word boundary tracking
        encoded = self.tokenize_with_alignment([words])

        # Move to device
        input_ids = encoded['input_ids'].to(device)
        attention_mask = encoded['attention_mask'].to(device)
        word_starts = encoded['word_starts'].to(device)
        word_mask = encoded['word_mask'].to(device)

        with torch.no_grad():
            arc_scores, rel_scores = self.forward(
                input_ids, attention_mask, word_starts, word_mask
            )
            arc_preds, rel_preds = self.decode(arc_scores, rel_scores, word_mask)

        # Convert to list of tuples
        arc_preds = arc_preds[0].cpu().tolist()
        rel_preds = rel_preds[0].cpu().tolist()

        results = []
        for i, word in enumerate(words):
            head = arc_preds[i]
            rel_idx = rel_preds[i]
            rel = self.idx2rel.get(rel_idx, "dep")
            results.append((word, head, rel))

        return results

    def tokenize_with_alignment(
        self,
        sentences: List[List[str]],
        max_length: int = 256,
    ) -> Dict[str, torch.Tensor]:
        """
        Tokenize sentences and track word-subword alignment.

        Args:
            sentences: List of sentences, where each sentence is a list of words
            max_length: Maximum subword sequence length

        Returns:
            Dictionary with input_ids, attention_mask, word_starts, word_mask
        """
        batch_input_ids = []
        batch_attention_mask = []
        batch_word_starts = []
        batch_word_mask = []

        for words in sentences:
            # Tokenize each word separately to track boundaries
            word_starts = []
            subword_ids = [self.tokenizer.cls_token_id]

            for word in words:
                word_starts.append(len(subword_ids))
                word_tokens = self.tokenizer.encode(word, add_special_tokens=False)
                subword_ids.extend(word_tokens)

            subword_ids.append(self.tokenizer.sep_token_id)

            # Truncate if needed
            if len(subword_ids) > max_length:
                subword_ids = subword_ids[:max_length-1] + [self.tokenizer.sep_token_id]
                # Truncate word_starts that go beyond
                word_starts = [ws for ws in word_starts if ws < max_length - 1]

            attention_mask = [1] * len(subword_ids)

            batch_input_ids.append(subword_ids)
            batch_attention_mask.append(attention_mask)
            batch_word_starts.append(word_starts)
            batch_word_mask.append([1] * len(word_starts))

        # Pad sequences
        max_subword_len = max(len(ids) for ids in batch_input_ids)
        max_word_len = max(len(ws) for ws in batch_word_starts)

        padded_input_ids = []
        padded_attention_mask = []
        padded_word_starts = []
        padded_word_mask = []

        for i in range(len(sentences)):
            # Pad subwords
            pad_len = max_subword_len - len(batch_input_ids[i])
            padded_input_ids.append(
                batch_input_ids[i] + [self.tokenizer.pad_token_id] * pad_len
            )
            padded_attention_mask.append(
                batch_attention_mask[i] + [0] * pad_len
            )

            # Pad words
            word_pad_len = max_word_len - len(batch_word_starts[i])
            # Use 0 for padding word_starts (points to CLS token, but masked)
            padded_word_starts.append(
                batch_word_starts[i] + [0] * word_pad_len
            )
            padded_word_mask.append(
                batch_word_mask[i] + [0] * word_pad_len
            )

        return {
            'input_ids': torch.tensor(padded_input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(padded_attention_mask, dtype=torch.long),
            'word_starts': torch.tensor(padded_word_starts, dtype=torch.long),
            'word_mask': torch.tensor(padded_word_mask, dtype=torch.bool),
        }

    def save(self, path: str, vocab: Optional[Dict] = None):
        """
        Save model checkpoint.

        Args:
            path: Directory path to save the model
            vocab: Vocabulary dict with rel2idx and idx2rel mappings
        """
        import os
        os.makedirs(path, exist_ok=True)

        # Save model state
        checkpoint = {
            'model_state_dict': self.state_dict(),
            'config': {
                'encoder_name': self.encoder_name,
                'n_rels': self.n_rels,
                'arc_hidden': self.mlp_arc_dep.linear.out_features,
                'rel_hidden': self.mlp_rel_dep.linear.out_features,
                'dropout': self.dropout.p,
                'use_mst': self.use_mst,
            },
        }

        if vocab is not None:
            checkpoint['vocab'] = vocab

        torch.save(checkpoint, os.path.join(path, 'model.pt'))

        # Save tokenizer
        self.tokenizer.save_pretrained(path)

    @classmethod
    def load(cls, path: str, device: str = 'cpu') -> 'PhoBERTDependencyParser':
        """
        Load model from checkpoint.

        Args:
            path: Directory path containing the saved model
            device: Device to load the model to

        Returns:
            Loaded PhoBERTDependencyParser model
        """
        import os

        checkpoint = torch.load(
            os.path.join(path, 'model.pt'),
            map_location=device,
            weights_only=False
        )

        config = checkpoint['config']

        # Create model
        model = cls(
            encoder_name=config['encoder_name'],
            n_rels=config['n_rels'],
            arc_hidden=config['arc_hidden'],
            rel_hidden=config['rel_hidden'],
            dropout=config['dropout'],
            use_mst=config.get('use_mst', True),
        )

        # Load state dict
        model.load_state_dict(checkpoint['model_state_dict'])

        # Load vocabulary
        if 'vocab' in checkpoint:
            model.rel2idx = checkpoint['vocab'].get('rel2idx', {})
            model.idx2rel = checkpoint['vocab'].get('idx2rel', {})
        else:
            model.rel2idx = {}
            model.idx2rel = {}

        model.to(device)
        return model