rain1024 commited on
Commit
48ba615
·
verified ·
1 Parent(s): 0bf906a

Upload src/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/train.py +1006 -0
src/train.py ADDED
@@ -0,0 +1,1006 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "transformers>=4.30.0",
6
+ # "datasets>=2.14.0",
7
+ # "click>=8.0.0",
8
+ # "tqdm>=4.60.0",
9
+ # "wandb>=0.15.0",
10
+ # "python-dotenv>=1.0.0",
11
+ # ]
12
+ # ///
13
+ """
14
+ Training script for Bamboo-1 Vietnamese Dependency Parser.
15
+
16
+ Supports multiple methods:
17
+ - baseline: BiLSTM + Biaffine (Dozat & Manning, 2017)
18
+ - trankit: XLM-RoBERTa + Biaffine (Nguyen et al., 2021)
19
+
20
+ Usage:
21
+ uv run scripts/train.py # Default baseline
22
+ uv run scripts/train.py --method trankit # Reproduce Trankit
23
+ uv run scripts/train.py --method trankit --dataset ud-vtb # Trankit on VTB
24
+ """
25
+
26
+ import sys
27
+ from pathlib import Path
28
+ from collections import Counter
29
+ from dataclasses import dataclass
30
+ from typing import List, Tuple, Optional
31
+
32
+ # Load environment variables
33
+ from dotenv import load_dotenv
34
+ load_dotenv()
35
+
36
+ import torch
37
+ import torch.nn as nn
38
+ import torch.nn.functional as F
39
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
40
+ from torch.utils.data import Dataset, DataLoader
41
+ from torch.optim import Adam, AdamW
42
+ from torch.optim.lr_scheduler import ExponentialLR
43
+ from tqdm import tqdm
44
+
45
+ import click
46
+
47
+ sys.path.insert(0, str(Path(__file__).parent.parent))
48
+ from src.corpus import UDD1Corpus
49
+ from src.ud_corpus import UDVietnameseVTB
50
+ from src.vndt_corpus import VnDTCorpus
51
+ from src.cost_estimate import CostTracker, detect_hardware
52
+
53
+
54
+ # ============================================================================
55
+ # Data Processing
56
+ # ============================================================================
57
+
58
+ @dataclass
59
+ class Sentence:
60
+ """A dependency-parsed sentence."""
61
+ words: List[str]
62
+ heads: List[int]
63
+ rels: List[str]
64
+
65
+
66
+ def read_conllu(path: str) -> List[Sentence]:
67
+ """Read CoNLL-U file and return list of sentences."""
68
+ sentences = []
69
+ words, heads, rels = [], [], []
70
+
71
+ with open(path, 'r', encoding='utf-8') as f:
72
+ for line in f:
73
+ line = line.strip()
74
+ if not line:
75
+ if words:
76
+ sentences.append(Sentence(words, heads, rels))
77
+ words, heads, rels = [], [], []
78
+ elif line.startswith('#'):
79
+ continue
80
+ else:
81
+ parts = line.split('\t')
82
+ if '-' in parts[0] or '.' in parts[0]: # Skip multi-word tokens
83
+ continue
84
+ words.append(parts[1]) # FORM
85
+ heads.append(int(parts[6])) # HEAD
86
+ rels.append(parts[7]) # DEPREL
87
+
88
+ if words:
89
+ sentences.append(Sentence(words, heads, rels))
90
+
91
+ return sentences
92
+
93
+
94
+ class Vocabulary:
95
+ """Vocabulary for words, characters, and relations."""
96
+ PAD = '<pad>'
97
+ UNK = '<unk>'
98
+
99
+ def __init__(self, min_freq: int = 2):
100
+ self.min_freq = min_freq
101
+ self.word2idx = {self.PAD: 0, self.UNK: 1}
102
+ self.char2idx = {self.PAD: 0, self.UNK: 1}
103
+ self.rel2idx = {}
104
+ self.idx2rel = {}
105
+
106
+ def build(self, sentences: List[Sentence]):
107
+ """Build vocabulary from sentences."""
108
+ word_counts = Counter()
109
+ char_counts = Counter()
110
+ rel_counts = Counter()
111
+
112
+ for sent in sentences:
113
+ for word in sent.words:
114
+ word_counts[word.lower()] += 1
115
+ for char in word:
116
+ char_counts[char] += 1
117
+ for rel in sent.rels:
118
+ rel_counts[rel] += 1
119
+
120
+ # Words
121
+ for word, count in word_counts.items():
122
+ if count >= self.min_freq and word not in self.word2idx:
123
+ self.word2idx[word] = len(self.word2idx)
124
+
125
+ # Characters
126
+ for char, count in char_counts.items():
127
+ if char not in self.char2idx:
128
+ self.char2idx[char] = len(self.char2idx)
129
+
130
+ # Relations
131
+ for rel in rel_counts:
132
+ if rel not in self.rel2idx:
133
+ idx = len(self.rel2idx)
134
+ self.rel2idx[rel] = idx
135
+ self.idx2rel[idx] = rel
136
+
137
+ def encode_word(self, word: str) -> int:
138
+ return self.word2idx.get(word.lower(), self.word2idx[self.UNK])
139
+
140
+ def encode_char(self, char: str) -> int:
141
+ return self.char2idx.get(char, self.char2idx[self.UNK])
142
+
143
+ def encode_rel(self, rel: str) -> int:
144
+ return self.rel2idx.get(rel, 0)
145
+
146
+ @property
147
+ def n_words(self) -> int:
148
+ return len(self.word2idx)
149
+
150
+ @property
151
+ def n_chars(self) -> int:
152
+ return len(self.char2idx)
153
+
154
+ @property
155
+ def n_rels(self) -> int:
156
+ return len(self.rel2idx)
157
+
158
+
159
+ class DependencyDataset(Dataset):
160
+ """Dataset for dependency parsing."""
161
+
162
+ def __init__(self, sentences: List[Sentence], vocab: Vocabulary):
163
+ self.sentences = sentences
164
+ self.vocab = vocab
165
+
166
+ def __len__(self):
167
+ return len(self.sentences)
168
+
169
+ def __getitem__(self, idx):
170
+ sent = self.sentences[idx]
171
+
172
+ # Encode words
173
+ word_ids = [self.vocab.encode_word(w) for w in sent.words]
174
+
175
+ # Encode characters
176
+ char_ids = [[self.vocab.encode_char(c) for c in w] for w in sent.words]
177
+
178
+ # Heads and relations
179
+ heads = sent.heads
180
+ rels = [self.vocab.encode_rel(r) for r in sent.rels]
181
+
182
+ return word_ids, char_ids, heads, rels
183
+
184
+
185
+ def collate_fn(batch):
186
+ """Collate function for DataLoader."""
187
+ word_ids, char_ids, heads, rels = zip(*batch)
188
+
189
+ # Get lengths
190
+ lengths = [len(w) for w in word_ids]
191
+ max_len = max(lengths)
192
+
193
+ # Pad words
194
+ word_ids_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
195
+ for i, wids in enumerate(word_ids):
196
+ word_ids_padded[i, :len(wids)] = torch.tensor(wids)
197
+
198
+ # Pad characters
199
+ max_word_len = max(max(len(c) for c in chars) for chars in char_ids)
200
+ char_ids_padded = torch.zeros(len(batch), max_len, max_word_len, dtype=torch.long)
201
+ for i, chars in enumerate(char_ids):
202
+ for j, c in enumerate(chars):
203
+ char_ids_padded[i, j, :len(c)] = torch.tensor(c)
204
+
205
+ # Pad heads
206
+ heads_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
207
+ for i, h in enumerate(heads):
208
+ heads_padded[i, :len(h)] = torch.tensor(h)
209
+
210
+ # Pad rels
211
+ rels_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
212
+ for i, r in enumerate(rels):
213
+ rels_padded[i, :len(r)] = torch.tensor(r)
214
+
215
+ # Mask
216
+ mask = torch.zeros(len(batch), max_len, dtype=torch.bool)
217
+ for i, l in enumerate(lengths):
218
+ mask[i, :l] = True
219
+
220
+ lengths = torch.tensor(lengths)
221
+
222
+ return word_ids_padded, char_ids_padded, heads_padded, rels_padded, mask, lengths
223
+
224
+
225
+ # ============================================================================
226
+ # Model
227
+ # ============================================================================
228
+
229
+ class CharLSTM(nn.Module):
230
+ """Character-level LSTM embeddings."""
231
+
232
+ def __init__(self, n_chars: int, char_dim: int = 50, hidden_dim: int = 100):
233
+ super().__init__()
234
+ self.embed = nn.Embedding(n_chars, char_dim, padding_idx=0)
235
+ self.lstm = nn.LSTM(char_dim, hidden_dim // 2, batch_first=True, bidirectional=True)
236
+ self.hidden_dim = hidden_dim
237
+
238
+ def forward(self, chars):
239
+ """
240
+ Args:
241
+ chars: (batch, seq_len, max_word_len)
242
+ Returns:
243
+ (batch, seq_len, hidden_dim)
244
+ """
245
+ batch, seq_len, max_word_len = chars.shape
246
+
247
+ # Flatten
248
+ chars_flat = chars.view(-1, max_word_len) # (batch * seq_len, max_word_len)
249
+
250
+ # Get word lengths
251
+ word_lens = (chars_flat != 0).sum(dim=1)
252
+ word_lens = word_lens.clamp(min=1)
253
+
254
+ # Embed
255
+ char_embeds = self.embed(chars_flat) # (batch * seq_len, max_word_len, char_dim)
256
+
257
+ # Pack and run LSTM
258
+ packed = pack_padded_sequence(char_embeds, word_lens.cpu(), batch_first=True, enforce_sorted=False)
259
+ _, (hidden, _) = self.lstm(packed)
260
+
261
+ # Concatenate forward and backward hidden states
262
+ hidden = torch.cat([hidden[0], hidden[1]], dim=-1) # (batch * seq_len, hidden_dim)
263
+
264
+ return hidden.view(batch, seq_len, self.hidden_dim)
265
+
266
+
267
+ class MLP(nn.Module):
268
+ """Multi-layer perceptron."""
269
+
270
+ def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33):
271
+ super().__init__()
272
+ self.linear = nn.Linear(input_dim, hidden_dim)
273
+ self.activation = nn.LeakyReLU(0.1)
274
+ self.dropout = nn.Dropout(dropout)
275
+
276
+ def forward(self, x):
277
+ return self.dropout(self.activation(self.linear(x)))
278
+
279
+
280
+ class Biaffine(nn.Module):
281
+ """Biaffine attention layer."""
282
+
283
+ def __init__(self, input_dim: int, output_dim: int = 1, bias_x: bool = True, bias_y: bool = True):
284
+ super().__init__()
285
+ self.input_dim = input_dim
286
+ self.output_dim = output_dim
287
+ self.bias_x = bias_x
288
+ self.bias_y = bias_y
289
+
290
+ self.weight = nn.Parameter(torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y))
291
+ nn.init.xavier_uniform_(self.weight)
292
+
293
+ def forward(self, x, y):
294
+ """
295
+ Args:
296
+ x: (batch, seq_len, input_dim) - dependent
297
+ y: (batch, seq_len, input_dim) - head
298
+ Returns:
299
+ (batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1
300
+ """
301
+ if self.bias_x:
302
+ x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1)
303
+ if self.bias_y:
304
+ y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1)
305
+
306
+ # (batch, seq_len, output_dim, input_dim+1)
307
+ x = torch.einsum('bxi,oij->bxoj', x, self.weight)
308
+ # (batch, seq_len, seq_len, output_dim)
309
+ scores = torch.einsum('bxoj,byj->bxyo', x, y)
310
+
311
+ if self.output_dim == 1:
312
+ scores = scores.squeeze(-1)
313
+
314
+ return scores
315
+
316
+
317
+ class BiaffineDependencyParser(nn.Module):
318
+ """Biaffine Dependency Parser (Dozat & Manning, 2017)."""
319
+
320
+ def __init__(
321
+ self,
322
+ n_words: int,
323
+ n_chars: int,
324
+ n_rels: int,
325
+ word_dim: int = 100,
326
+ char_dim: int = 50,
327
+ char_hidden: int = 100,
328
+ lstm_hidden: int = 400,
329
+ lstm_layers: int = 3,
330
+ arc_hidden: int = 500,
331
+ rel_hidden: int = 100,
332
+ dropout: float = 0.33,
333
+ ):
334
+ super().__init__()
335
+
336
+ self.word_embed = nn.Embedding(n_words, word_dim, padding_idx=0)
337
+ self.char_lstm = CharLSTM(n_chars, char_dim, char_hidden)
338
+
339
+ input_dim = word_dim + char_hidden
340
+
341
+ self.lstm = nn.LSTM(
342
+ input_dim, lstm_hidden // 2,
343
+ num_layers=lstm_layers,
344
+ batch_first=True,
345
+ bidirectional=True,
346
+ dropout=dropout if lstm_layers > 1 else 0
347
+ )
348
+
349
+ self.mlp_arc_dep = MLP(lstm_hidden, arc_hidden, dropout)
350
+ self.mlp_arc_head = MLP(lstm_hidden, arc_hidden, dropout)
351
+ self.mlp_rel_dep = MLP(lstm_hidden, rel_hidden, dropout)
352
+ self.mlp_rel_head = MLP(lstm_hidden, rel_hidden, dropout)
353
+
354
+ self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False)
355
+ self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True)
356
+
357
+ self.dropout = nn.Dropout(dropout)
358
+ self.n_rels = n_rels
359
+
360
+ def forward(self, words, chars, mask):
361
+ """
362
+ Args:
363
+ words: (batch, seq_len)
364
+ chars: (batch, seq_len, max_word_len)
365
+ mask: (batch, seq_len)
366
+ Returns:
367
+ arc_scores: (batch, seq_len, seq_len)
368
+ rel_scores: (batch, seq_len, seq_len, n_rels)
369
+ """
370
+ # Embeddings
371
+ word_embeds = self.word_embed(words)
372
+ char_embeds = self.char_lstm(chars)
373
+ embeds = torch.cat([word_embeds, char_embeds], dim=-1)
374
+ embeds = self.dropout(embeds)
375
+
376
+ # BiLSTM
377
+ lengths = mask.sum(dim=1).cpu()
378
+ packed = pack_padded_sequence(embeds, lengths, batch_first=True, enforce_sorted=False)
379
+ lstm_out, _ = self.lstm(packed)
380
+ lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True, total_length=mask.size(1))
381
+ lstm_out = self.dropout(lstm_out)
382
+
383
+ # MLP
384
+ arc_dep = self.mlp_arc_dep(lstm_out)
385
+ arc_head = self.mlp_arc_head(lstm_out)
386
+ rel_dep = self.mlp_rel_dep(lstm_out)
387
+ rel_head = self.mlp_rel_head(lstm_out)
388
+
389
+ # Biaffine
390
+ arc_scores = self.arc_attn(arc_dep, arc_head) # (batch, seq_len, seq_len)
391
+ rel_scores = self.rel_attn(rel_dep, rel_head) # (batch, seq_len, seq_len, n_rels)
392
+
393
+ return arc_scores, rel_scores
394
+
395
+ def loss(self, arc_scores, rel_scores, heads, rels, mask):
396
+ """Compute loss."""
397
+ batch_size, seq_len = mask.shape
398
+
399
+ # Arc loss
400
+ arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), float('-inf'))
401
+ arc_loss = F.cross_entropy(
402
+ arc_scores[mask].view(-1, seq_len),
403
+ heads[mask],
404
+ reduction='mean'
405
+ )
406
+
407
+ # Rel loss - select scores for gold heads
408
+ rel_scores_gold = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), heads]
409
+ rel_loss = F.cross_entropy(
410
+ rel_scores_gold[mask],
411
+ rels[mask],
412
+ reduction='mean'
413
+ )
414
+
415
+ return arc_loss + rel_loss
416
+
417
+ def decode(self, arc_scores, rel_scores, mask):
418
+ """Decode predictions."""
419
+ # Greedy decoding
420
+ arc_preds = arc_scores.argmax(dim=-1)
421
+
422
+ batch_size, seq_len = mask.shape
423
+ rel_scores_pred = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), arc_preds]
424
+ rel_preds = rel_scores_pred.argmax(dim=-1)
425
+
426
+ return arc_preds, rel_preds
427
+
428
+
429
+ # ============================================================================
430
+ # Trankit-style Transformer Parser (XLM-RoBERTa + Biaffine)
431
+ # ============================================================================
432
+
433
+ class TransformerDependencyParser(nn.Module):
434
+ """
435
+ Trankit-style dependency parser using XLM-RoBERTa.
436
+
437
+ Architecture follows Nguyen et al. 2021 EACL:
438
+ - XLM-RoBERTa encoder
439
+ - Word-level pooling (first subword)
440
+ - Biaffine attention for arc/rel prediction
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ n_rels: int,
446
+ encoder: str = "xlm-roberta-base",
447
+ arc_hidden: int = 500,
448
+ rel_hidden: int = 100,
449
+ dropout: float = 0.33,
450
+ ):
451
+ super().__init__()
452
+ from transformers import AutoModel, AutoTokenizer
453
+
454
+ self.encoder_name = encoder
455
+ self.tokenizer = AutoTokenizer.from_pretrained(encoder)
456
+ self.encoder = AutoModel.from_pretrained(encoder)
457
+ self.hidden_size = self.encoder.config.hidden_size
458
+
459
+ # Biaffine layers
460
+ self.mlp_arc_dep = MLP(self.hidden_size, arc_hidden, dropout)
461
+ self.mlp_arc_head = MLP(self.hidden_size, arc_hidden, dropout)
462
+ self.mlp_rel_dep = MLP(self.hidden_size, rel_hidden, dropout)
463
+ self.mlp_rel_head = MLP(self.hidden_size, rel_hidden, dropout)
464
+
465
+ self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False)
466
+ self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True)
467
+
468
+ self.dropout = nn.Dropout(dropout)
469
+ self.n_rels = n_rels
470
+
471
+ def encode_batch(self, sentences: List[List[str]], device):
472
+ """Tokenize and encode sentences, return word-level representations."""
473
+ batch_size = len(sentences)
474
+ max_words = max(len(s) for s in sentences)
475
+
476
+ # Tokenize each word and track subword positions
477
+ all_input_ids = []
478
+ all_attention_mask = []
479
+ word_starts = [] # (batch, max_words) -> position of first subword
480
+
481
+ for sent in sentences:
482
+ input_ids = [self.tokenizer.cls_token_id]
483
+ starts = []
484
+
485
+ for word in sent:
486
+ starts.append(len(input_ids))
487
+ tokens = self.tokenizer.encode(word, add_special_tokens=False)
488
+ input_ids.extend(tokens if tokens else [self.tokenizer.unk_token_id])
489
+
490
+ input_ids.append(self.tokenizer.sep_token_id)
491
+ all_input_ids.append(input_ids)
492
+ word_starts.append(starts)
493
+
494
+ # Pad sequences
495
+ max_len = max(len(ids) for ids in all_input_ids)
496
+ padded_ids = torch.zeros(batch_size, max_len, dtype=torch.long, device=device)
497
+ attention_mask = torch.zeros(batch_size, max_len, dtype=torch.long, device=device)
498
+
499
+ for i, ids in enumerate(all_input_ids):
500
+ padded_ids[i, :len(ids)] = torch.tensor(ids)
501
+ attention_mask[i, :len(ids)] = 1
502
+
503
+ # Encode with transformer
504
+ outputs = self.encoder(padded_ids, attention_mask=attention_mask)
505
+ hidden = outputs.last_hidden_state # (batch, seq_len, hidden)
506
+
507
+ # Extract word-level representations (first subword)
508
+ word_hidden = torch.zeros(batch_size, max_words, self.hidden_size, device=device)
509
+ word_mask = torch.zeros(batch_size, max_words, dtype=torch.bool, device=device)
510
+
511
+ for i, starts in enumerate(word_starts):
512
+ for j, pos in enumerate(starts):
513
+ word_hidden[i, j] = hidden[i, pos]
514
+ word_mask[i, j] = True
515
+
516
+ return word_hidden, word_mask
517
+
518
+ def forward(self, word_hidden, word_mask):
519
+ """Compute arc and relation scores from word representations."""
520
+ word_hidden = self.dropout(word_hidden)
521
+
522
+ # Biaffine scoring
523
+ arc_dep = self.mlp_arc_dep(word_hidden)
524
+ arc_head = self.mlp_arc_head(word_hidden)
525
+ rel_dep = self.mlp_rel_dep(word_hidden)
526
+ rel_head = self.mlp_rel_head(word_hidden)
527
+
528
+ arc_scores = self.arc_attn(arc_dep, arc_head)
529
+ rel_scores = self.rel_attn(rel_dep, rel_head)
530
+
531
+ return arc_scores, rel_scores
532
+
533
+ def loss(self, arc_scores, rel_scores, heads, rels, mask):
534
+ """Compute cross-entropy loss."""
535
+ batch_size, seq_len = mask.shape
536
+
537
+ # Arc loss
538
+ arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), float('-inf'))
539
+ arc_loss = F.cross_entropy(
540
+ arc_scores[mask].view(-1, seq_len),
541
+ heads[mask],
542
+ reduction='mean'
543
+ )
544
+
545
+ # Rel loss
546
+ rel_scores_gold = rel_scores[torch.arange(batch_size, device=mask.device).unsqueeze(1),
547
+ torch.arange(seq_len, device=mask.device), heads]
548
+ rel_loss = F.cross_entropy(
549
+ rel_scores_gold[mask],
550
+ rels[mask],
551
+ reduction='mean'
552
+ )
553
+
554
+ return arc_loss + rel_loss
555
+
556
+ def decode(self, arc_scores, rel_scores, mask):
557
+ """Greedy decoding."""
558
+ arc_preds = arc_scores.argmax(dim=-1)
559
+
560
+ batch_size, seq_len = mask.shape
561
+ rel_scores_pred = rel_scores[torch.arange(batch_size, device=mask.device).unsqueeze(1),
562
+ torch.arange(seq_len, device=mask.device), arc_preds]
563
+ rel_preds = rel_scores_pred.argmax(dim=-1)
564
+
565
+ return arc_preds, rel_preds
566
+
567
+
568
+ class TransformerDataset(Dataset):
569
+ """Dataset for transformer-based parser (stores raw sentences)."""
570
+
571
+ def __init__(self, sentences: List[Sentence], vocab):
572
+ self.sentences = sentences
573
+ self.vocab = vocab
574
+
575
+ def __len__(self):
576
+ return len(self.sentences)
577
+
578
+ def __getitem__(self, idx):
579
+ sent = self.sentences[idx]
580
+ heads = sent.heads
581
+ rels = [self.vocab.encode_rel(r) for r in sent.rels]
582
+ return sent.words, heads, rels
583
+
584
+
585
+ def transformer_collate_fn(batch):
586
+ """Collate for transformer-based parser."""
587
+ words_list, heads_list, rels_list = zip(*batch)
588
+
589
+ max_len = max(len(w) for w in words_list)
590
+ batch_size = len(batch)
591
+
592
+ # Pad heads and rels
593
+ heads_padded = torch.zeros(batch_size, max_len, dtype=torch.long)
594
+ rels_padded = torch.zeros(batch_size, max_len, dtype=torch.long)
595
+ mask = torch.zeros(batch_size, max_len, dtype=torch.bool)
596
+
597
+ for i, (h, r) in enumerate(zip(heads_list, rels_list)):
598
+ heads_padded[i, :len(h)] = torch.tensor(h)
599
+ rels_padded[i, :len(r)] = torch.tensor(r)
600
+ mask[i, :len(h)] = True
601
+
602
+ return list(words_list), heads_padded, rels_padded, mask
603
+
604
+
605
+ def evaluate_transformer(model, dataloader, device):
606
+ """Evaluate transformer-based model."""
607
+ model.eval()
608
+
609
+ total_arcs = 0
610
+ correct_arcs = 0
611
+ correct_rels = 0
612
+
613
+ with torch.no_grad():
614
+ for words_list, heads, rels, mask in dataloader:
615
+ heads = heads.to(device)
616
+ rels = rels.to(device)
617
+ mask = mask.to(device)
618
+
619
+ word_hidden, word_mask = model.encode_batch(words_list, device)
620
+ arc_scores, rel_scores = model(word_hidden, word_mask)
621
+ arc_preds, rel_preds = model.decode(arc_scores, rel_scores, word_mask)
622
+
623
+ arc_correct = (arc_preds == heads) & mask
624
+ rel_correct = (rel_preds == rels) & mask & arc_correct
625
+
626
+ total_arcs += mask.sum().item()
627
+ correct_arcs += arc_correct.sum().item()
628
+ correct_rels += rel_correct.sum().item()
629
+
630
+ uas = correct_arcs / total_arcs * 100
631
+ las = correct_rels / total_arcs * 100
632
+
633
+ return uas, las
634
+
635
+
636
+ # ============================================================================
637
+ # Training
638
+ # ============================================================================
639
+
640
+ def evaluate(model, dataloader, device):
641
+ """Evaluate model and return UAS/LAS."""
642
+ model.eval()
643
+
644
+ total_arcs = 0
645
+ correct_arcs = 0
646
+ correct_rels = 0
647
+
648
+ with torch.no_grad():
649
+ for batch in dataloader:
650
+ words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch]
651
+
652
+ arc_scores, rel_scores = model(words, chars, mask)
653
+ arc_preds, rel_preds = model.decode(arc_scores, rel_scores, mask)
654
+
655
+ # Count correct
656
+ arc_correct = (arc_preds == heads) & mask
657
+ rel_correct = (rel_preds == rels) & mask & arc_correct
658
+
659
+ total_arcs += mask.sum().item()
660
+ correct_arcs += arc_correct.sum().item()
661
+ correct_rels += rel_correct.sum().item()
662
+
663
+ uas = correct_arcs / total_arcs * 100
664
+ las = correct_rels / total_arcs * 100
665
+
666
+ return uas, las
667
+
668
+
669
+ @click.command()
670
+ @click.option('--method', type=click.Choice(['baseline', 'trankit']), default='baseline',
671
+ help='Parser method: baseline (BiLSTM) or trankit (XLM-RoBERTa)')
672
+ @click.option('--dataset', type=click.Choice(['udd1', 'ud-vtb', 'vndt']), default='udd1',
673
+ help='Dataset: udd1 (UDD-1), ud-vtb (UD Vietnamese VTB), or vndt (VnDT v1.1)')
674
+ @click.option('--encoder', default='xlm-roberta-base',
675
+ help='Transformer encoder for trankit method')
676
+ @click.option('--output', '-o', default='models/bamboo-1', help='Output directory')
677
+ @click.option('--epochs', default=100, type=int, help='Number of epochs')
678
+ @click.option('--batch-size', default=32, type=int, help='Batch size')
679
+ @click.option('--lr', default=2e-3, type=float, help='Learning rate for baseline')
680
+ @click.option('--bert-lr', default=1e-5, type=float, help='Encoder learning rate for trankit')
681
+ @click.option('--head-lr', default=1e-4, type=float, help='Head learning rate for trankit')
682
+ @click.option('--warmup-steps', default=500, type=int, help='Warmup steps for trankit')
683
+ @click.option('--lstm-hidden', default=400, type=int, help='LSTM hidden size (baseline)')
684
+ @click.option('--lstm-layers', default=3, type=int, help='LSTM layers (baseline)')
685
+ @click.option('--patience', default=10, type=int, help='Early stopping patience')
686
+ @click.option('--force-download', is_flag=True, help='Force re-download dataset')
687
+ @click.option('--data-dir', default=None, help='Custom data directory')
688
+ @click.option('--gpu-type', default='RTX_A4000', help='GPU type for cost estimation')
689
+ @click.option('--cost-interval', default=300, type=int, help='Cost report interval in seconds')
690
+ @click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging')
691
+ @click.option('--wandb-project', default='bamboo-1', help='W&B project name')
692
+ @click.option('--max-time', default=0, type=int, help='Max training time in minutes (0=unlimited)')
693
+ @click.option('--sample', default=0, type=int, help='Sample N sentences from each split (0=all)')
694
+ @click.option('--eval-every', default=1, type=int, help='Evaluate every N epochs')
695
+ @click.option('--fp16', is_flag=True, default=True, help='Use mixed precision training')
696
+ def train(method, dataset, encoder, output, epochs, batch_size, lr, bert_lr, head_lr, warmup_steps,
697
+ lstm_hidden, lstm_layers, patience, force_download, data_dir, gpu_type, cost_interval,
698
+ use_wandb, wandb_project, max_time, sample, eval_every, fp16):
699
+ """Train Bamboo-1 Vietnamese Dependency Parser."""
700
+
701
+ # Detect hardware
702
+ hardware = detect_hardware()
703
+ detected_gpu_type = hardware.get_gpu_type()
704
+
705
+ if gpu_type == "RTX_A4000":
706
+ gpu_type = detected_gpu_type
707
+
708
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
709
+ click.echo(f"Using device: {device}")
710
+ click.echo(f"Hardware: {hardware}")
711
+
712
+ # CUDA optimizations
713
+ if torch.cuda.is_available():
714
+ torch.backends.cudnn.benchmark = True
715
+ torch.backends.cuda.matmul.allow_tf32 = True
716
+ torch.backends.cudnn.allow_tf32 = True
717
+
718
+ # Mixed precision
719
+ use_amp = fp16 and torch.cuda.is_available()
720
+ scaler = torch.amp.GradScaler('cuda') if use_amp else None
721
+ if use_amp:
722
+ click.echo("Mixed precision (FP16): enabled")
723
+
724
+ # Initialize wandb
725
+ if use_wandb:
726
+ import wandb
727
+ wandb.init(
728
+ project=wandb_project,
729
+ config={
730
+ "method": method,
731
+ "dataset": dataset,
732
+ "encoder": encoder if method == "trankit" else "bilstm",
733
+ "epochs": epochs,
734
+ "batch_size": batch_size,
735
+ "lr": lr if method == "baseline" else bert_lr,
736
+ "head_lr": head_lr if method == "trankit" else None,
737
+ "lstm_hidden": lstm_hidden if method == "baseline" else None,
738
+ "lstm_layers": lstm_layers if method == "baseline" else None,
739
+ "patience": patience,
740
+ "gpu_type": gpu_type,
741
+ "hardware": hardware.to_dict(),
742
+ }
743
+ )
744
+ click.echo(f"W&B logging enabled: {wandb.run.url}")
745
+
746
+ click.echo("=" * 60)
747
+ click.echo(f"Bamboo-1: Vietnamese Dependency Parser ({method.upper()})")
748
+ click.echo("=" * 60)
749
+
750
+ # Load corpus
751
+ click.echo(f"\nLoading {dataset.upper()} corpus...")
752
+ if dataset == 'udd1':
753
+ corpus = UDD1Corpus(data_dir=data_dir, force_download=force_download)
754
+ elif dataset == 'ud-vtb':
755
+ corpus = UDVietnameseVTB(data_dir=data_dir, force_download=force_download)
756
+ else: # vndt
757
+ corpus = VnDTCorpus(data_dir=data_dir, force_download=force_download)
758
+
759
+ train_sents = read_conllu(corpus.train)
760
+ dev_sents = read_conllu(corpus.dev)
761
+ test_sents = read_conllu(corpus.test)
762
+
763
+ # Sample subset if requested
764
+ if sample > 0:
765
+ train_sents = train_sents[:sample]
766
+ dev_sents = dev_sents[:min(sample // 2, len(dev_sents))]
767
+ test_sents = test_sents[:min(sample // 2, len(test_sents))]
768
+ click.echo(f" Sampling {sample} sentences...")
769
+
770
+ click.echo(f" Train: {len(train_sents)} sentences")
771
+ click.echo(f" Dev: {len(dev_sents)} sentences")
772
+ click.echo(f" Test: {len(test_sents)} sentences")
773
+
774
+ # Build vocabulary
775
+ click.echo("\nBuilding vocabulary...")
776
+ vocab = Vocabulary(min_freq=2)
777
+ vocab.build(train_sents)
778
+ if method == "baseline":
779
+ click.echo(f" Words: {vocab.n_words}")
780
+ click.echo(f" Chars: {vocab.n_chars}")
781
+ click.echo(f" Relations: {vocab.n_rels}")
782
+
783
+ # Create datasets and model based on method
784
+ if method == "trankit":
785
+ # Trankit method: XLM-RoBERTa + Biaffine
786
+ train_dataset = TransformerDataset(train_sents, vocab)
787
+ dev_dataset = TransformerDataset(dev_sents, vocab)
788
+ test_dataset = TransformerDataset(test_sents, vocab)
789
+
790
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
791
+ collate_fn=transformer_collate_fn, num_workers=0)
792
+ dev_loader = DataLoader(dev_dataset, batch_size=batch_size,
793
+ collate_fn=transformer_collate_fn, num_workers=0)
794
+ test_loader = DataLoader(test_dataset, batch_size=batch_size,
795
+ collate_fn=transformer_collate_fn, num_workers=0)
796
+
797
+ click.echo(f"\nInitializing model with {encoder}...")
798
+ model = TransformerDependencyParser(
799
+ n_rels=vocab.n_rels,
800
+ encoder=encoder,
801
+ ).to(device)
802
+
803
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
804
+ encoder_params = sum(p.numel() for p in model.encoder.parameters())
805
+ head_params = n_params - encoder_params
806
+ click.echo(f" Total parameters: {n_params:,}")
807
+ click.echo(f" Encoder parameters: {encoder_params:,}")
808
+ click.echo(f" Head parameters: {head_params:,}")
809
+
810
+ # Differential learning rates
811
+ encoder_params_list = list(model.encoder.parameters())
812
+ head_params_list = [p for n, p in model.named_parameters() if 'encoder' not in n]
813
+ optimizer = AdamW([
814
+ {'params': encoder_params_list, 'lr': bert_lr},
815
+ {'params': head_params_list, 'lr': head_lr},
816
+ ], weight_decay=0.01)
817
+
818
+ # Learning rate scheduler with warmup
819
+ total_steps = len(train_loader) * epochs
820
+ def lr_lambda(step):
821
+ if step < warmup_steps:
822
+ return step / warmup_steps
823
+ return max(0.0, (total_steps - step) / (total_steps - warmup_steps))
824
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
825
+
826
+ eval_fn = evaluate_transformer
827
+ else:
828
+ # Baseline method: BiLSTM + Biaffine
829
+ train_dataset = DependencyDataset(train_sents, vocab)
830
+ dev_dataset = DependencyDataset(dev_sents, vocab)
831
+ test_dataset = DependencyDataset(test_sents, vocab)
832
+
833
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
834
+ dev_loader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=collate_fn)
835
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn)
836
+
837
+ click.echo("\nInitializing BiLSTM model...")
838
+ model = BiaffineDependencyParser(
839
+ n_words=vocab.n_words,
840
+ n_chars=vocab.n_chars,
841
+ n_rels=vocab.n_rels,
842
+ lstm_hidden=lstm_hidden,
843
+ lstm_layers=lstm_layers,
844
+ ).to(device)
845
+
846
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
847
+ click.echo(f" Parameters: {n_params:,}")
848
+
849
+ optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.9))
850
+ scheduler = ExponentialLR(optimizer, gamma=0.75 ** (1 / 5000))
851
+
852
+ eval_fn = evaluate
853
+
854
+ # Training
855
+ click.echo(f"\nTraining for {epochs} epochs...")
856
+ if max_time > 0:
857
+ click.echo(f"Time limit: {max_time} minutes")
858
+ output_path = Path(output)
859
+ output_path.mkdir(parents=True, exist_ok=True)
860
+
861
+ # Cost tracking
862
+ cost_tracker = CostTracker(gpu_type=gpu_type)
863
+ cost_tracker.report_interval = cost_interval
864
+ cost_tracker.start()
865
+ click.echo(f"Cost tracking: {gpu_type} @ ${cost_tracker.hourly_rate}/hr")
866
+
867
+ best_las = -1
868
+ no_improve = 0
869
+ time_limit_seconds = max_time * 60 if max_time > 0 else float('inf')
870
+
871
+ for epoch in range(1, epochs + 1):
872
+ # Check time limit
873
+ if cost_tracker.elapsed_seconds() >= time_limit_seconds:
874
+ click.echo(f"\nTime limit reached ({max_time} minutes)")
875
+ break
876
+ model.train()
877
+ total_loss = 0
878
+
879
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}", leave=False)
880
+ for batch in pbar:
881
+ optimizer.zero_grad()
882
+
883
+ if method == "trankit":
884
+ words_list, heads, rels, mask = batch
885
+ heads = heads.to(device)
886
+ rels = rels.to(device)
887
+ mask = mask.to(device)
888
+
889
+ with torch.amp.autocast('cuda', enabled=use_amp):
890
+ word_hidden, word_mask = model.encode_batch(words_list, device)
891
+ arc_scores, rel_scores = model(word_hidden, word_mask)
892
+ loss = model.loss(arc_scores, rel_scores, heads, rels, mask)
893
+ else:
894
+ words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch]
895
+ arc_scores, rel_scores = model(words, chars, mask)
896
+ loss = model.loss(arc_scores, rel_scores, heads, rels, mask)
897
+
898
+ if use_amp and scaler:
899
+ scaler.scale(loss).backward()
900
+ scaler.unscale_(optimizer)
901
+ nn.utils.clip_grad_norm_(model.parameters(), 5.0)
902
+ scaler.step(optimizer)
903
+ scaler.update()
904
+ else:
905
+ loss.backward()
906
+ nn.utils.clip_grad_norm_(model.parameters(), 5.0)
907
+ optimizer.step()
908
+
909
+ scheduler.step()
910
+ total_loss += loss.item()
911
+ pbar.set_postfix({'loss': f'{loss.item():.4f}'})
912
+
913
+ # Evaluate (skip if not eval epoch, unless last epoch)
914
+ if epoch % eval_every != 0 and epoch != epochs:
915
+ avg_loss = total_loss / len(train_loader)
916
+ current_lr = optimizer.param_groups[0]['lr']
917
+ click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}")
918
+ continue
919
+
920
+ dev_uas, dev_las = eval_fn(model, dev_loader, device)
921
+
922
+ # Cost update
923
+ progress = epoch / epochs
924
+ current_cost = cost_tracker.current_cost()
925
+ estimated_total_cost = cost_tracker.estimate_total_cost(progress)
926
+ elapsed_minutes = cost_tracker.elapsed_seconds() / 60
927
+
928
+ cost_status = cost_tracker.update(epoch, epochs)
929
+ if cost_status:
930
+ click.echo(f" [{cost_status}]")
931
+
932
+ avg_loss = total_loss / len(train_loader)
933
+ click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
934
+ f"Dev UAS: {dev_uas:.2f}% | Dev LAS: {dev_las:.2f}%")
935
+
936
+ # Log to wandb
937
+ if use_wandb:
938
+ wandb.log({
939
+ "epoch": epoch,
940
+ "train/loss": avg_loss,
941
+ "dev/uas": dev_uas,
942
+ "dev/las": dev_las,
943
+ "cost/current_usd": current_cost,
944
+ "cost/estimated_total_usd": estimated_total_cost,
945
+ "cost/elapsed_minutes": elapsed_minutes,
946
+ })
947
+
948
+ # Save best model
949
+ if dev_las >= best_las:
950
+ best_las = dev_las
951
+ no_improve = 0
952
+ if method == "trankit":
953
+ config = {
954
+ 'method': 'trankit',
955
+ 'encoder': encoder,
956
+ 'n_rels': vocab.n_rels,
957
+ }
958
+ else:
959
+ config = {
960
+ 'method': 'baseline',
961
+ 'n_words': vocab.n_words,
962
+ 'n_chars': vocab.n_chars,
963
+ 'n_rels': vocab.n_rels,
964
+ 'lstm_hidden': lstm_hidden,
965
+ 'lstm_layers': lstm_layers,
966
+ }
967
+ torch.save({
968
+ 'model': model.state_dict(),
969
+ 'vocab': vocab,
970
+ 'config': config,
971
+ }, output_path / 'model.pt')
972
+ click.echo(f" -> Saved best model (LAS: {best_las:.2f}%)")
973
+ else:
974
+ no_improve += 1
975
+ if no_improve >= patience:
976
+ click.echo(f"\nEarly stopping after {patience} epochs without improvement")
977
+ break
978
+
979
+ # Final evaluation
980
+ click.echo("\nLoading best model for final evaluation...")
981
+ checkpoint = torch.load(output_path / 'model.pt', weights_only=False)
982
+ model.load_state_dict(checkpoint['model'])
983
+
984
+ test_uas, test_las = eval_fn(model, test_loader, device)
985
+ click.echo(f"\nTest Results:")
986
+ click.echo(f" UAS: {test_uas:.2f}%")
987
+ click.echo(f" LAS: {test_las:.2f}%")
988
+
989
+ click.echo(f"\nModel saved to: {output_path}")
990
+
991
+ # Final cost summary
992
+ final_cost = cost_tracker.current_cost()
993
+ click.echo(f"\n{cost_tracker.summary(epoch, epochs)}")
994
+
995
+ # Log final metrics to wandb
996
+ if use_wandb:
997
+ wandb.log({
998
+ "test/uas": test_uas,
999
+ "test/las": test_las,
1000
+ "cost/final_usd": final_cost,
1001
+ })
1002
+ wandb.finish()
1003
+
1004
+
1005
+ if __name__ == '__main__':
1006
+ train()