rain1024 Claude Opus 4.5 commited on
Commit
b39f0e3
·
1 Parent(s): 3ca3932

Add PhoBERT-based dependency parser for Trankit reproduction

Browse files

- bamboo1/models/: PhoBERT + Biaffine parser with MST decoding
- bamboo1/ud_corpus.py: UD Vietnamese VTB dataset loader
- scripts/train_phobert.py: Training with FP16, gradient accumulation
- scripts/run_phobert_runpod.sh: RunPod automation for cloud training
- scripts/runpod_setup.py: launch-fast command for H100 (<5 min training)

Target: Reproduce Trankit benchmark (70.96% UAS / 64.76% LAS on UD-VTB)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

bamboo1/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bamboo-1 Model implementations.
3
+
4
+ This module contains the transformer-based dependency parser using PhoBERT.
5
+ """
6
+
7
+ from bamboo1.models.transformer_parser import PhoBERTDependencyParser
8
+ from bamboo1.models.mst import mst_decode
9
+
10
+ __all__ = ["PhoBERTDependencyParser", "mst_decode"]
bamboo1/models/mst.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimum Spanning Tree (MST) decoding for dependency parsing.
3
+
4
+ Implements the Chu-Liu/Edmonds algorithm for finding the maximum spanning
5
+ arborescence, which ensures valid dependency tree structures.
6
+
7
+ Reference:
8
+ - Edmonds, J. (1967). Optimum branchings.
9
+ - Chu, Y.J. & Liu, T.H. (1965). On the shortest arborescence of a directed graph.
10
+ """
11
+
12
+ import numpy as np
13
+ from typing import List, Tuple, Optional
14
+
15
+
16
+ def mst_decode(scores: np.ndarray, length: Optional[int] = None) -> np.ndarray:
17
+ """
18
+ Decode the maximum spanning arborescence using Chu-Liu/Edmonds algorithm.
19
+
20
+ Args:
21
+ scores: Arc scores matrix of shape (seq_len, seq_len) where scores[i, j]
22
+ is the score for token i having token j as its head.
23
+ Index 0 is the root node.
24
+ length: Actual sequence length (excluding padding). If None, uses full matrix.
25
+
26
+ Returns:
27
+ heads: Array of shape (seq_len,) containing head indices for each token.
28
+ heads[0] is always 0 (root has no head).
29
+ """
30
+ if length is None:
31
+ length = scores.shape[0]
32
+
33
+ # Work on the actual tokens (excluding padding)
34
+ scores = scores[:length, :length].copy()
35
+
36
+ # Token 0 is root - root cannot have a head other than itself
37
+ scores[0, :] = float('-inf')
38
+ scores[0, 0] = 0
39
+
40
+ # No self-loops (except for root)
41
+ np.fill_diagonal(scores[1:, 1:], float('-inf'))
42
+
43
+ heads = _chu_liu_edmonds(scores)
44
+
45
+ return heads
46
+
47
+
48
+ def _chu_liu_edmonds(scores: np.ndarray) -> np.ndarray:
49
+ """
50
+ Chu-Liu/Edmonds algorithm for maximum spanning arborescence.
51
+
52
+ Args:
53
+ scores: Arc scores matrix of shape (n, n)
54
+
55
+ Returns:
56
+ heads: Array of head indices
57
+ """
58
+ n = scores.shape[0]
59
+
60
+ # Step 1: For each node (except root), select the maximum incoming arc
61
+ heads = np.argmax(scores, axis=1)
62
+ heads[0] = 0 # Root points to itself
63
+
64
+ # Step 2: Check for cycles
65
+ cycle = _find_cycle(heads)
66
+
67
+ if cycle is None:
68
+ # No cycle - we have a valid tree
69
+ return heads
70
+
71
+ # Step 3: Contract the cycle and recurse
72
+ cycle_set = set(cycle)
73
+ cycle_head = cycle[0] # Representative node for the contracted cycle
74
+
75
+ # Create mapping from old indices to new indices
76
+ # Cycle nodes (except representative) are removed
77
+ old_to_new = {}
78
+ new_to_old = {}
79
+ new_idx = 0
80
+ for i in range(n):
81
+ if i not in cycle_set or i == cycle_head:
82
+ old_to_new[i] = new_idx
83
+ new_to_old[new_idx] = i
84
+ new_idx += 1
85
+
86
+ # Number of nodes in contracted graph
87
+ n_contracted = new_idx
88
+
89
+ # Build contracted graph
90
+ contracted_scores = np.full((n_contracted, n_contracted), float('-inf'))
91
+
92
+ for i in range(n):
93
+ if i in cycle_set and i != cycle_head:
94
+ continue
95
+ new_i = old_to_new[i]
96
+
97
+ for j in range(n):
98
+ if j in cycle_set and j != cycle_head:
99
+ continue
100
+ new_j = old_to_new[j]
101
+
102
+ if new_i == new_j:
103
+ continue
104
+
105
+ if i == cycle_head:
106
+ # Incoming edges to cycle: find best way to enter cycle
107
+ if j not in cycle_set:
108
+ # Edge from outside to cycle
109
+ best_score = float('-inf')
110
+ for c in cycle:
111
+ # Score of edge j->c minus score of edge heads[c]->c
112
+ # (because we're replacing that edge)
113
+ score = scores[c, j] - scores[c, heads[c]]
114
+ if score > best_score:
115
+ best_score = score
116
+ contracted_scores[new_i, new_j] = best_score
117
+ else:
118
+ contracted_scores[new_i, new_j] = float('-inf')
119
+ elif j == cycle_head:
120
+ # Outgoing edges from cycle
121
+ if i not in cycle_set:
122
+ best_score = float('-inf')
123
+ for c in cycle:
124
+ if scores[i, c] > best_score:
125
+ best_score = scores[i, c]
126
+ contracted_scores[new_i, new_j] = best_score
127
+ else:
128
+ # Edge not involving cycle
129
+ contracted_scores[new_i, new_j] = scores[i, j]
130
+
131
+ # Recurse on contracted graph
132
+ contracted_heads = _chu_liu_edmonds(contracted_scores)
133
+
134
+ # Step 4: Expand the solution
135
+ final_heads = np.zeros(n, dtype=np.int64)
136
+
137
+ # First, set heads for non-cycle nodes
138
+ for new_i in range(n_contracted):
139
+ old_i = new_to_old[new_i]
140
+ if old_i != cycle_head:
141
+ new_head = contracted_heads[new_i]
142
+ old_head = new_to_old[new_head]
143
+
144
+ # If head is cycle representative, find which cycle node is actual head
145
+ if old_head == cycle_head:
146
+ best_score = float('-inf')
147
+ best_c = cycle_head
148
+ for c in cycle:
149
+ if scores[old_i, c] > best_score:
150
+ best_score = scores[old_i, c]
151
+ best_c = c
152
+ final_heads[old_i] = best_c
153
+ else:
154
+ final_heads[old_i] = old_head
155
+
156
+ # Find which node in cycle is entered from outside
157
+ new_cycle_head = contracted_heads[old_to_new[cycle_head]]
158
+ if new_cycle_head != old_to_new[cycle_head]: # Cycle has incoming edge from outside
159
+ outside_head = new_to_old[new_cycle_head]
160
+
161
+ # Find which cycle node is entered
162
+ best_score = float('-inf')
163
+ entered_node = cycle_head
164
+ for c in cycle:
165
+ score = scores[c, outside_head] - scores[c, heads[c]]
166
+ if score > best_score:
167
+ best_score = score
168
+ entered_node = c
169
+
170
+ # Set heads within cycle, breaking at entered node
171
+ for c in cycle:
172
+ if c == entered_node:
173
+ final_heads[c] = outside_head
174
+ else:
175
+ final_heads[c] = heads[c]
176
+ else:
177
+ # Cycle contains root (shouldn't happen in valid dependency parsing)
178
+ for c in cycle:
179
+ final_heads[c] = heads[c]
180
+
181
+ final_heads[0] = 0 # Root
182
+
183
+ return final_heads
184
+
185
+
186
+ def _find_cycle(heads: np.ndarray) -> Optional[List[int]]:
187
+ """
188
+ Find a cycle in the given head assignments.
189
+
190
+ Args:
191
+ heads: Array of head indices
192
+
193
+ Returns:
194
+ List of node indices forming a cycle, or None if no cycle exists
195
+ """
196
+ n = len(heads)
197
+ visited = np.zeros(n, dtype=np.int32)
198
+
199
+ for start in range(1, n): # Skip root
200
+ if visited[start] == 2: # Already processed
201
+ continue
202
+
203
+ path = []
204
+ node = start
205
+
206
+ while visited[node] == 0:
207
+ visited[node] = 1 # Mark as in current path
208
+ path.append(node)
209
+ node = heads[node]
210
+
211
+ if node == 0: # Reached root
212
+ break
213
+
214
+ if visited[node] == 1:
215
+ # Found cycle - extract it
216
+ cycle_start = path.index(node)
217
+ cycle = path[cycle_start:]
218
+ return cycle
219
+
220
+ # Mark all nodes in path as fully processed
221
+ for p in path:
222
+ visited[p] = 2
223
+
224
+ return None
225
+
226
+
227
+ def batch_mst_decode(scores: np.ndarray, lengths: np.ndarray) -> np.ndarray:
228
+ """
229
+ Batch version of MST decoding.
230
+
231
+ Args:
232
+ scores: Arc scores of shape (batch, seq_len, seq_len)
233
+ lengths: Sequence lengths of shape (batch,)
234
+
235
+ Returns:
236
+ heads: Head indices of shape (batch, seq_len)
237
+ """
238
+ batch_size, seq_len, _ = scores.shape
239
+ heads = np.zeros((batch_size, seq_len), dtype=np.int64)
240
+
241
+ for i in range(batch_size):
242
+ heads[i, :lengths[i]] = mst_decode(scores[i], lengths[i])
243
+
244
+ return heads
bamboo1/models/transformer_parser.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer-based Dependency Parser using PhoBERT.
3
+
4
+ This module implements a Biaffine dependency parser with PhoBERT as the encoder,
5
+ following the Trankit approach but using Vietnamese-specific PhoBERT.
6
+
7
+ Architecture:
8
+ Input → PhoBERT → Word-level pooling → MLP projections → Biaffine attention → MST decoding
9
+
10
+ Reference:
11
+ - Dozat & Manning (2017): Deep Biaffine Attention for Neural Dependency Parsing
12
+ - Nguyen & Nguyen (2020): PhoBERT: Pre-trained language models for Vietnamese
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from typing import List, Tuple, Optional, Dict, Any
19
+ import numpy as np
20
+
21
+ from bamboo1.models.mst import mst_decode, batch_mst_decode
22
+
23
+
24
+ class MLP(nn.Module):
25
+ """Multi-layer perceptron for biaffine scoring."""
26
+
27
+ def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33):
28
+ super().__init__()
29
+ self.linear = nn.Linear(input_dim, hidden_dim)
30
+ self.activation = nn.LeakyReLU(0.1)
31
+ self.dropout = nn.Dropout(dropout)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ return self.dropout(self.activation(self.linear(x)))
35
+
36
+
37
+ class Biaffine(nn.Module):
38
+ """Biaffine attention layer for dependency scoring."""
39
+
40
+ def __init__(
41
+ self,
42
+ input_dim: int,
43
+ output_dim: int = 1,
44
+ bias_x: bool = True,
45
+ bias_y: bool = True
46
+ ):
47
+ super().__init__()
48
+ self.input_dim = input_dim
49
+ self.output_dim = output_dim
50
+ self.bias_x = bias_x
51
+ self.bias_y = bias_y
52
+
53
+ self.weight = nn.Parameter(
54
+ torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y)
55
+ )
56
+ nn.init.xavier_uniform_(self.weight)
57
+
58
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
59
+ """
60
+ Args:
61
+ x: (batch, seq_len, input_dim) - dependent representations
62
+ y: (batch, seq_len, input_dim) - head representations
63
+
64
+ Returns:
65
+ scores: (batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1
66
+ """
67
+ if self.bias_x:
68
+ x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1)
69
+ if self.bias_y:
70
+ y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1)
71
+
72
+ # (batch, seq_len, output_dim, input_dim+1)
73
+ x = torch.einsum('bxi,oij->bxoj', x, self.weight)
74
+ # (batch, seq_len, seq_len, output_dim)
75
+ scores = torch.einsum('bxoj,byj->bxyo', x, y)
76
+
77
+ if self.output_dim == 1:
78
+ scores = scores.squeeze(-1)
79
+
80
+ return scores
81
+
82
+
83
+ class PhoBERTDependencyParser(nn.Module):
84
+ """
85
+ PhoBERT-based Biaffine Dependency Parser.
86
+
87
+ Uses PhoBERT as encoder with first-subword pooling for word alignment,
88
+ followed by biaffine attention for arc and relation prediction.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ encoder_name: str = "vinai/phobert-base",
94
+ n_rels: int = 50,
95
+ arc_hidden: int = 500,
96
+ rel_hidden: int = 100,
97
+ dropout: float = 0.33,
98
+ use_mst: bool = True,
99
+ ):
100
+ """
101
+ Args:
102
+ encoder_name: HuggingFace model name for PhoBERT
103
+ n_rels: Number of dependency relations
104
+ arc_hidden: Hidden dimension for arc MLPs
105
+ rel_hidden: Hidden dimension for relation MLPs
106
+ dropout: Dropout rate
107
+ use_mst: Use MST decoding (True) or greedy decoding (False)
108
+ """
109
+ super().__init__()
110
+
111
+ from transformers import AutoModel, AutoTokenizer
112
+
113
+ self.encoder_name = encoder_name
114
+ self.n_rels = n_rels
115
+ self.use_mst = use_mst
116
+
117
+ # Load PhoBERT encoder
118
+ self.encoder = AutoModel.from_pretrained(encoder_name)
119
+ self.tokenizer = AutoTokenizer.from_pretrained(encoder_name)
120
+ self.hidden_size = self.encoder.config.hidden_size # 768 for phobert-base
121
+
122
+ # Dropout
123
+ self.dropout = nn.Dropout(dropout)
124
+
125
+ # MLP projections
126
+ self.mlp_arc_dep = MLP(self.hidden_size, arc_hidden, dropout)
127
+ self.mlp_arc_head = MLP(self.hidden_size, arc_hidden, dropout)
128
+ self.mlp_rel_dep = MLP(self.hidden_size, rel_hidden, dropout)
129
+ self.mlp_rel_head = MLP(self.hidden_size, rel_hidden, dropout)
130
+
131
+ # Biaffine attention
132
+ self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False)
133
+ self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True)
134
+
135
+ def _get_word_embeddings(
136
+ self,
137
+ input_ids: torch.Tensor,
138
+ attention_mask: torch.Tensor,
139
+ word_starts: torch.Tensor,
140
+ word_mask: torch.Tensor,
141
+ ) -> torch.Tensor:
142
+ """
143
+ Get word-level embeddings from subword encoder output.
144
+
145
+ Uses first-subword pooling strategy: each word is represented by
146
+ the embedding of its first subword token.
147
+
148
+ Args:
149
+ input_ids: (batch, subword_seq_len) - Subword token IDs
150
+ attention_mask: (batch, subword_seq_len) - Attention mask for subwords
151
+ word_starts: (batch, word_seq_len) - Indices of first subword for each word
152
+ word_mask: (batch, word_seq_len) - Mask for actual words
153
+
154
+ Returns:
155
+ word_embeddings: (batch, word_seq_len, hidden_size)
156
+ """
157
+ # Get encoder output
158
+ encoder_output = self.encoder(
159
+ input_ids=input_ids,
160
+ attention_mask=attention_mask,
161
+ return_dict=True
162
+ )
163
+ hidden_states = encoder_output.last_hidden_state # (batch, subword_seq_len, hidden)
164
+
165
+ # Apply dropout
166
+ hidden_states = self.dropout(hidden_states)
167
+
168
+ # Extract word embeddings using first-subword indices
169
+ batch_size, word_seq_len = word_starts.shape
170
+
171
+ # Gather word embeddings
172
+ # word_starts: (batch, word_seq_len) -> (batch, word_seq_len, hidden)
173
+ word_embeddings = torch.gather(
174
+ hidden_states,
175
+ dim=1,
176
+ index=word_starts.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
177
+ )
178
+
179
+ return word_embeddings
180
+
181
+ def forward(
182
+ self,
183
+ input_ids: torch.Tensor,
184
+ attention_mask: torch.Tensor,
185
+ word_starts: torch.Tensor,
186
+ word_mask: torch.Tensor,
187
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
188
+ """
189
+ Forward pass computing arc and relation scores.
190
+
191
+ Args:
192
+ input_ids: (batch, subword_seq_len) - Subword token IDs
193
+ attention_mask: (batch, subword_seq_len) - Attention mask for subwords
194
+ word_starts: (batch, word_seq_len) - Indices of first subword for each word
195
+ word_mask: (batch, word_seq_len) - Mask for actual words
196
+
197
+ Returns:
198
+ arc_scores: (batch, word_seq_len, word_seq_len) - Arc scores
199
+ rel_scores: (batch, word_seq_len, word_seq_len, n_rels) - Relation scores
200
+ """
201
+ # Get word-level embeddings
202
+ word_embeddings = self._get_word_embeddings(
203
+ input_ids, attention_mask, word_starts, word_mask
204
+ )
205
+
206
+ # MLP projections
207
+ arc_dep = self.mlp_arc_dep(word_embeddings)
208
+ arc_head = self.mlp_arc_head(word_embeddings)
209
+ rel_dep = self.mlp_rel_dep(word_embeddings)
210
+ rel_head = self.mlp_rel_head(word_embeddings)
211
+
212
+ # Biaffine attention
213
+ arc_scores = self.arc_attn(arc_dep, arc_head) # (batch, seq, seq)
214
+ rel_scores = self.rel_attn(rel_dep, rel_head) # (batch, seq, seq, n_rels)
215
+
216
+ return arc_scores, rel_scores
217
+
218
+ def loss(
219
+ self,
220
+ arc_scores: torch.Tensor,
221
+ rel_scores: torch.Tensor,
222
+ heads: torch.Tensor,
223
+ rels: torch.Tensor,
224
+ mask: torch.Tensor,
225
+ ) -> torch.Tensor:
226
+ """
227
+ Compute cross-entropy loss for arcs and relations.
228
+
229
+ Args:
230
+ arc_scores: (batch, seq_len, seq_len) - Arc scores
231
+ rel_scores: (batch, seq_len, seq_len, n_rels) - Relation scores
232
+ heads: (batch, seq_len) - Gold head indices
233
+ rels: (batch, seq_len) - Gold relation indices
234
+ mask: (batch, seq_len) - Token mask (1 for real tokens, 0 for padding)
235
+
236
+ Returns:
237
+ Total loss (arc_loss + rel_loss)
238
+ """
239
+ batch_size, seq_len = mask.shape
240
+
241
+ # Mask invalid positions
242
+ arc_scores_masked = arc_scores.clone()
243
+ arc_scores_masked = arc_scores_masked.masked_fill(~mask.unsqueeze(2), float('-inf'))
244
+
245
+ # Arc loss: cross-entropy over possible heads
246
+ arc_loss = F.cross_entropy(
247
+ arc_scores_masked[mask].view(-1, seq_len),
248
+ heads[mask],
249
+ reduction='mean'
250
+ )
251
+
252
+ # Relation loss: cross-entropy conditioned on gold heads
253
+ batch_indices = torch.arange(batch_size, device=rel_scores.device).unsqueeze(1)
254
+ seq_indices = torch.arange(seq_len, device=rel_scores.device)
255
+ rel_scores_gold = rel_scores[batch_indices, seq_indices, heads]
256
+
257
+ rel_loss = F.cross_entropy(
258
+ rel_scores_gold[mask],
259
+ rels[mask],
260
+ reduction='mean'
261
+ )
262
+
263
+ return arc_loss + rel_loss
264
+
265
+ def decode(
266
+ self,
267
+ arc_scores: torch.Tensor,
268
+ rel_scores: torch.Tensor,
269
+ mask: torch.Tensor,
270
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
271
+ """
272
+ Decode predictions using MST or greedy decoding.
273
+
274
+ Args:
275
+ arc_scores: (batch, seq_len, seq_len) - Arc scores
276
+ rel_scores: (batch, seq_len, seq_len, n_rels) - Relation scores
277
+ mask: (batch, seq_len) - Token mask
278
+
279
+ Returns:
280
+ arc_preds: (batch, seq_len) - Predicted head indices
281
+ rel_preds: (batch, seq_len) - Predicted relation indices
282
+ """
283
+ batch_size, seq_len = mask.shape
284
+ device = arc_scores.device
285
+
286
+ if self.use_mst:
287
+ # MST decoding for valid tree structure
288
+ lengths = mask.sum(dim=1).cpu().numpy()
289
+ arc_scores_np = arc_scores.cpu().numpy()
290
+ arc_preds_np = batch_mst_decode(arc_scores_np, lengths)
291
+ arc_preds = torch.from_numpy(arc_preds_np).to(device)
292
+ else:
293
+ # Greedy decoding
294
+ arc_preds = arc_scores.argmax(dim=-1)
295
+
296
+ # Get relation predictions for predicted heads
297
+ batch_indices = torch.arange(batch_size, device=device).unsqueeze(1)
298
+ seq_indices = torch.arange(seq_len, device=device)
299
+ rel_scores_pred = rel_scores[batch_indices, seq_indices, arc_preds]
300
+ rel_preds = rel_scores_pred.argmax(dim=-1)
301
+
302
+ return arc_preds, rel_preds
303
+
304
+ def predict(
305
+ self,
306
+ words: List[str],
307
+ return_probs: bool = False,
308
+ ) -> List[Tuple[str, int, str]]:
309
+ """
310
+ Predict dependencies for a single sentence.
311
+
312
+ Args:
313
+ words: List of words (pre-tokenized)
314
+ return_probs: Whether to return probability scores
315
+
316
+ Returns:
317
+ List of (word, head, deprel) tuples
318
+ """
319
+ self.eval()
320
+ device = next(self.parameters()).device
321
+
322
+ # Tokenize with word boundary tracking
323
+ encoded = self.tokenize_with_alignment([words])
324
+
325
+ # Move to device
326
+ input_ids = encoded['input_ids'].to(device)
327
+ attention_mask = encoded['attention_mask'].to(device)
328
+ word_starts = encoded['word_starts'].to(device)
329
+ word_mask = encoded['word_mask'].to(device)
330
+
331
+ with torch.no_grad():
332
+ arc_scores, rel_scores = self.forward(
333
+ input_ids, attention_mask, word_starts, word_mask
334
+ )
335
+ arc_preds, rel_preds = self.decode(arc_scores, rel_scores, word_mask)
336
+
337
+ # Convert to list of tuples
338
+ arc_preds = arc_preds[0].cpu().tolist()
339
+ rel_preds = rel_preds[0].cpu().tolist()
340
+
341
+ results = []
342
+ for i, word in enumerate(words):
343
+ head = arc_preds[i]
344
+ rel_idx = rel_preds[i]
345
+ rel = self.idx2rel.get(rel_idx, "dep")
346
+ results.append((word, head, rel))
347
+
348
+ return results
349
+
350
+ def tokenize_with_alignment(
351
+ self,
352
+ sentences: List[List[str]],
353
+ max_length: int = 256,
354
+ ) -> Dict[str, torch.Tensor]:
355
+ """
356
+ Tokenize sentences and track word-subword alignment.
357
+
358
+ Args:
359
+ sentences: List of sentences, where each sentence is a list of words
360
+ max_length: Maximum subword sequence length
361
+
362
+ Returns:
363
+ Dictionary with input_ids, attention_mask, word_starts, word_mask
364
+ """
365
+ batch_input_ids = []
366
+ batch_attention_mask = []
367
+ batch_word_starts = []
368
+ batch_word_mask = []
369
+
370
+ for words in sentences:
371
+ # Tokenize each word separately to track boundaries
372
+ word_starts = []
373
+ subword_ids = [self.tokenizer.cls_token_id]
374
+
375
+ for word in words:
376
+ word_starts.append(len(subword_ids))
377
+ word_tokens = self.tokenizer.encode(word, add_special_tokens=False)
378
+ subword_ids.extend(word_tokens)
379
+
380
+ subword_ids.append(self.tokenizer.sep_token_id)
381
+
382
+ # Truncate if needed
383
+ if len(subword_ids) > max_length:
384
+ subword_ids = subword_ids[:max_length-1] + [self.tokenizer.sep_token_id]
385
+ # Truncate word_starts that go beyond
386
+ word_starts = [ws for ws in word_starts if ws < max_length - 1]
387
+
388
+ attention_mask = [1] * len(subword_ids)
389
+
390
+ batch_input_ids.append(subword_ids)
391
+ batch_attention_mask.append(attention_mask)
392
+ batch_word_starts.append(word_starts)
393
+ batch_word_mask.append([1] * len(word_starts))
394
+
395
+ # Pad sequences
396
+ max_subword_len = max(len(ids) for ids in batch_input_ids)
397
+ max_word_len = max(len(ws) for ws in batch_word_starts)
398
+
399
+ padded_input_ids = []
400
+ padded_attention_mask = []
401
+ padded_word_starts = []
402
+ padded_word_mask = []
403
+
404
+ for i in range(len(sentences)):
405
+ # Pad subwords
406
+ pad_len = max_subword_len - len(batch_input_ids[i])
407
+ padded_input_ids.append(
408
+ batch_input_ids[i] + [self.tokenizer.pad_token_id] * pad_len
409
+ )
410
+ padded_attention_mask.append(
411
+ batch_attention_mask[i] + [0] * pad_len
412
+ )
413
+
414
+ # Pad words
415
+ word_pad_len = max_word_len - len(batch_word_starts[i])
416
+ # Use 0 for padding word_starts (points to CLS token, but masked)
417
+ padded_word_starts.append(
418
+ batch_word_starts[i] + [0] * word_pad_len
419
+ )
420
+ padded_word_mask.append(
421
+ batch_word_mask[i] + [0] * word_pad_len
422
+ )
423
+
424
+ return {
425
+ 'input_ids': torch.tensor(padded_input_ids, dtype=torch.long),
426
+ 'attention_mask': torch.tensor(padded_attention_mask, dtype=torch.long),
427
+ 'word_starts': torch.tensor(padded_word_starts, dtype=torch.long),
428
+ 'word_mask': torch.tensor(padded_word_mask, dtype=torch.bool),
429
+ }
430
+
431
+ def save(self, path: str, vocab: Optional[Dict] = None):
432
+ """
433
+ Save model checkpoint.
434
+
435
+ Args:
436
+ path: Directory path to save the model
437
+ vocab: Vocabulary dict with rel2idx and idx2rel mappings
438
+ """
439
+ import os
440
+ os.makedirs(path, exist_ok=True)
441
+
442
+ # Save model state
443
+ checkpoint = {
444
+ 'model_state_dict': self.state_dict(),
445
+ 'config': {
446
+ 'encoder_name': self.encoder_name,
447
+ 'n_rels': self.n_rels,
448
+ 'arc_hidden': self.mlp_arc_dep.linear.out_features,
449
+ 'rel_hidden': self.mlp_rel_dep.linear.out_features,
450
+ 'dropout': self.dropout.p,
451
+ 'use_mst': self.use_mst,
452
+ },
453
+ }
454
+
455
+ if vocab is not None:
456
+ checkpoint['vocab'] = vocab
457
+
458
+ torch.save(checkpoint, os.path.join(path, 'model.pt'))
459
+
460
+ # Save tokenizer
461
+ self.tokenizer.save_pretrained(path)
462
+
463
+ @classmethod
464
+ def load(cls, path: str, device: str = 'cpu') -> 'PhoBERTDependencyParser':
465
+ """
466
+ Load model from checkpoint.
467
+
468
+ Args:
469
+ path: Directory path containing the saved model
470
+ device: Device to load the model to
471
+
472
+ Returns:
473
+ Loaded PhoBERTDependencyParser model
474
+ """
475
+ import os
476
+
477
+ checkpoint = torch.load(
478
+ os.path.join(path, 'model.pt'),
479
+ map_location=device,
480
+ weights_only=False
481
+ )
482
+
483
+ config = checkpoint['config']
484
+
485
+ # Create model
486
+ model = cls(
487
+ encoder_name=config['encoder_name'],
488
+ n_rels=config['n_rels'],
489
+ arc_hidden=config['arc_hidden'],
490
+ rel_hidden=config['rel_hidden'],
491
+ dropout=config['dropout'],
492
+ use_mst=config.get('use_mst', True),
493
+ )
494
+
495
+ # Load state dict
496
+ model.load_state_dict(checkpoint['model_state_dict'])
497
+
498
+ # Load vocabulary
499
+ if 'vocab' in checkpoint:
500
+ model.rel2idx = checkpoint['vocab'].get('rel2idx', {})
501
+ model.idx2rel = checkpoint['vocab'].get('idx2rel', {})
502
+ else:
503
+ model.rel2idx = {}
504
+ model.idx2rel = {}
505
+
506
+ model.to(device)
507
+ return model
bamboo1/ud_corpus.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UD Vietnamese VTB Corpus loader for dependency parsing.
3
+
4
+ This module provides a corpus class that downloads the UD Vietnamese VTB dataset
5
+ from Universal Dependencies for comparison with Trankit benchmark results.
6
+
7
+ UD Vietnamese VTB:
8
+ - Treebank size: ~3,300 sentences
9
+ - Source: Vietnamese Language and Speech Processing (VLSP)
10
+ - Standard benchmark for Vietnamese dependency parsing
11
+ """
12
+
13
+ import os
14
+ import tarfile
15
+ import urllib.request
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+
20
+ class UDVietnameseVTB:
21
+ """
22
+ Corpus class for UD Vietnamese VTB dataset.
23
+
24
+ This class downloads the UD Vietnamese VTB treebank from Universal Dependencies
25
+ for fair comparison with Trankit's reported benchmark results.
26
+
27
+ Attributes:
28
+ train: Path to the training data file (CoNLL-U format)
29
+ dev: Path to the development/validation data file (CoNLL-U format)
30
+ test: Path to the test data file (CoNLL-U format)
31
+
32
+ Example:
33
+ >>> from bamboo1.ud_corpus import UDVietnameseVTB
34
+ >>> corpus = UDVietnameseVTB()
35
+ >>> print(corpus.train) # Path to train.conllu
36
+ """
37
+
38
+ name = "UD_Vietnamese-VTB"
39
+
40
+ # UD Vietnamese VTB release URL (v2.14)
41
+ UD_VERSION = "2.14"
42
+ UD_BASE_URL = "https://raw.githubusercontent.com/UniversalDependencies/UD_Vietnamese-VTB/master"
43
+
44
+ FILE_NAMES = {
45
+ "train": "vi_vtb-ud-train.conllu",
46
+ "dev": "vi_vtb-ud-dev.conllu",
47
+ "test": "vi_vtb-ud-test.conllu",
48
+ }
49
+
50
+ def __init__(self, data_dir: Optional[str] = None, force_download: bool = False):
51
+ """
52
+ Initialize the UD Vietnamese VTB corpus.
53
+
54
+ Args:
55
+ data_dir: Directory to store the CoNLL-U files.
56
+ Defaults to ./data/UD_Vietnamese-VTB
57
+ force_download: If True, re-download even if files exist.
58
+ """
59
+ if data_dir is None:
60
+ data_dir = Path(__file__).parent.parent / "data" / "UD_Vietnamese-VTB"
61
+ self.data_dir = Path(data_dir)
62
+ self.data_dir.mkdir(parents=True, exist_ok=True)
63
+
64
+ self._train = self.data_dir / self.FILE_NAMES["train"]
65
+ self._dev = self.data_dir / self.FILE_NAMES["dev"]
66
+ self._test = self.data_dir / self.FILE_NAMES["test"]
67
+
68
+ if force_download or not self._files_exist():
69
+ self._download()
70
+
71
+ def _files_exist(self) -> bool:
72
+ """Check if all required files exist."""
73
+ return self._train.exists() and self._dev.exists() and self._test.exists()
74
+
75
+ def _download(self):
76
+ """Download UD Vietnamese VTB files from GitHub."""
77
+ print(f"Downloading UD Vietnamese VTB from Universal Dependencies...")
78
+
79
+ for split, filename in self.FILE_NAMES.items():
80
+ url = f"{self.UD_BASE_URL}/{filename}"
81
+ output_path = self.data_dir / filename
82
+
83
+ print(f" Downloading {filename}...")
84
+ try:
85
+ urllib.request.urlretrieve(url, output_path)
86
+ except Exception as e:
87
+ print(f" Warning: Failed to download {filename}: {e}")
88
+ print(f" Trying alternative method...")
89
+ self._download_alternative()
90
+ return
91
+
92
+ print(f"Dataset saved to {self.data_dir}")
93
+ self._print_statistics()
94
+
95
+ def _download_alternative(self):
96
+ """Alternative download method using HuggingFace datasets."""
97
+ try:
98
+ from datasets import load_dataset
99
+
100
+ print(" Using HuggingFace datasets library...")
101
+ dataset = load_dataset("universal_dependencies", "vi_vtb")
102
+
103
+ for split_name, output_path in [
104
+ ("train", self._train),
105
+ ("validation", self._dev),
106
+ ("test", self._test),
107
+ ]:
108
+ self._convert_hf_split(dataset[split_name], output_path)
109
+
110
+ print(f"Dataset saved to {self.data_dir}")
111
+ self._print_statistics()
112
+
113
+ except Exception as e:
114
+ raise RuntimeError(
115
+ f"Failed to download UD Vietnamese VTB. "
116
+ f"Please download manually from: "
117
+ f"https://github.com/UniversalDependencies/UD_Vietnamese-VTB\n"
118
+ f"Error: {e}"
119
+ )
120
+
121
+ def _convert_hf_split(self, split, output_path: Path):
122
+ """Convert a HuggingFace dataset split to CoNLL-U format."""
123
+ with open(output_path, "w", encoding="utf-8") as f:
124
+ for idx, item in enumerate(split):
125
+ sent_id = item.get("idx", idx)
126
+ text = item.get("text", "")
127
+
128
+ f.write(f"# sent_id = {sent_id}\n")
129
+ if text:
130
+ f.write(f"# text = {text}\n")
131
+
132
+ tokens = item["tokens"]
133
+ lemmas = item.get("lemmas", ["_"] * len(tokens))
134
+ upos = item["upos"]
135
+ xpos = item.get("xpos", ["_"] * len(tokens))
136
+ feats = item.get("feats", [None] * len(tokens))
137
+ heads = item["head"]
138
+ deprels = item["deprel"]
139
+ deps = item.get("deps", [None] * len(tokens))
140
+ misc = item.get("misc", [None] * len(tokens))
141
+
142
+ for i in range(len(tokens)):
143
+ token_id = i + 1
144
+ form = tokens[i]
145
+ lemma = lemmas[i] if lemmas[i] else "_"
146
+ upos_tag = upos[i] if upos[i] else "_"
147
+ xpos_tag = xpos[i] if xpos[i] else "_"
148
+ feat = feats[i] if feats[i] else "_"
149
+ head = int(heads[i]) if heads[i] is not None else 0
150
+ deprel = deprels[i] if deprels[i] else "_"
151
+ dep = deps[i] if deps[i] else "_"
152
+ misc_val = misc[i] if misc[i] else "_"
153
+
154
+ line = f"{token_id}\t{form}\t{lemma}\t{upos_tag}\t{xpos_tag}\t{feat}\t{head}\t{deprel}\t{dep}\t{misc_val}"
155
+ f.write(line + "\n")
156
+
157
+ f.write("\n")
158
+
159
+ def _print_statistics(self):
160
+ """Print dataset statistics."""
161
+ for name, path in [("Train", self._train), ("Dev", self._dev), ("Test", self._test)]:
162
+ n_sents, n_tokens = self._count_sentences_tokens(path)
163
+ print(f" {name}: {n_sents} sentences, {n_tokens} tokens")
164
+
165
+ def _count_sentences_tokens(self, path: Path) -> tuple:
166
+ """Count sentences and tokens in a CoNLL-U file."""
167
+ n_sents = 0
168
+ n_tokens = 0
169
+
170
+ with open(path, "r", encoding="utf-8") as f:
171
+ for line in f:
172
+ line = line.strip()
173
+ if not line:
174
+ n_sents += 1
175
+ elif not line.startswith("#"):
176
+ parts = line.split("\t")
177
+ if "-" not in parts[0] and "." not in parts[0]:
178
+ n_tokens += 1
179
+
180
+ return n_sents, n_tokens
181
+
182
+ @property
183
+ def train(self) -> str:
184
+ """Path to training data file."""
185
+ return str(self._train)
186
+
187
+ @property
188
+ def dev(self) -> str:
189
+ """Path to development/validation data file."""
190
+ return str(self._dev)
191
+
192
+ @property
193
+ def test(self) -> str:
194
+ """Path to test data file."""
195
+ return str(self._test)
196
+
197
+ def get_statistics(self) -> dict:
198
+ """Get dataset statistics."""
199
+ stats = {}
200
+
201
+ for split_name, path in [
202
+ ("train", self._train),
203
+ ("dev", self._dev),
204
+ ("test", self._test)
205
+ ]:
206
+ n_sents, n_tokens = self._count_sentences_tokens(path)
207
+ stats[f"{split_name}_sentences"] = n_sents
208
+ stats[f"{split_name}_tokens"] = n_tokens
209
+
210
+ # Collect all POS tags and relations
211
+ all_upos = set()
212
+ all_deprels = set()
213
+
214
+ for path in [self._train, self._dev, self._test]:
215
+ with open(path, "r", encoding="utf-8") as f:
216
+ for line in f:
217
+ line = line.strip()
218
+ if line and not line.startswith("#"):
219
+ parts = line.split("\t")
220
+ if len(parts) >= 8 and "-" not in parts[0] and "." not in parts[0]:
221
+ all_upos.add(parts[3])
222
+ all_deprels.add(parts[7])
223
+
224
+ stats["num_upos_tags"] = len(all_upos)
225
+ stats["num_deprels"] = len(all_deprels)
226
+ stats["upos_tags"] = sorted(all_upos)
227
+ stats["deprels"] = sorted(all_deprels)
228
+
229
+ return stats
pyproject.toml CHANGED
@@ -9,7 +9,9 @@ dependencies = [
9
  "datasets>=2.14.0",
10
  "click>=8.0.0",
11
  "underthesea>=9.2.0",
12
- "transformers>=5.0.0",
 
 
13
  ]
14
 
15
  [project.optional-dependencies]
@@ -20,6 +22,9 @@ dev = [
20
  cloud = [
21
  "runpod>=1.6.0",
22
  ]
 
 
 
23
 
24
  [build-system]
25
  requires = ["hatchling"]
 
9
  "datasets>=2.14.0",
10
  "click>=8.0.0",
11
  "underthesea>=9.2.0",
12
+ "transformers>=4.30.0",
13
+ "tqdm>=4.60.0",
14
+ "numpy>=1.24.0",
15
  ]
16
 
17
  [project.optional-dependencies]
 
22
  cloud = [
23
  "runpod>=1.6.0",
24
  ]
25
+ adapters = [
26
+ "adapters>=0.1.0",
27
+ ]
28
 
29
  [build-system]
30
  requires = ["hatchling"]
scripts/evaluate.py CHANGED
@@ -11,10 +11,15 @@
11
  """
12
  Evaluation script for Bamboo-1 Vietnamese Dependency Parser.
13
 
 
 
 
 
14
  Usage:
15
  uv run scripts/evaluate.py --model models/bamboo-1
16
- uv run scripts/evaluate.py --model models/bamboo-1 --split test
17
- uv run scripts/evaluate.py --model models/bamboo-1 --detailed
 
18
  """
19
 
20
  import sys
@@ -27,6 +32,7 @@ import click
27
  sys.path.insert(0, str(Path(__file__).parent.parent))
28
 
29
  from bamboo1.corpus import UDD1Corpus
 
30
 
31
 
32
  def read_conll_sentences(filepath: str):
@@ -103,12 +109,71 @@ def calculate_attachment_scores(gold_sentences, pred_sentences):
103
  }
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  @click.command()
107
  @click.option(
108
  "--model", "-m",
109
  required=True,
110
  help="Path to trained model directory",
111
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  @click.option(
113
  "--split",
114
  type=click.Choice(["dev", "test", "both"]),
@@ -125,21 +190,32 @@ def calculate_attachment_scores(gold_sentences, pred_sentences):
125
  "--output", "-o",
126
  help="Save predictions to file (CoNLL-U format)",
127
  )
128
- def evaluate(model, split, detailed, output):
129
- """Evaluate Bamboo-1 Vietnamese Dependency Parser on UDD-1 dataset."""
130
- from underthesea.models.dependency_parser import DependencyParser
131
 
 
 
 
132
  click.echo("=" * 60)
133
  click.echo("Bamboo-1: Vietnamese Dependency Parser Evaluation")
134
  click.echo("=" * 60)
135
 
136
  # Load model
137
- click.echo(f"\nLoading model from {model}...")
138
- parser = DependencyParser.load(model)
 
 
 
 
 
 
139
 
140
  # Load corpus
141
- click.echo("Loading UDD-1 corpus...")
142
- corpus = UDD1Corpus()
 
 
 
143
 
144
  splits_to_eval = []
145
  if split == "both":
@@ -164,12 +240,11 @@ def evaluate(model, split, detailed, output):
164
  pred_sentences = []
165
 
166
  for gold_sent in gold_sentences:
167
- # Reconstruct text from tokens
168
  tokens = [tok["form"] for tok in gold_sent]
169
- text = " ".join(tokens)
170
 
171
  # Parse
172
- result = parser.predict(text)
173
 
174
  # Convert result to same format as gold
175
  pred_sent = []
 
11
  """
12
  Evaluation script for Bamboo-1 Vietnamese Dependency Parser.
13
 
14
+ Supports both BiLSTM and PhoBERT-based models, and multiple datasets:
15
+ - UDD-1: Main Vietnamese dependency dataset (~18K sentences)
16
+ - UD Vietnamese VTB: Universal Dependencies benchmark (~3.3K sentences)
17
+
18
  Usage:
19
  uv run scripts/evaluate.py --model models/bamboo-1
20
+ uv run scripts/evaluate.py --model models/bamboo-1-phobert --model-type phobert
21
+ uv run scripts/evaluate.py --model models/bamboo-1-phobert --dataset ud-vtb
22
+ uv run scripts/evaluate.py --model models/bamboo-1 --split test --detailed
23
  """
24
 
25
  import sys
 
32
  sys.path.insert(0, str(Path(__file__).parent.parent))
33
 
34
  from bamboo1.corpus import UDD1Corpus
35
+ from bamboo1.ud_corpus import UDVietnameseVTB
36
 
37
 
38
  def read_conll_sentences(filepath: str):
 
109
  }
110
 
111
 
112
+ def load_phobert_model(model_path, device='cuda'):
113
+ """Load PhoBERT-based model."""
114
+ import torch
115
+ from bamboo1.models.transformer_parser import PhoBERTDependencyParser
116
+
117
+ if not torch.cuda.is_available():
118
+ device = 'cpu'
119
+
120
+ return PhoBERTDependencyParser.load(model_path, device=device)
121
+
122
+
123
+ def predict_phobert(parser, words):
124
+ """Make predictions using PhoBERT model."""
125
+ import torch
126
+
127
+ parser.eval()
128
+ device = next(parser.parameters()).device
129
+
130
+ # Tokenize
131
+ encoded = parser.tokenize_with_alignment([words])
132
+ input_ids = encoded['input_ids'].to(device)
133
+ attention_mask = encoded['attention_mask'].to(device)
134
+ word_starts = encoded['word_starts'].to(device)
135
+ word_mask = encoded['word_mask'].to(device)
136
+
137
+ with torch.no_grad():
138
+ arc_scores, rel_scores = parser.forward(
139
+ input_ids, attention_mask, word_starts, word_mask
140
+ )
141
+ arc_preds, rel_preds = parser.decode(arc_scores, rel_scores, word_mask)
142
+
143
+ # Convert to list
144
+ arc_preds = arc_preds[0].cpu().tolist()
145
+ rel_preds = rel_preds[0].cpu().tolist()
146
+
147
+ results = []
148
+ for i, word in enumerate(words):
149
+ head = arc_preds[i]
150
+ rel_idx = rel_preds[i]
151
+ rel = parser.idx2rel.get(rel_idx, "dep")
152
+ results.append((word, head, rel))
153
+
154
+ return results
155
+
156
+
157
  @click.command()
158
  @click.option(
159
  "--model", "-m",
160
  required=True,
161
  help="Path to trained model directory",
162
  )
163
+ @click.option(
164
+ "--model-type",
165
+ type=click.Choice(["bilstm", "phobert"]),
166
+ default="bilstm",
167
+ help="Model type: bilstm (underthesea) or phobert (transformer)",
168
+ show_default=True,
169
+ )
170
+ @click.option(
171
+ "--dataset",
172
+ type=click.Choice(["udd1", "ud-vtb"]),
173
+ default="udd1",
174
+ help="Dataset: udd1 (UDD-1) or ud-vtb (UD Vietnamese VTB)",
175
+ show_default=True,
176
+ )
177
  @click.option(
178
  "--split",
179
  type=click.Choice(["dev", "test", "both"]),
 
190
  "--output", "-o",
191
  help="Save predictions to file (CoNLL-U format)",
192
  )
193
+ def evaluate(model, model_type, dataset, split, detailed, output):
194
+ """Evaluate Bamboo-1 Vietnamese Dependency Parser.
 
195
 
196
+ Supports both BiLSTM (underthesea) and PhoBERT-based models,
197
+ and evaluation on UDD-1 or UD Vietnamese VTB datasets.
198
+ """
199
  click.echo("=" * 60)
200
  click.echo("Bamboo-1: Vietnamese Dependency Parser Evaluation")
201
  click.echo("=" * 60)
202
 
203
  # Load model
204
+ click.echo(f"\nLoading {model_type} model from {model}...")
205
+ if model_type == "phobert":
206
+ parser = load_phobert_model(model)
207
+ predict_fn = lambda words: predict_phobert(parser, words)
208
+ else:
209
+ from underthesea.models.dependency_parser import DependencyParser
210
+ parser = DependencyParser.load(model)
211
+ predict_fn = lambda words: parser.predict(" ".join(words))
212
 
213
  # Load corpus
214
+ click.echo(f"Loading {dataset.upper()} corpus...")
215
+ if dataset == "udd1":
216
+ corpus = UDD1Corpus()
217
+ else:
218
+ corpus = UDVietnameseVTB()
219
 
220
  splits_to_eval = []
221
  if split == "both":
 
240
  pred_sentences = []
241
 
242
  for gold_sent in gold_sentences:
243
+ # Get tokens
244
  tokens = [tok["form"] for tok in gold_sent]
 
245
 
246
  # Parse
247
+ result = predict_fn(tokens)
248
 
249
  # Convert result to same format as gold
250
  pred_sent = []
scripts/run_phobert_runpod.sh ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Run PhoBERT dependency parser training on RunPod
3
+ #
4
+ # Usage:
5
+ # ./scripts/run_phobert_runpod.sh setup # Install uv, clone repo, sync deps
6
+ # ./scripts/run_phobert_runpod.sh train # Train PhoBERT on UDD-1
7
+ # ./scripts/run_phobert_runpod.sh train-vtb # Train PhoBERT on UD Vietnamese VTB
8
+ # ./scripts/run_phobert_runpod.sh train-large # Train with PhoBERT-large
9
+ # ./scripts/run_phobert_runpod.sh eval # Evaluate trained model
10
+ # ./scripts/run_phobert_runpod.sh download # Download trained model
11
+ # ./scripts/run_phobert_runpod.sh ssh # Interactive SSH session
12
+ # ./scripts/run_phobert_runpod.sh <command> # Run custom command
13
+ #
14
+ # Environment variables:
15
+ # RUNPOD_HOST Pod IP address
16
+ # RUNPOD_PORT Pod SSH port
17
+ # WANDB_API_KEY (optional) W&B API key for logging
18
+
19
+ set -e
20
+
21
+ # Pod connection details (update these after launching pod)
22
+ POD_HOST="${RUNPOD_HOST:-213.173.99.13}"
23
+ POD_PORT="${RUNPOD_PORT:-11375}"
24
+ SSH_OPTS="-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o LogLevel=ERROR"
25
+
26
+ # Training defaults
27
+ MODEL_DIR="models/bamboo-1-phobert"
28
+ ENCODER="vinai/phobert-base"
29
+ EPOCHS=100
30
+ BATCH_SIZE="${BATCH_SIZE:-48}" # Auto: 48 for A5000, increase for larger GPUs
31
+ PATIENCE=10
32
+ FP16="--fp16" # Enable mixed precision for ~2x speedup
33
+
34
+ ssh_cmd() {
35
+ ssh $SSH_OPTS root@$POD_HOST -p $POD_PORT "$@"
36
+ }
37
+
38
+ scp_to_pod() {
39
+ scp $SSH_OPTS -P $POD_PORT "$1" root@$POD_HOST:"$2"
40
+ }
41
+
42
+ scp_from_pod() {
43
+ scp $SSH_OPTS -P $POD_PORT root@$POD_HOST:"$1" "$2"
44
+ }
45
+
46
+ terminate_pod() {
47
+ echo ""
48
+ echo "Terminating pod..."
49
+ cd "$(dirname "$0")/.." && uv run scripts/runpod_setup.py status 2>/dev/null | grep -oP '\(\K[a-z0-9]+(?=\))' | head -1 | xargs -I {} uv run scripts/runpod_setup.py terminate {} 2>/dev/null || echo "Could not auto-terminate. Run: uv run scripts/runpod_setup.py terminate <pod-id>"
50
+ }
51
+
52
+ # Build wandb flags if API key is set
53
+ get_wandb_flags() {
54
+ if [ -n "$WANDB_API_KEY" ]; then
55
+ echo "--wandb --wandb-project bamboo-1-phobert"
56
+ fi
57
+ }
58
+
59
+ # Build wandb env export for SSH commands
60
+ get_wandb_env() {
61
+ if [ -n "$WANDB_API_KEY" ]; then
62
+ echo "export WANDB_API_KEY=$WANDB_API_KEY && "
63
+ fi
64
+ }
65
+
66
+ case "${1:-help}" in
67
+ setup)
68
+ echo "Setting up environment on RunPod..."
69
+
70
+ # Install uv
71
+ ssh_cmd 'curl -LsSf https://astral.sh/uv/install.sh | sh'
72
+
73
+ # Clone repo
74
+ ssh_cmd 'source $HOME/.local/bin/env && git clone https://huggingface.co/undertheseanlp/bamboo-1 /workspace/bamboo-1 || true'
75
+
76
+ # Pull latest and sync dependencies
77
+ ssh_cmd 'source $HOME/.local/bin/env && cd /workspace/bamboo-1 && git pull && uv sync'
78
+
79
+ echo "Setup complete!"
80
+
81
+ if [ -n "$WANDB_API_KEY" ]; then
82
+ echo ""
83
+ echo "WANDB_API_KEY detected - it will be passed automatically during training."
84
+ fi
85
+ echo ""
86
+ echo "Next steps:"
87
+ echo " ./scripts/run_phobert_runpod.sh train # Train on UDD-1"
88
+ echo " ./scripts/run_phobert_runpod.sh train-vtb # Train on UD-VTB (Trankit benchmark)"
89
+ ;;
90
+
91
+ train)
92
+ echo "Training PhoBERT dependency parser on UDD-1..."
93
+ echo " Encoder: $ENCODER"
94
+ echo " Output: $MODEL_DIR"
95
+ echo " Epochs: $EPOCHS"
96
+
97
+ WANDB_FLAGS=$(get_wandb_flags)
98
+ WANDB_ENV=$(get_wandb_env)
99
+
100
+ ssh_cmd "${WANDB_ENV}source \$HOME/.local/bin/env && cd /workspace/bamboo-1 && \
101
+ uv run scripts/train_phobert.py \
102
+ --output $MODEL_DIR \
103
+ --encoder $ENCODER \
104
+ --dataset udd1 \
105
+ --epochs $EPOCHS \
106
+ --batch-size $BATCH_SIZE \
107
+ --patience $PATIENCE \
108
+ $FP16 \
109
+ $WANDB_FLAGS"
110
+
111
+ echo ""
112
+ echo "Training complete! Download model with:"
113
+ echo " ./scripts/run_phobert_runpod.sh download"
114
+ ;;
115
+
116
+ train-vtb)
117
+ echo "Training PhoBERT dependency parser on UD Vietnamese VTB..."
118
+ echo " (For comparison with Trankit benchmark)"
119
+ echo " Encoder: $ENCODER"
120
+ echo " Output: ${MODEL_DIR}-vtb"
121
+ echo " Epochs: $EPOCHS"
122
+
123
+ WANDB_FLAGS=$(get_wandb_flags)
124
+ WANDB_ENV=$(get_wandb_env)
125
+
126
+ ssh_cmd "${WANDB_ENV}source \$HOME/.local/bin/env && cd /workspace/bamboo-1 && \
127
+ uv run scripts/train_phobert.py \
128
+ --output ${MODEL_DIR}-vtb \
129
+ --encoder $ENCODER \
130
+ --dataset ud-vtb \
131
+ --epochs $EPOCHS \
132
+ --batch-size $BATCH_SIZE \
133
+ --patience $PATIENCE \
134
+ $FP16 \
135
+ $WANDB_FLAGS"
136
+
137
+ echo ""
138
+ echo "Training complete! Download model with:"
139
+ echo " ./scripts/run_phobert_runpod.sh download-vtb"
140
+ ;;
141
+
142
+ train-large)
143
+ echo "Training PhoBERT-large dependency parser on UDD-1..."
144
+ echo " Encoder: vinai/phobert-large"
145
+ echo " Output: ${MODEL_DIR}-large"
146
+ echo " Epochs: $EPOCHS"
147
+ echo " (Note: Requires GPU with >= 24GB VRAM)"
148
+
149
+ WANDB_FLAGS=$(get_wandb_flags)
150
+ WANDB_ENV=$(get_wandb_env)
151
+
152
+ ssh_cmd "${WANDB_ENV}source \$HOME/.local/bin/env && cd /workspace/bamboo-1 && \
153
+ uv run scripts/train_phobert.py \
154
+ --output ${MODEL_DIR}-large \
155
+ --encoder vinai/phobert-large \
156
+ --dataset udd1 \
157
+ --epochs $EPOCHS \
158
+ --batch-size 24 \
159
+ --patience $PATIENCE \
160
+ $FP16 \
161
+ $WANDB_FLAGS"
162
+
163
+ echo ""
164
+ echo "Training complete! Download model with:"
165
+ echo " ./scripts/run_phobert_runpod.sh download-large"
166
+ ;;
167
+
168
+ train-quick)
169
+ echo "Quick training run (100 samples) for testing..."
170
+
171
+ WANDB_FLAGS=$(get_wandb_flags)
172
+ WANDB_ENV=$(get_wandb_env)
173
+
174
+ ssh_cmd "${WANDB_ENV}source \$HOME/.local/bin/env && cd /workspace/bamboo-1 && \
175
+ uv run scripts/train_phobert.py \
176
+ --output ${MODEL_DIR}-test \
177
+ --encoder $ENCODER \
178
+ --dataset udd1 \
179
+ --epochs 5 \
180
+ --batch-size $BATCH_SIZE \
181
+ --sample 100 \
182
+ $FP16 \
183
+ $WANDB_FLAGS"
184
+ ;;
185
+
186
+ train-fast)
187
+ echo "FAST Trankit reproduction (<5 min) - H100 settings!"
188
+ echo " Dataset: UD Vietnamese VTB (Trankit benchmark)"
189
+ echo " Encoder: $ENCODER"
190
+ echo " Output: ${MODEL_DIR}-vtb"
191
+ echo " Settings: batch=256, epochs=30, patience=5"
192
+ echo ""
193
+ echo " Target: Trankit base 70.96% UAS / 64.76% LAS"
194
+
195
+ WANDB_FLAGS=$(get_wandb_flags)
196
+ WANDB_ENV=$(get_wandb_env)
197
+
198
+ ssh_cmd "${WANDB_ENV}source \$HOME/.local/bin/env && cd /workspace/bamboo-1 && \
199
+ uv run scripts/train_phobert.py \
200
+ --output ${MODEL_DIR}-vtb \
201
+ --encoder $ENCODER \
202
+ --dataset ud-vtb \
203
+ --epochs 30 \
204
+ --batch-size 256 \
205
+ --patience 5 \
206
+ --warmup-steps 50 \
207
+ --num-workers 8 \
208
+ $FP16 \
209
+ $WANDB_FLAGS"
210
+
211
+ echo ""
212
+ echo "Training complete! Download model with:"
213
+ echo " ./scripts/run_phobert_runpod.sh download-vtb"
214
+ ;;
215
+
216
+ eval)
217
+ echo "Evaluating PhoBERT model on UDD-1 test set..."
218
+
219
+ ssh_cmd "source \$HOME/.local/bin/env && cd /workspace/bamboo-1 && \
220
+ uv run scripts/evaluate.py \
221
+ --model $MODEL_DIR \
222
+ --model-type phobert \
223
+ --dataset udd1 \
224
+ --split test \
225
+ --detailed"
226
+ ;;
227
+
228
+ eval-vtb)
229
+ echo "Evaluating PhoBERT model on UD Vietnamese VTB test set..."
230
+ echo " (For comparison with Trankit: 70.96% UAS / 64.76% LAS)"
231
+
232
+ ssh_cmd "source \$HOME/.local/bin/env && cd /workspace/bamboo-1 && \
233
+ uv run scripts/evaluate.py \
234
+ --model ${MODEL_DIR}-vtb \
235
+ --model-type phobert \
236
+ --dataset ud-vtb \
237
+ --split test \
238
+ --detailed"
239
+ ;;
240
+
241
+ download)
242
+ echo "Downloading trained model from RunPod..."
243
+ mkdir -p models/bamboo-1-phobert
244
+ scp_from_pod "/workspace/bamboo-1/$MODEL_DIR/*" "models/bamboo-1-phobert/"
245
+ echo "Model downloaded to models/bamboo-1-phobert/"
246
+ ;;
247
+
248
+ download-vtb)
249
+ echo "Downloading VTB-trained model from RunPod..."
250
+ mkdir -p models/bamboo-1-phobert-vtb
251
+ scp_from_pod "/workspace/bamboo-1/${MODEL_DIR}-vtb/*" "models/bamboo-1-phobert-vtb/"
252
+ echo "Model downloaded to models/bamboo-1-phobert-vtb/"
253
+ ;;
254
+
255
+ download-large)
256
+ echo "Downloading PhoBERT-large model from RunPod..."
257
+ mkdir -p models/bamboo-1-phobert-large
258
+ scp_from_pod "/workspace/bamboo-1/${MODEL_DIR}-large/*" "models/bamboo-1-phobert-large/"
259
+ echo "Model downloaded to models/bamboo-1-phobert-large/"
260
+ ;;
261
+
262
+ logs)
263
+ echo "Tailing training logs..."
264
+ ssh_cmd "tail -f /workspace/bamboo-1/training.log 2>/dev/null || echo 'No log file found. Training may not have started yet.'"
265
+ ;;
266
+
267
+ gpu-status)
268
+ echo "GPU status on RunPod..."
269
+ ssh_cmd "nvidia-smi"
270
+ ;;
271
+
272
+ ssh)
273
+ echo "Connecting to RunPod..."
274
+ ssh $SSH_OPTS root@$POD_HOST -p $POD_PORT
275
+ ;;
276
+
277
+ help|--help|-h)
278
+ echo "Usage: $0 <command>"
279
+ echo ""
280
+ echo "PhoBERT Training Commands:"
281
+ echo " setup Install uv, clone repo, sync dependencies"
282
+ echo " train-fast FAST Trankit reproduction <5 min (H100, UD-VTB)"
283
+ echo " train Train PhoBERT on UDD-1 dataset (18K sentences)"
284
+ echo " train-vtb Train PhoBERT on UD Vietnamese VTB (Trankit benchmark)"
285
+ echo " train-large Train PhoBERT-large on UDD-1 (requires 24GB+ VRAM)"
286
+ echo " train-quick Quick test run with 100 samples"
287
+ echo ""
288
+ echo "Evaluation Commands:"
289
+ echo " eval Evaluate model on UDD-1 test set"
290
+ echo " eval-vtb Evaluate model on UD-VTB test set"
291
+ echo ""
292
+ echo "Utility Commands:"
293
+ echo " download Download trained model (UDD-1)"
294
+ echo " download-vtb Download trained model (UD-VTB)"
295
+ echo " download-large Download trained model (PhoBERT-large)"
296
+ echo " logs Tail training logs"
297
+ echo " gpu-status Show GPU utilization"
298
+ echo " ssh Interactive SSH session"
299
+ echo " <cmd> Run custom command on pod"
300
+ echo ""
301
+ echo "Environment variables:"
302
+ echo " RUNPOD_HOST Pod IP address (default: $POD_HOST)"
303
+ echo " RUNPOD_PORT Pod SSH port (default: $POD_PORT)"
304
+ echo " WANDB_API_KEY W&B API key for experiment tracking (optional)"
305
+ echo " BATCH_SIZE Override default batch size (default: 48)"
306
+ echo ""
307
+ echo "GPU Recommendations (for launch-phobert):"
308
+ echo " A4000 (16GB) - Budget, ~\$0.20/hr, batch_size=32"
309
+ echo " A5000 (24GB) - Recommended, ~\$0.30/hr, batch_size=48 (default)"
310
+ echo " A6000 (48GB) - Fast, ~\$0.50/hr, batch_size=64"
311
+ echo " A100 (80GB) - Fastest, ~\$1.50/hr, batch_size=128"
312
+ echo ""
313
+ echo "Trankit Benchmark Reference:"
314
+ echo " Trankit base: 70.96% UAS / 64.76% LAS (UD Vietnamese VTB)"
315
+ echo " Trankit large: 71.07% UAS / 65.37% LAS (UD Vietnamese VTB)"
316
+ ;;
317
+
318
+ *)
319
+ # Run custom command
320
+ ssh_cmd "source \$HOME/.local/bin/env && cd /workspace/bamboo-1 && $*"
321
+ ;;
322
+ esac
scripts/runpod_setup.py CHANGED
@@ -3,6 +3,7 @@
3
  # dependencies = [
4
  # "runpod>=1.6.0",
5
  # "requests>=2.28.0",
 
6
  # ]
7
  # ///
8
  """
@@ -29,9 +30,15 @@ Usage:
29
  """
30
 
31
  import os
 
 
32
  import click
33
  import runpod
34
  import requests
 
 
 
 
35
 
36
 
37
  @click.group()
@@ -147,7 +154,12 @@ def status():
147
 
148
  click.echo("Active pods:")
149
  for pod in pods:
150
- click.echo(f" - {pod['name']} ({pod['id']}): {pod.get('desiredStatus', 'UNKNOWN')}")
 
 
 
 
 
151
 
152
 
153
  @cli.command()
@@ -168,6 +180,139 @@ def terminate(pod_id):
168
  click.echo("Pod terminated.")
169
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # =============================================================================
172
  # Volume Management
173
  # =============================================================================
@@ -192,6 +337,117 @@ def _graphql_request(query: str, variables: dict = None) -> dict:
192
  return response.json()
193
 
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  @cli.command("volume-list")
196
  def volume_list():
197
  """List all network volumes."""
 
3
  # dependencies = [
4
  # "runpod>=1.6.0",
5
  # "requests>=2.28.0",
6
+ # "python-dotenv>=1.0.0",
7
  # ]
8
  # ///
9
  """
 
30
  """
31
 
32
  import os
33
+ from pathlib import Path
34
+
35
  import click
36
  import runpod
37
  import requests
38
+ from dotenv import load_dotenv
39
+
40
+ # Load .env file from project root
41
+ load_dotenv(Path(__file__).parent.parent / ".env")
42
 
43
 
44
  @click.group()
 
154
 
155
  click.echo("Active pods:")
156
  for pod in pods:
157
+ click.echo(f"\n {pod['name']} ({pod['id']}): {pod.get('desiredStatus', 'UNKNOWN')}")
158
+ runtime = pod.get('runtime') or {}
159
+ ports = runtime.get('ports') or []
160
+ for p in ports:
161
+ if p.get('privatePort') == 22:
162
+ click.echo(f" SSH: ssh root@{p.get('ip')} -p {p.get('publicPort')}")
163
 
164
 
165
  @cli.command()
 
180
  click.echo("Pod terminated.")
181
 
182
 
183
+ GPU_RECOMMENDATIONS = {
184
+ "budget": "NVIDIA RTX A4000", # 16GB, $0.20/hr - Basic training
185
+ "balanced": "NVIDIA RTX A5000", # 24GB, $0.30/hr - Good balance (Recommended)
186
+ "fast": "NVIDIA RTX A6000", # 48GB, $0.50/hr - Larger batches, faster
187
+ "fastest": "NVIDIA A100 80GB PCIe", # 80GB, $1.50/hr - Best for production
188
+ }
189
+
190
+
191
+ @cli.command("launch-phobert")
192
+ @click.option("--gpu", default="NVIDIA RTX A5000",
193
+ help="GPU type: A4000 (budget), A5000 (balanced), A6000 (fast), A100 (fastest)")
194
+ @click.option("--image", default=DEFAULT_IMAGE, help="Docker image")
195
+ @click.option("--disk", default=30, type=int, help="Disk size in GB (PhoBERT needs more space)")
196
+ @click.option("--name", default="bamboo-1-phobert", help="Pod name")
197
+ @click.option("--volume", default=None, help="Network volume ID to attach")
198
+ @click.option("--wandb-key", envvar="WANDB_API_KEY", help="W&B API key for logging")
199
+ @click.option("--dataset", type=click.Choice(["udd1", "ud-vtb"]), default="udd1",
200
+ help="Dataset: udd1 or ud-vtb (Trankit benchmark)")
201
+ @click.option("--encoder", default="vinai/phobert-base",
202
+ help="Encoder: vinai/phobert-base or vinai/phobert-large")
203
+ @click.option("--epochs", default=100, type=int, help="Number of epochs")
204
+ @click.option("--sample", default=0, type=int, help="Sample N sentences (0=all)")
205
+ @click.option("--batch-size", default=0, type=int, help="Batch size (0=auto based on GPU)")
206
+ def launch_phobert(gpu, image, disk, name, volume, wandb_key, dataset, encoder, epochs, sample, batch_size):
207
+ """Launch a RunPod instance for PhoBERT training.
208
+
209
+ This launches a pod configured for training the PhoBERT-based dependency parser.
210
+ After the pod starts, SSH in and run the training command printed below.
211
+
212
+ GPU Recommendations:
213
+ A4000 (16GB) - Budget option, batch_size=32
214
+ A5000 (24GB) - Recommended balance, batch_size=48-64
215
+ A6000 (48GB) - Fast training, batch_size=64-96
216
+ A100 (80GB) - Fastest, batch_size=128+
217
+
218
+ Example:
219
+ uv run scripts/runpod_setup.py launch-phobert
220
+ uv run scripts/runpod_setup.py launch-phobert --gpu "NVIDIA RTX A6000" # Faster
221
+ uv run scripts/runpod_setup.py launch-phobert --dataset ud-vtb # Trankit benchmark
222
+ uv run scripts/runpod_setup.py launch-phobert --encoder vinai/phobert-large --gpu "NVIDIA RTX A6000"
223
+ """
224
+ # Auto-select batch size based on GPU if not specified
225
+ if batch_size == 0:
226
+ if "A100" in gpu or "H100" in gpu:
227
+ batch_size = 128
228
+ elif "A6000" in gpu:
229
+ batch_size = 64
230
+ elif "A5000" in gpu:
231
+ batch_size = 48
232
+ else: # A4000 or unknown
233
+ batch_size = 32
234
+
235
+ # Reduce batch size for large encoder
236
+ if "large" in encoder:
237
+ batch_size = batch_size // 2
238
+
239
+ click.echo("Launching RunPod instance for PhoBERT training...")
240
+ click.echo(f" GPU: {gpu}")
241
+ click.echo(f" Image: {image}")
242
+ click.echo(f" Disk: {disk}GB")
243
+ click.echo(f" Dataset: {dataset}")
244
+ click.echo(f" Encoder: {encoder}")
245
+ click.echo(f" Batch size: {batch_size}")
246
+
247
+ # Build training command with optimizations
248
+ train_cmd = f"uv run scripts/train_phobert.py --encoder {encoder} --dataset {dataset} --epochs {epochs} --batch-size {batch_size} --fp16"
249
+ if sample > 0:
250
+ train_cmd += f" --sample {sample}"
251
+ if wandb_key:
252
+ train_cmd += " --wandb --wandb-project bamboo-1-phobert"
253
+
254
+ # Output directory based on config
255
+ output_suffix = ""
256
+ if dataset == "ud-vtb":
257
+ output_suffix += "-vtb"
258
+ if "large" in encoder:
259
+ output_suffix += "-large"
260
+ train_cmd += f" --output models/bamboo-1-phobert{output_suffix}"
261
+
262
+ # Set environment variables
263
+ env_vars = {}
264
+ if wandb_key:
265
+ env_vars["WANDB_API_KEY"] = wandb_key
266
+
267
+ # Add SSH public key
268
+ ssh_key = get_ssh_public_key()
269
+ if ssh_key:
270
+ env_vars["PUBLIC_KEY"] = ssh_key
271
+ click.echo(" SSH key: configured")
272
+
273
+ if volume:
274
+ click.echo(f" Volume: {volume}")
275
+
276
+ pod = runpod.create_pod(
277
+ name=name,
278
+ image_name=image,
279
+ gpu_type_id=gpu,
280
+ volume_in_gb=disk,
281
+ env=env_vars if env_vars else None,
282
+ ports="22/tcp",
283
+ network_volume_id=volume,
284
+ )
285
+
286
+ click.echo("\nPod created!")
287
+ click.echo(f" ID: {pod['id']}")
288
+ click.echo(f" Status: {pod.get('desiredStatus', 'PENDING')}")
289
+ click.echo("\nMonitor at: https://runpod.io/console/pods")
290
+
291
+ # Generate setup and training commands
292
+ click.echo("\n" + "="*70)
293
+ click.echo("After SSH into the pod, run these commands:")
294
+ click.echo("="*70)
295
+
296
+ setup_cmd = """curl -LsSf https://astral.sh/uv/install.sh | sh && \\
297
+ source $HOME/.local/bin/env && \\
298
+ git clone https://huggingface.co/undertheseanlp/bamboo-1 /workspace/bamboo-1 && \\
299
+ cd /workspace/bamboo-1 && uv sync"""
300
+
301
+ click.echo("\n# 1. Setup (run once):")
302
+ click.echo(setup_cmd)
303
+
304
+ click.echo("\n# 2. Train:")
305
+ click.echo(f"cd /workspace/bamboo-1 && {train_cmd}")
306
+
307
+ click.echo("\n" + "="*70)
308
+
309
+ if dataset == "ud-vtb":
310
+ click.echo("\nTranskit benchmark reference:")
311
+ click.echo(" Trankit base: 70.96% UAS / 64.76% LAS")
312
+ click.echo(" Trankit large: 71.07% UAS / 65.37% LAS")
313
+ click.echo("")
314
+
315
+
316
  # =============================================================================
317
  # Volume Management
318
  # =============================================================================
 
337
  return response.json()
338
 
339
 
340
+ @cli.command("launch-fast")
341
+ @click.option("--gpu", default="NVIDIA H100 80GB HBM3", help="GPU type (H100 for fastest)")
342
+ @click.option("--image", default=DEFAULT_IMAGE, help="Docker image")
343
+ @click.option("--disk", default=30, type=int, help="Disk size in GB")
344
+ @click.option("--name", default="bamboo-1-trankit", help="Pod name")
345
+ @click.option("--volume", default=None, help="Network volume ID to attach")
346
+ @click.option("--wandb-key", envvar="WANDB_API_KEY", help="W&B API key for logging")
347
+ @click.option("--encoder", default="vinai/phobert-base", help="Encoder model")
348
+ def launch_fast(gpu, image, disk, name, volume, wandb_key, encoder):
349
+ """Launch pod for FAST Trankit reproduction (<5 minutes).
350
+
351
+ Trains on UD Vietnamese VTB to reproduce Trankit benchmark:
352
+ - Trankit base: 70.96% UAS / 64.76% LAS
353
+ - Trankit large: 71.07% UAS / 65.37% LAS
354
+
355
+ Uses H100 with aggressive settings for <5 min training.
356
+
357
+ Example:
358
+ uv run scripts/runpod_setup.py launch-fast
359
+ uv run scripts/runpod_setup.py launch-fast --encoder vinai/phobert-large
360
+ """
361
+ dataset = "ud-vtb" # Always use UD-VTB for Trankit reproduction
362
+
363
+ # Set batch size based on GPU
364
+ if "H100" in gpu:
365
+ batch_size = 256
366
+ epochs = 30
367
+ elif "A100" in gpu:
368
+ batch_size = 128
369
+ epochs = 40
370
+ else:
371
+ batch_size = 64
372
+ epochs = 50
373
+ click.echo("WARNING: For <5 min training, use H100!")
374
+
375
+ # Reduce batch for large model
376
+ if "large" in encoder:
377
+ batch_size = batch_size // 2
378
+
379
+ click.echo("Launching FAST Trankit reproduction (<5 minutes)...")
380
+ click.echo(f" GPU: {gpu}")
381
+ click.echo(f" Batch size: {batch_size}")
382
+ click.echo(f" Epochs: {epochs}")
383
+ click.echo(f" Dataset: {dataset} (UD Vietnamese VTB)")
384
+ click.echo(f" Encoder: {encoder}")
385
+ click.echo("")
386
+ click.echo(" Target: Trankit base 70.96% UAS / 64.76% LAS")
387
+
388
+ # Output name
389
+ output_name = "models/bamboo-1-phobert-vtb"
390
+ if "large" in encoder:
391
+ output_name += "-large"
392
+
393
+ # Build optimized training command
394
+ train_cmd = f"""uv run scripts/train_phobert.py \\
395
+ --encoder {encoder} \\
396
+ --dataset {dataset} \\
397
+ --output {output_name} \\
398
+ --epochs {epochs} \\
399
+ --batch-size {batch_size} \\
400
+ --patience 5 \\
401
+ --warmup-steps 50 \\
402
+ --num-workers 8 \\
403
+ --fp16"""
404
+
405
+ if wandb_key:
406
+ train_cmd += " --wandb --wandb-project bamboo-1-phobert"
407
+
408
+ # Set environment variables
409
+ env_vars = {}
410
+ if wandb_key:
411
+ env_vars["WANDB_API_KEY"] = wandb_key
412
+
413
+ ssh_key = get_ssh_public_key()
414
+ if ssh_key:
415
+ env_vars["PUBLIC_KEY"] = ssh_key
416
+ click.echo(" SSH key: configured")
417
+
418
+ if volume:
419
+ click.echo(f" Volume: {volume}")
420
+
421
+ pod = runpod.create_pod(
422
+ name=name,
423
+ image_name=image,
424
+ gpu_type_id=gpu,
425
+ volume_in_gb=disk,
426
+ env=env_vars if env_vars else None,
427
+ ports="22/tcp",
428
+ network_volume_id=volume,
429
+ )
430
+
431
+ click.echo(f"\nPod created!")
432
+ click.echo(f" ID: {pod['id']}")
433
+ click.echo(f" Status: {pod.get('desiredStatus', 'PENDING')}")
434
+ click.echo("\nMonitor at: https://runpod.io/console/pods")
435
+
436
+ # One-liner setup + train
437
+ click.echo("\n" + "="*70)
438
+ click.echo("SSH in and run this ONE command for <5 min training:")
439
+ click.echo("="*70)
440
+
441
+ one_liner = f"""curl -LsSf https://astral.sh/uv/install.sh | sh && \\
442
+ source $HOME/.local/bin/env && \\
443
+ git clone https://huggingface.co/undertheseanlp/bamboo-1 /workspace/bamboo-1 && \\
444
+ cd /workspace/bamboo-1 && uv sync && \\
445
+ {train_cmd}"""
446
+
447
+ click.echo(one_liner)
448
+ click.echo("="*70)
449
+
450
+
451
  @cli.command("volume-list")
452
  def volume_list():
453
  """List all network volumes."""
scripts/train_phobert.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "transformers>=4.30.0",
6
+ # "datasets>=2.14.0",
7
+ # "click>=8.0.0",
8
+ # "tqdm>=4.60.0",
9
+ # "wandb>=0.15.0",
10
+ # ]
11
+ # ///
12
+ """
13
+ Training script for PhoBERT-based Vietnamese Dependency Parser.
14
+
15
+ This script trains a transformer-based dependency parser using PhoBERT as the
16
+ encoder, following the Trankit approach for Vietnamese dependency parsing.
17
+
18
+ Architecture:
19
+ PhoBERT -> Word-level pooling -> Biaffine attention -> MST decoding
20
+
21
+ Usage:
22
+ uv run scripts/train_phobert.py
23
+ uv run scripts/train_phobert.py --output models/bamboo-1-phobert --epochs 100
24
+ uv run scripts/train_phobert.py --encoder vinai/phobert-large
25
+ uv run scripts/train_phobert.py --dataset ud-vtb # Use UD Vietnamese VTB
26
+ """
27
+
28
+ import sys
29
+ from pathlib import Path
30
+ from collections import Counter
31
+ from dataclasses import dataclass
32
+ from typing import List, Tuple, Optional, Dict
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ from torch.utils.data import Dataset, DataLoader
37
+ from torch.optim import AdamW
38
+ from tqdm import tqdm
39
+
40
+ import click
41
+
42
+ sys.path.insert(0, str(Path(__file__).parent.parent))
43
+ from bamboo1.corpus import UDD1Corpus
44
+ from bamboo1.ud_corpus import UDVietnameseVTB
45
+ from bamboo1.models.transformer_parser import PhoBERTDependencyParser
46
+ from scripts.cost_estimate import CostTracker, detect_hardware
47
+
48
+
49
+ # ============================================================================
50
+ # Data Processing
51
+ # ============================================================================
52
+
53
+ @dataclass
54
+ class Sentence:
55
+ """A dependency-parsed sentence."""
56
+ words: List[str]
57
+ heads: List[int]
58
+ rels: List[str]
59
+
60
+
61
+ def read_conllu(path: str) -> List[Sentence]:
62
+ """Read CoNLL-U file and return list of sentences."""
63
+ sentences = []
64
+ words, heads, rels = [], [], []
65
+
66
+ with open(path, 'r', encoding='utf-8') as f:
67
+ for line in f:
68
+ line = line.strip()
69
+ if not line:
70
+ if words:
71
+ sentences.append(Sentence(words, heads, rels))
72
+ words, heads, rels = [], [], []
73
+ elif line.startswith('#'):
74
+ continue
75
+ else:
76
+ parts = line.split('\t')
77
+ if '-' in parts[0] or '.' in parts[0]:
78
+ continue
79
+ words.append(parts[1])
80
+ heads.append(int(parts[6]))
81
+ rels.append(parts[7])
82
+
83
+ if words:
84
+ sentences.append(Sentence(words, heads, rels))
85
+
86
+ return sentences
87
+
88
+
89
+ class Vocabulary:
90
+ """Vocabulary for relations."""
91
+
92
+ def __init__(self):
93
+ self.rel2idx = {}
94
+ self.idx2rel = {}
95
+
96
+ def build(self, sentences: List[Sentence]):
97
+ """Build vocabulary from sentences."""
98
+ rel_counts = Counter()
99
+ for sent in sentences:
100
+ for rel in sent.rels:
101
+ rel_counts[rel] += 1
102
+
103
+ for rel in sorted(rel_counts.keys()):
104
+ if rel not in self.rel2idx:
105
+ idx = len(self.rel2idx)
106
+ self.rel2idx[rel] = idx
107
+ self.idx2rel[idx] = rel
108
+
109
+ @property
110
+ def n_rels(self) -> int:
111
+ return len(self.rel2idx)
112
+
113
+
114
+ class PhoBERTDependencyDataset(Dataset):
115
+ """Dataset for PhoBERT dependency parsing."""
116
+
117
+ def __init__(
118
+ self,
119
+ sentences: List[Sentence],
120
+ vocab: Vocabulary,
121
+ tokenizer,
122
+ max_length: int = 256,
123
+ ):
124
+ self.sentences = sentences
125
+ self.vocab = vocab
126
+ self.tokenizer = tokenizer
127
+ self.max_length = max_length
128
+
129
+ def __len__(self):
130
+ return len(self.sentences)
131
+
132
+ def __getitem__(self, idx):
133
+ sent = self.sentences[idx]
134
+
135
+ # Tokenize with word boundary tracking
136
+ word_starts = []
137
+ subword_ids = [self.tokenizer.cls_token_id]
138
+
139
+ for word in sent.words:
140
+ word_starts.append(len(subword_ids))
141
+ word_tokens = self.tokenizer.encode(word, add_special_tokens=False)
142
+ if not word_tokens:
143
+ word_tokens = [self.tokenizer.unk_token_id]
144
+ subword_ids.extend(word_tokens)
145
+
146
+ subword_ids.append(self.tokenizer.sep_token_id)
147
+
148
+ # Truncate if needed
149
+ if len(subword_ids) > self.max_length:
150
+ subword_ids = subword_ids[:self.max_length-1] + [self.tokenizer.sep_token_id]
151
+ # Keep words that fit
152
+ valid_words = sum(1 for ws in word_starts if ws < self.max_length - 1)
153
+ word_starts = word_starts[:valid_words]
154
+ heads = sent.heads[:valid_words]
155
+ rels = sent.rels[:valid_words]
156
+ else:
157
+ heads = sent.heads
158
+ rels = sent.rels
159
+
160
+ # Encode relations
161
+ rel_ids = [self.vocab.rel2idx.get(r, 0) for r in rels]
162
+
163
+ return {
164
+ 'input_ids': subword_ids,
165
+ 'word_starts': word_starts,
166
+ 'heads': heads,
167
+ 'rels': rel_ids,
168
+ }
169
+
170
+
171
+ def collate_fn(batch):
172
+ """Collate function for DataLoader."""
173
+ # Get max lengths
174
+ max_subword_len = max(len(item['input_ids']) for item in batch)
175
+ max_word_len = max(len(item['word_starts']) for item in batch)
176
+
177
+ batch_size = len(batch)
178
+
179
+ # Initialize tensors
180
+ input_ids = torch.zeros(batch_size, max_subword_len, dtype=torch.long)
181
+ attention_mask = torch.zeros(batch_size, max_subword_len, dtype=torch.long)
182
+ word_starts = torch.zeros(batch_size, max_word_len, dtype=torch.long)
183
+ word_mask = torch.zeros(batch_size, max_word_len, dtype=torch.bool)
184
+ heads = torch.zeros(batch_size, max_word_len, dtype=torch.long)
185
+ rels = torch.zeros(batch_size, max_word_len, dtype=torch.long)
186
+
187
+ for i, item in enumerate(batch):
188
+ # Subwords
189
+ seq_len = len(item['input_ids'])
190
+ input_ids[i, :seq_len] = torch.tensor(item['input_ids'])
191
+ attention_mask[i, :seq_len] = 1
192
+
193
+ # Words
194
+ word_len = len(item['word_starts'])
195
+ word_starts[i, :word_len] = torch.tensor(item['word_starts'])
196
+ word_mask[i, :word_len] = True
197
+ heads[i, :word_len] = torch.tensor(item['heads'])
198
+ rels[i, :word_len] = torch.tensor(item['rels'])
199
+
200
+ return {
201
+ 'input_ids': input_ids,
202
+ 'attention_mask': attention_mask,
203
+ 'word_starts': word_starts,
204
+ 'word_mask': word_mask,
205
+ 'heads': heads,
206
+ 'rels': rels,
207
+ }
208
+
209
+
210
+ # ============================================================================
211
+ # Training
212
+ # ============================================================================
213
+
214
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
215
+ """Create scheduler with linear warmup and linear decay."""
216
+ def lr_lambda(current_step):
217
+ if current_step < num_warmup_steps:
218
+ return float(current_step) / float(max(1, num_warmup_steps))
219
+ return max(
220
+ 0.0,
221
+ float(num_training_steps - current_step) /
222
+ float(max(1, num_training_steps - num_warmup_steps))
223
+ )
224
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
225
+
226
+
227
+ def evaluate(model, dataloader, device):
228
+ """Evaluate model and return UAS/LAS."""
229
+ model.eval()
230
+
231
+ total_arcs = 0
232
+ correct_arcs = 0
233
+ correct_rels = 0
234
+
235
+ with torch.no_grad():
236
+ for batch in dataloader:
237
+ input_ids = batch['input_ids'].to(device)
238
+ attention_mask = batch['attention_mask'].to(device)
239
+ word_starts = batch['word_starts'].to(device)
240
+ word_mask = batch['word_mask'].to(device)
241
+ heads = batch['heads'].to(device)
242
+ rels = batch['rels'].to(device)
243
+
244
+ arc_scores, rel_scores = model(
245
+ input_ids, attention_mask, word_starts, word_mask
246
+ )
247
+ arc_preds, rel_preds = model.decode(arc_scores, rel_scores, word_mask)
248
+
249
+ # Count correct
250
+ arc_correct = (arc_preds == heads) & word_mask
251
+ rel_correct = (rel_preds == rels) & word_mask & arc_correct
252
+
253
+ total_arcs += word_mask.sum().item()
254
+ correct_arcs += arc_correct.sum().item()
255
+ correct_rels += rel_correct.sum().item()
256
+
257
+ uas = correct_arcs / total_arcs * 100 if total_arcs > 0 else 0
258
+ las = correct_rels / total_arcs * 100 if total_arcs > 0 else 0
259
+
260
+ return uas, las
261
+
262
+
263
+ @click.command()
264
+ @click.option('--output', '-o', default='models/bamboo-1-phobert', help='Output directory')
265
+ @click.option('--encoder', default='vinai/phobert-base', help='PhoBERT encoder model')
266
+ @click.option('--dataset', type=click.Choice(['udd1', 'ud-vtb']), default='udd1',
267
+ help='Dataset to use: udd1 (UDD-1) or ud-vtb (UD Vietnamese VTB)')
268
+ @click.option('--epochs', default=100, type=int, help='Number of epochs')
269
+ @click.option('--batch-size', default=32, type=int, help='Batch size')
270
+ @click.option('--bert-lr', default=1e-5, type=float, help='Learning rate for BERT layers')
271
+ @click.option('--head-lr', default=1e-3, type=float, help='Learning rate for parser head')
272
+ @click.option('--warmup-steps', default=1000, type=int, help='Warmup steps')
273
+ @click.option('--weight-decay', default=0.01, type=float, help='Weight decay')
274
+ @click.option('--max-grad-norm', default=5.0, type=float, help='Max gradient norm for clipping')
275
+ @click.option('--arc-hidden', default=500, type=int, help='Arc MLP hidden size')
276
+ @click.option('--rel-hidden', default=100, type=int, help='Relation MLP hidden size')
277
+ @click.option('--dropout', default=0.33, type=float, help='Dropout rate')
278
+ @click.option('--patience', default=10, type=int, help='Early stopping patience')
279
+ @click.option('--use-mst/--no-mst', default=True, help='Use MST decoding')
280
+ @click.option('--force-download', is_flag=True, help='Force re-download dataset')
281
+ @click.option('--gpu-type', default='RTX_A4000', help='GPU type for cost estimation')
282
+ @click.option('--cost-interval', default=300, type=int, help='Cost report interval in seconds')
283
+ @click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging')
284
+ @click.option('--wandb-project', default='bamboo-1-phobert', help='W&B project name')
285
+ @click.option('--max-time', default=0, type=int, help='Max training time in minutes (0=unlimited)')
286
+ @click.option('--sample', default=0, type=int, help='Sample N sentences from each split (0=all)')
287
+ @click.option('--freeze-bert', default=0, type=int, help='Freeze BERT for first N epochs')
288
+ @click.option('--fp16/--no-fp16', default=True, help='Use mixed precision training (FP16)')
289
+ @click.option('--num-workers', default=4, type=int, help='DataLoader workers')
290
+ @click.option('--grad-accum', default=1, type=int, help='Gradient accumulation steps')
291
+ def train(
292
+ output, encoder, dataset, epochs, batch_size, bert_lr, head_lr, warmup_steps,
293
+ weight_decay, max_grad_norm, arc_hidden, rel_hidden, dropout, patience,
294
+ use_mst, force_download, gpu_type, cost_interval, use_wandb, wandb_project,
295
+ max_time, sample, freeze_bert, fp16, num_workers, grad_accum
296
+ ):
297
+ """Train PhoBERT-based Vietnamese Dependency Parser."""
298
+
299
+ # Detect hardware
300
+ hardware = detect_hardware()
301
+ detected_gpu_type = hardware.get_gpu_type()
302
+
303
+ if gpu_type == "RTX_A4000":
304
+ gpu_type = detected_gpu_type
305
+
306
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
307
+ click.echo(f"Using device: {device}")
308
+ click.echo(f"Hardware: {hardware}")
309
+
310
+ # Mixed precision training
311
+ use_amp = fp16 and torch.cuda.is_available()
312
+ scaler = torch.cuda.amp.GradScaler() if use_amp else None
313
+ if use_amp:
314
+ click.echo(f"Mixed precision (FP16): enabled")
315
+
316
+ # Initialize wandb
317
+ if use_wandb:
318
+ import wandb
319
+ wandb.init(
320
+ project=wandb_project,
321
+ config={
322
+ "encoder": encoder,
323
+ "dataset": dataset,
324
+ "epochs": epochs,
325
+ "batch_size": batch_size,
326
+ "bert_lr": bert_lr,
327
+ "head_lr": head_lr,
328
+ "warmup_steps": warmup_steps,
329
+ "weight_decay": weight_decay,
330
+ "arc_hidden": arc_hidden,
331
+ "rel_hidden": rel_hidden,
332
+ "dropout": dropout,
333
+ "patience": patience,
334
+ "use_mst": use_mst,
335
+ "gpu_type": gpu_type,
336
+ "hardware": hardware.to_dict(),
337
+ }
338
+ )
339
+ click.echo(f"W&B logging enabled: {wandb.run.url}")
340
+
341
+ click.echo("=" * 60)
342
+ click.echo("Bamboo-1: PhoBERT Vietnamese Dependency Parser")
343
+ click.echo("=" * 60)
344
+
345
+ # Load corpus
346
+ click.echo(f"\nLoading {dataset.upper()} corpus...")
347
+ if dataset == 'udd1':
348
+ corpus = UDD1Corpus(force_download=force_download)
349
+ else:
350
+ corpus = UDVietnameseVTB(force_download=force_download)
351
+
352
+ train_sents = read_conllu(corpus.train)
353
+ dev_sents = read_conllu(corpus.dev)
354
+ test_sents = read_conllu(corpus.test)
355
+
356
+ if sample > 0:
357
+ train_sents = train_sents[:sample]
358
+ dev_sents = dev_sents[:min(sample // 2, len(dev_sents))]
359
+ test_sents = test_sents[:min(sample // 2, len(test_sents))]
360
+ click.echo(f" Sampling {sample} sentences...")
361
+
362
+ click.echo(f" Train: {len(train_sents)} sentences")
363
+ click.echo(f" Dev: {len(dev_sents)} sentences")
364
+ click.echo(f" Test: {len(test_sents)} sentences")
365
+
366
+ # Build vocabulary
367
+ click.echo("\nBuilding vocabulary...")
368
+ vocab = Vocabulary()
369
+ vocab.build(train_sents)
370
+ click.echo(f" Relations: {vocab.n_rels}")
371
+
372
+ # Create model
373
+ click.echo(f"\nInitializing model with {encoder}...")
374
+ model = PhoBERTDependencyParser(
375
+ encoder_name=encoder,
376
+ n_rels=vocab.n_rels,
377
+ arc_hidden=arc_hidden,
378
+ rel_hidden=rel_hidden,
379
+ dropout=dropout,
380
+ use_mst=use_mst,
381
+ ).to(device)
382
+
383
+ # Set relation mappings
384
+ model.rel2idx = vocab.rel2idx
385
+ model.idx2rel = vocab.idx2rel
386
+
387
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
388
+ n_bert_params = sum(p.numel() for p in model.encoder.parameters() if p.requires_grad)
389
+ n_head_params = n_params - n_bert_params
390
+ click.echo(f" Total parameters: {n_params:,}")
391
+ click.echo(f" BERT parameters: {n_bert_params:,}")
392
+ click.echo(f" Head parameters: {n_head_params:,}")
393
+
394
+ # Create datasets
395
+ train_dataset = PhoBERTDependencyDataset(train_sents, vocab, model.tokenizer)
396
+ dev_dataset = PhoBERTDependencyDataset(dev_sents, vocab, model.tokenizer)
397
+ test_dataset = PhoBERTDependencyDataset(test_sents, vocab, model.tokenizer)
398
+
399
+ # DataLoader with optimizations
400
+ loader_kwargs = {
401
+ 'collate_fn': collate_fn,
402
+ 'num_workers': num_workers,
403
+ 'pin_memory': torch.cuda.is_available(),
404
+ 'persistent_workers': num_workers > 0,
405
+ }
406
+ train_loader = DataLoader(
407
+ train_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs
408
+ )
409
+ dev_loader = DataLoader(
410
+ dev_dataset, batch_size=batch_size, **loader_kwargs
411
+ )
412
+ test_loader = DataLoader(
413
+ test_dataset, batch_size=batch_size, **loader_kwargs
414
+ )
415
+
416
+ # Effective batch size with gradient accumulation
417
+ effective_batch_size = batch_size * grad_accum
418
+ if grad_accum > 1:
419
+ click.echo(f" Effective batch size: {effective_batch_size} (batch={batch_size} x accum={grad_accum})")
420
+
421
+ # Optimizer with differential learning rates
422
+ no_decay = ['bias', 'LayerNorm.weight', 'LayerNorm.bias']
423
+ optimizer_grouped_parameters = [
424
+ # BERT parameters with weight decay
425
+ {
426
+ 'params': [p for n, p in model.encoder.named_parameters()
427
+ if not any(nd in n for nd in no_decay)],
428
+ 'lr': bert_lr,
429
+ 'weight_decay': weight_decay,
430
+ },
431
+ # BERT parameters without weight decay
432
+ {
433
+ 'params': [p for n, p in model.encoder.named_parameters()
434
+ if any(nd in n for nd in no_decay)],
435
+ 'lr': bert_lr,
436
+ 'weight_decay': 0.0,
437
+ },
438
+ # Parser head parameters
439
+ {
440
+ 'params': [p for n, p in model.named_parameters()
441
+ if not n.startswith('encoder.')],
442
+ 'lr': head_lr,
443
+ 'weight_decay': weight_decay,
444
+ },
445
+ ]
446
+ optimizer = AdamW(optimizer_grouped_parameters)
447
+
448
+ # Learning rate scheduler with warmup
449
+ total_steps = len(train_loader) * epochs
450
+ scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)
451
+
452
+ # Training
453
+ click.echo(f"\nTraining for {epochs} epochs...")
454
+ if freeze_bert > 0:
455
+ click.echo(f" Freezing BERT for first {freeze_bert} epochs")
456
+ if max_time > 0:
457
+ click.echo(f" Time limit: {max_time} minutes")
458
+
459
+ output_path = Path(output)
460
+ output_path.mkdir(parents=True, exist_ok=True)
461
+
462
+ # Cost tracking
463
+ cost_tracker = CostTracker(gpu_type=gpu_type)
464
+ cost_tracker.report_interval = cost_interval
465
+ cost_tracker.start()
466
+ click.echo(f"Cost tracking: {gpu_type} @ ${cost_tracker.hourly_rate}/hr")
467
+
468
+ best_las = -1
469
+ no_improve = 0
470
+ time_limit_seconds = max_time * 60 if max_time > 0 else float('inf')
471
+
472
+ for epoch in range(1, epochs + 1):
473
+ # Check time limit
474
+ if cost_tracker.elapsed_seconds() >= time_limit_seconds:
475
+ click.echo(f"\nTime limit reached ({max_time} minutes)")
476
+ break
477
+
478
+ # Freeze/unfreeze BERT
479
+ if epoch <= freeze_bert:
480
+ for param in model.encoder.parameters():
481
+ param.requires_grad = False
482
+ elif epoch == freeze_bert + 1:
483
+ click.echo(" Unfreezing BERT parameters...")
484
+ for param in model.encoder.parameters():
485
+ param.requires_grad = True
486
+
487
+ model.train()
488
+ total_loss = 0
489
+ optimizer.zero_grad()
490
+
491
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}", leave=False)
492
+ for step, batch in enumerate(pbar):
493
+ input_ids = batch['input_ids'].to(device, non_blocking=True)
494
+ attention_mask = batch['attention_mask'].to(device, non_blocking=True)
495
+ word_starts = batch['word_starts'].to(device, non_blocking=True)
496
+ word_mask = batch['word_mask'].to(device, non_blocking=True)
497
+ heads = batch['heads'].to(device, non_blocking=True)
498
+ rels = batch['rels'].to(device, non_blocking=True)
499
+
500
+ # Mixed precision forward pass
501
+ with torch.cuda.amp.autocast(enabled=use_amp):
502
+ arc_scores, rel_scores = model(
503
+ input_ids, attention_mask, word_starts, word_mask
504
+ )
505
+ loss = model.loss(arc_scores, rel_scores, heads, rels, word_mask)
506
+ loss = loss / grad_accum # Scale for gradient accumulation
507
+
508
+ # Backward pass with gradient scaling
509
+ if use_amp:
510
+ scaler.scale(loss).backward()
511
+ else:
512
+ loss.backward()
513
+
514
+ # Optimizer step (every grad_accum steps)
515
+ if (step + 1) % grad_accum == 0 or (step + 1) == len(train_loader):
516
+ if use_amp:
517
+ scaler.unscale_(optimizer)
518
+ nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
519
+ scaler.step(optimizer)
520
+ scaler.update()
521
+ else:
522
+ nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
523
+ optimizer.step()
524
+ scheduler.step()
525
+ optimizer.zero_grad()
526
+
527
+ total_loss += loss.item() * grad_accum
528
+ pbar.set_postfix({'loss': f'{loss.item() * grad_accum:.4f}'})
529
+
530
+ # Evaluate
531
+ dev_uas, dev_las = evaluate(model, dev_loader, device)
532
+
533
+ # Cost update
534
+ progress = epoch / epochs
535
+ current_cost = cost_tracker.current_cost()
536
+ estimated_total_cost = cost_tracker.estimate_total_cost(progress)
537
+ elapsed_minutes = cost_tracker.elapsed_seconds() / 60
538
+
539
+ cost_status = cost_tracker.update(epoch, epochs)
540
+ if cost_status:
541
+ click.echo(f" [{cost_status}]")
542
+
543
+ avg_loss = total_loss / len(train_loader)
544
+ current_lr = scheduler.get_last_lr()[0]
545
+ click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
546
+ f"Dev UAS: {dev_uas:.2f}% | Dev LAS: {dev_las:.2f}% | "
547
+ f"LR: {current_lr:.2e}")
548
+
549
+ # Log to wandb
550
+ if use_wandb:
551
+ wandb.log({
552
+ "epoch": epoch,
553
+ "train/loss": avg_loss,
554
+ "dev/uas": dev_uas,
555
+ "dev/las": dev_las,
556
+ "lr": current_lr,
557
+ "cost/current_usd": current_cost,
558
+ "cost/estimated_total_usd": estimated_total_cost,
559
+ "cost/elapsed_minutes": elapsed_minutes,
560
+ })
561
+
562
+ # Save best model
563
+ if dev_las >= best_las:
564
+ best_las = dev_las
565
+ no_improve = 0
566
+ model.save(
567
+ str(output_path),
568
+ vocab={'rel2idx': vocab.rel2idx, 'idx2rel': vocab.idx2rel}
569
+ )
570
+ click.echo(f" -> Saved best model (LAS: {best_las:.2f}%)")
571
+ else:
572
+ no_improve += 1
573
+ if no_improve >= patience:
574
+ click.echo(f"\nEarly stopping after {patience} epochs without improvement")
575
+ break
576
+
577
+ # Final evaluation
578
+ click.echo("\nLoading best model for final evaluation...")
579
+ model = PhoBERTDependencyParser.load(str(output_path), device=str(device))
580
+
581
+ test_uas, test_las = evaluate(model, test_loader, device)
582
+ click.echo(f"\nTest Results:")
583
+ click.echo(f" UAS: {test_uas:.2f}%")
584
+ click.echo(f" LAS: {test_las:.2f}%")
585
+
586
+ click.echo(f"\nModel saved to: {output_path}")
587
+
588
+ # Final cost summary
589
+ final_cost = cost_tracker.current_cost()
590
+ click.echo(f"\n{cost_tracker.summary(epoch, epochs)}")
591
+
592
+ # Log final metrics to wandb
593
+ if use_wandb:
594
+ wandb.log({
595
+ "test/uas": test_uas,
596
+ "test/las": test_las,
597
+ "cost/final_usd": final_cost,
598
+ })
599
+ wandb.finish()
600
+
601
+
602
+ if __name__ == '__main__':
603
+ train()
uv.lock CHANGED
@@ -22,6 +22,19 @@ resolution-markers = [
22
  "python_full_version < '3.11' and sys_platform != 'linux'",
23
  ]
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  [[package]]
26
  name = "aiodns"
27
  version = "4.0.0"
@@ -367,12 +380,18 @@ source = { editable = "." }
367
  dependencies = [
368
  { name = "click" },
369
  { name = "datasets" },
 
 
370
  { name = "torch" },
 
371
  { name = "transformers" },
372
  { name = "underthesea" },
373
  ]
374
 
375
  [package.optional-dependencies]
 
 
 
376
  cloud = [
377
  { name = "runpod" },
378
  ]
@@ -383,16 +402,19 @@ dev = [
383
 
384
  [package.metadata]
385
  requires-dist = [
 
386
  { name = "click", specifier = ">=8.0.0" },
387
  { name = "datasets", specifier = ">=2.14.0" },
 
388
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" },
389
  { name = "runpod", marker = "extra == 'cloud'", specifier = ">=1.6.0" },
390
  { name = "torch", specifier = ">=2.0.0" },
391
- { name = "transformers", specifier = ">=5.0.0" },
 
392
  { name = "underthesea", specifier = ">=9.2.0" },
393
  { name = "wandb", marker = "extra == 'dev'", specifier = ">=0.15.0" },
394
  ]
395
- provides-extras = ["dev", "cloud"]
396
 
397
  [[package]]
398
  name = "bcrypt"
@@ -1384,23 +1406,21 @@ wheels = [
1384
 
1385
  [[package]]
1386
  name = "huggingface-hub"
1387
- version = "1.3.5"
1388
  source = { registry = "https://pypi.org/simple" }
1389
  dependencies = [
1390
  { name = "filelock" },
1391
  { name = "fsspec" },
1392
- { name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" },
1393
- { name = "httpx" },
1394
  { name = "packaging" },
1395
  { name = "pyyaml" },
1396
- { name = "shellingham" },
1397
  { name = "tqdm" },
1398
- { name = "typer-slim" },
1399
  { name = "typing-extensions" },
1400
  ]
1401
- sdist = { url = "https://files.pythonhosted.org/packages/67/e9/2658cb9bc4c72a67b7f87650e827266139befaf499095883d30dabc4d49f/huggingface_hub-1.3.5.tar.gz", hash = "sha256:8045aca8ddab35d937138f3c386c6d43a275f53437c5c64cdc9aa8408653b4ed", size = 627456, upload-time = "2026-01-29T10:34:19.687Z" }
1402
  wheels = [
1403
- { url = "https://files.pythonhosted.org/packages/f9/84/a579b95c46fe8e319f89dc700c087596f665141575f4dcf136aaa97d856f/huggingface_hub-1.3.5-py3-none-any.whl", hash = "sha256:fe332d7f86a8af874768452295c22cd3f37730fb2463cf6cc3295e26036f8ef9", size = 536675, upload-time = "2026-01-29T10:34:17.713Z" },
1404
  ]
1405
 
1406
  [[package]]
@@ -3763,32 +3783,27 @@ wheels = [
3763
 
3764
  [[package]]
3765
  name = "tokenizers"
3766
- version = "0.22.2"
3767
  source = { registry = "https://pypi.org/simple" }
3768
  dependencies = [
3769
  { name = "huggingface-hub" },
3770
  ]
3771
- sdist = { url = "https://files.pythonhosted.org/packages/73/6f/f80cfef4a312e1fb34baf7d85c72d4411afde10978d4657f8cdd811d3ccc/tokenizers-0.22.2.tar.gz", hash = "sha256:473b83b915e547aa366d1eee11806deaf419e17be16310ac0a14077f1e28f917", size = 372115, upload-time = "2026-01-05T10:45:15.988Z" }
3772
- wheels = [
3773
- { url = "https://files.pythonhosted.org/packages/92/97/5dbfabf04c7e348e655e907ed27913e03db0923abb5dfdd120d7b25630e1/tokenizers-0.22.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:544dd704ae7238755d790de45ba8da072e9af3eea688f698b137915ae959281c", size = 3100275, upload-time = "2026-01-05T10:41:02.158Z" },
3774
- { url = "https://files.pythonhosted.org/packages/2e/47/174dca0502ef88b28f1c9e06b73ce33500eedfac7a7692108aec220464e7/tokenizers-0.22.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:1e418a55456beedca4621dbab65a318981467a2b188e982a23e117f115ce5001", size = 2981472, upload-time = "2026-01-05T10:41:00.276Z" },
3775
- { url = "https://files.pythonhosted.org/packages/d6/84/7990e799f1309a8b87af6b948f31edaa12a3ed22d11b352eaf4f4b2e5753/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2249487018adec45d6e3554c71d46eb39fa8ea67156c640f7513eb26f318cec7", size = 3290736, upload-time = "2026-01-05T10:40:32.165Z" },
3776
- { url = "https://files.pythonhosted.org/packages/78/59/09d0d9ba94dcd5f4f1368d4858d24546b4bdc0231c2354aa31d6199f0399/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25b85325d0815e86e0bac263506dd114578953b7b53d7de09a6485e4a160a7dd", size = 3168835, upload-time = "2026-01-05T10:40:38.847Z" },
3777
- { url = "https://files.pythonhosted.org/packages/47/50/b3ebb4243e7160bda8d34b731e54dd8ab8b133e50775872e7a434e524c28/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bfb88f22a209ff7b40a576d5324bf8286b519d7358663db21d6246fb17eea2d5", size = 3521673, upload-time = "2026-01-05T10:40:56.614Z" },
3778
- { url = "https://files.pythonhosted.org/packages/e0/fa/89f4cb9e08df770b57adb96f8cbb7e22695a4cb6c2bd5f0c4f0ebcf33b66/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c774b1276f71e1ef716e5486f21e76333464f47bece56bbd554485982a9e03e", size = 3724818, upload-time = "2026-01-05T10:40:44.507Z" },
3779
- { url = "https://files.pythonhosted.org/packages/64/04/ca2363f0bfbe3b3d36e95bf67e56a4c88c8e3362b658e616d1ac185d47f2/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:df6c4265b289083bf710dff49bc51ef252f9d5be33a45ee2bed151114a56207b", size = 3379195, upload-time = "2026-01-05T10:40:51.139Z" },
3780
- { url = "https://files.pythonhosted.org/packages/2e/76/932be4b50ef6ccedf9d3c6639b056a967a86258c6d9200643f01269211ca/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:369cc9fc8cc10cb24143873a0d95438bb8ee257bb80c71989e3ee290e8d72c67", size = 3274982, upload-time = "2026-01-05T10:40:58.331Z" },
3781
- { url = "https://files.pythonhosted.org/packages/1d/28/5f9f5a4cc211b69e89420980e483831bcc29dade307955cc9dc858a40f01/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:29c30b83d8dcd061078b05ae0cb94d3c710555fbb44861139f9f83dcca3dc3e4", size = 9478245, upload-time = "2026-01-05T10:41:04.053Z" },
3782
- { url = "https://files.pythonhosted.org/packages/6c/fb/66e2da4704d6aadebf8cb39f1d6d1957df667ab24cff2326b77cda0dcb85/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:37ae80a28c1d3265bb1f22464c856bd23c02a05bb211e56d0c5301a435be6c1a", size = 9560069, upload-time = "2026-01-05T10:45:10.673Z" },
3783
- { url = "https://files.pythonhosted.org/packages/16/04/fed398b05caa87ce9b1a1bb5166645e38196081b225059a6edaff6440fac/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:791135ee325f2336f498590eb2f11dc5c295232f288e75c99a36c5dbce63088a", size = 9899263, upload-time = "2026-01-05T10:45:12.559Z" },
3784
- { url = "https://files.pythonhosted.org/packages/05/a1/d62dfe7376beaaf1394917e0f8e93ee5f67fea8fcf4107501db35996586b/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:38337540fbbddff8e999d59970f3c6f35a82de10053206a7562f1ea02d046fa5", size = 10033429, upload-time = "2026-01-05T10:45:14.333Z" },
3785
- { url = "https://files.pythonhosted.org/packages/fd/18/a545c4ea42af3df6effd7d13d250ba77a0a86fb20393143bbb9a92e434d4/tokenizers-0.22.2-cp39-abi3-win32.whl", hash = "sha256:a6bf3f88c554a2b653af81f3204491c818ae2ac6fbc09e76ef4773351292bc92", size = 2502363, upload-time = "2026-01-05T10:45:20.593Z" },
3786
- { url = "https://files.pythonhosted.org/packages/65/71/0670843133a43d43070abeb1949abfdef12a86d490bea9cd9e18e37c5ff7/tokenizers-0.22.2-cp39-abi3-win_amd64.whl", hash = "sha256:c9ea31edff2968b44a88f97d784c2f16dc0729b8b143ed004699ebca91f05c48", size = 2747786, upload-time = "2026-01-05T10:45:18.411Z" },
3787
- { url = "https://files.pythonhosted.org/packages/72/f4/0de46cfa12cdcbcd464cc59fde36912af405696f687e53a091fb432f694c/tokenizers-0.22.2-cp39-abi3-win_arm64.whl", hash = "sha256:9ce725d22864a1e965217204946f830c37876eee3b2ba6fc6255e8e903d5fcbc", size = 2612133, upload-time = "2026-01-05T10:45:17.232Z" },
3788
- { url = "https://files.pythonhosted.org/packages/84/04/655b79dbcc9b3ac5f1479f18e931a344af67e5b7d3b251d2dcdcd7558592/tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:753d47ebd4542742ef9261d9da92cd545b2cacbb48349a1225466745bb866ec4", size = 3282301, upload-time = "2026-01-05T10:40:34.858Z" },
3789
- { url = "https://files.pythonhosted.org/packages/46/cd/e4851401f3d8f6f45d8480262ab6a5c8cb9c4302a790a35aa14eeed6d2fd/tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e10bf9113d209be7cd046d40fbabbaf3278ff6d18eb4da4c500443185dc1896c", size = 3161308, upload-time = "2026-01-05T10:40:40.737Z" },
3790
- { url = "https://files.pythonhosted.org/packages/6f/6e/55553992a89982cd12d4a66dddb5e02126c58677ea3931efcbe601d419db/tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64d94e84f6660764e64e7e0b22baa72f6cd942279fdbb21d46abd70d179f0195", size = 3718964, upload-time = "2026-01-05T10:40:46.56Z" },
3791
- { url = "https://files.pythonhosted.org/packages/59/8c/b1c87148aa15e099243ec9f0cf9d0e970cc2234c3257d558c25a2c5304e6/tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f01a9c019878532f98927d2bacb79bbb404b43d3437455522a00a30718cdedb5", size = 3373542, upload-time = "2026-01-05T10:40:52.803Z" },
3792
  ]
3793
 
3794
  [[package]]
@@ -3942,7 +3957,7 @@ wheels = [
3942
 
3943
  [[package]]
3944
  name = "transformers"
3945
- version = "5.0.0"
3946
  source = { registry = "https://pypi.org/simple" }
3947
  dependencies = [
3948
  { name = "filelock" },
@@ -3952,14 +3967,14 @@ dependencies = [
3952
  { name = "packaging" },
3953
  { name = "pyyaml" },
3954
  { name = "regex" },
 
3955
  { name = "safetensors" },
3956
  { name = "tokenizers" },
3957
  { name = "tqdm" },
3958
- { name = "typer-slim" },
3959
  ]
3960
- sdist = { url = "https://files.pythonhosted.org/packages/bc/79/845941711811789c85fb7e2599cea425a14a07eda40f50896b9d3fda7492/transformers-5.0.0.tar.gz", hash = "sha256:5f5634efed6cf76ad068cc5834c7adbc32db78bbd6211fb70df2325a9c37dec8", size = 8424830, upload-time = "2026-01-26T10:46:46.813Z" }
3961
  wheels = [
3962
- { url = "https://files.pythonhosted.org/packages/52/f3/ac976fa8e305c9e49772527e09fbdc27cc6831b8a2f6b6063406626be5dd/transformers-5.0.0-py3-none-any.whl", hash = "sha256:587086f249ce64c817213cf36afdb318d087f790723e9b3d4500b97832afd52d", size = 10142091, upload-time = "2026-01-26T10:46:43.88Z" },
3963
  ]
3964
 
3965
  [[package]]
@@ -3991,19 +4006,6 @@ wheels = [
3991
  { url = "https://files.pythonhosted.org/packages/a0/1d/d9257dd49ff2ca23ea5f132edf1281a0c4f9de8a762b9ae399b670a59235/typer-0.21.1-py3-none-any.whl", hash = "sha256:7985e89081c636b88d172c2ee0cfe33c253160994d47bdfdc302defd7d1f1d01", size = 47381, upload-time = "2026-01-06T11:21:09.824Z" },
3992
  ]
3993
 
3994
- [[package]]
3995
- name = "typer-slim"
3996
- version = "0.21.1"
3997
- source = { registry = "https://pypi.org/simple" }
3998
- dependencies = [
3999
- { name = "click" },
4000
- { name = "typing-extensions" },
4001
- ]
4002
- sdist = { url = "https://files.pythonhosted.org/packages/17/d4/064570dec6358aa9049d4708e4a10407d74c99258f8b2136bb8702303f1a/typer_slim-0.21.1.tar.gz", hash = "sha256:73495dd08c2d0940d611c5a8c04e91c2a0a98600cbd4ee19192255a233b6dbfd", size = 110478, upload-time = "2026-01-06T11:21:11.176Z" }
4003
- wheels = [
4004
- { url = "https://files.pythonhosted.org/packages/c8/0a/4aca634faf693e33004796b6cee0ae2e1dba375a800c16ab8d3eff4bb800/typer_slim-0.21.1-py3-none-any.whl", hash = "sha256:6e6c31047f171ac93cc5a973c9e617dbc5ab2bddc4d0a3135dc161b4e2020e0d", size = 47444, upload-time = "2026-01-06T11:21:12.441Z" },
4005
- ]
4006
-
4007
  [[package]]
4008
  name = "typing-extensions"
4009
  version = "4.15.0"
 
22
  "python_full_version < '3.11' and sys_platform != 'linux'",
23
  ]
24
 
25
+ [[package]]
26
+ name = "adapters"
27
+ version = "1.2.0"
28
+ source = { registry = "https://pypi.org/simple" }
29
+ dependencies = [
30
+ { name = "packaging" },
31
+ { name = "transformers" },
32
+ ]
33
+ sdist = { url = "https://files.pythonhosted.org/packages/5f/c7/96580e5b7417b0838bd3e41a416939be63a549f22cfe0bcf8cdc62fd2ed8/adapters-1.2.0.tar.gz", hash = "sha256:40db5c5e0789603859980229f7acbae51168abf1999efdb65e5a7778e81a104e", size = 226695, upload-time = "2025-05-20T19:27:07.202Z" }
34
+ wheels = [
35
+ { url = "https://files.pythonhosted.org/packages/dc/e5/91cb0ea212558443b3d62e2b8d8537647549b9c6d34d613847a9fb2fcc58/adapters-1.2.0-py3-none-any.whl", hash = "sha256:fa55ddd9a99577ad0bacb16bebd0a26b6c5db2eae8730b5bfe4403eb917f2e22", size = 302180, upload-time = "2025-05-20T19:27:05.323Z" },
36
+ ]
37
+
38
  [[package]]
39
  name = "aiodns"
40
  version = "4.0.0"
 
380
  dependencies = [
381
  { name = "click" },
382
  { name = "datasets" },
383
+ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
384
+ { name = "numpy", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
385
  { name = "torch" },
386
+ { name = "tqdm" },
387
  { name = "transformers" },
388
  { name = "underthesea" },
389
  ]
390
 
391
  [package.optional-dependencies]
392
+ adapters = [
393
+ { name = "adapters" },
394
+ ]
395
  cloud = [
396
  { name = "runpod" },
397
  ]
 
402
 
403
  [package.metadata]
404
  requires-dist = [
405
+ { name = "adapters", marker = "extra == 'adapters'", specifier = ">=0.1.0" },
406
  { name = "click", specifier = ">=8.0.0" },
407
  { name = "datasets", specifier = ">=2.14.0" },
408
+ { name = "numpy", specifier = ">=1.24.0" },
409
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" },
410
  { name = "runpod", marker = "extra == 'cloud'", specifier = ">=1.6.0" },
411
  { name = "torch", specifier = ">=2.0.0" },
412
+ { name = "tqdm", specifier = ">=4.60.0" },
413
+ { name = "transformers", specifier = ">=4.30.0" },
414
  { name = "underthesea", specifier = ">=9.2.0" },
415
  { name = "wandb", marker = "extra == 'dev'", specifier = ">=0.15.0" },
416
  ]
417
+ provides-extras = ["dev", "cloud", "adapters"]
418
 
419
  [[package]]
420
  name = "bcrypt"
 
1406
 
1407
  [[package]]
1408
  name = "huggingface-hub"
1409
+ version = "0.36.0"
1410
  source = { registry = "https://pypi.org/simple" }
1411
  dependencies = [
1412
  { name = "filelock" },
1413
  { name = "fsspec" },
1414
+ { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" },
 
1415
  { name = "packaging" },
1416
  { name = "pyyaml" },
1417
+ { name = "requests" },
1418
  { name = "tqdm" },
 
1419
  { name = "typing-extensions" },
1420
  ]
1421
+ sdist = { url = "https://files.pythonhosted.org/packages/98/63/4910c5fa9128fdadf6a9c5ac138e8b1b6cee4ca44bf7915bbfbce4e355ee/huggingface_hub-0.36.0.tar.gz", hash = "sha256:47b3f0e2539c39bf5cde015d63b72ec49baff67b6931c3d97f3f84532e2b8d25", size = 463358, upload-time = "2025-10-23T12:12:01.413Z" }
1422
  wheels = [
1423
+ { url = "https://files.pythonhosted.org/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl", hash = "sha256:7bcc9ad17d5b3f07b57c78e79d527102d08313caa278a641993acddcb894548d", size = 566094, upload-time = "2025-10-23T12:11:59.557Z" },
1424
  ]
1425
 
1426
  [[package]]
 
3783
 
3784
  [[package]]
3785
  name = "tokenizers"
3786
+ version = "0.21.4"
3787
  source = { registry = "https://pypi.org/simple" }
3788
  dependencies = [
3789
  { name = "huggingface-hub" },
3790
  ]
3791
+ sdist = { url = "https://files.pythonhosted.org/packages/c2/2f/402986d0823f8d7ca139d969af2917fefaa9b947d1fb32f6168c509f2492/tokenizers-0.21.4.tar.gz", hash = "sha256:fa23f85fbc9a02ec5c6978da172cdcbac23498c3ca9f3645c5c68740ac007880", size = 351253, upload-time = "2025-07-28T15:48:54.325Z" }
3792
+ wheels = [
3793
+ { url = "https://files.pythonhosted.org/packages/98/c6/fdb6f72bf6454f52eb4a2510be7fb0f614e541a2554d6210e370d85efff4/tokenizers-0.21.4-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:2ccc10a7c3bcefe0f242867dc914fc1226ee44321eb618cfe3019b5df3400133", size = 2863987, upload-time = "2025-07-28T15:48:44.877Z" },
3794
+ { url = "https://files.pythonhosted.org/packages/8d/a6/28975479e35ddc751dc1ddc97b9b69bf7fcf074db31548aab37f8116674c/tokenizers-0.21.4-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:5e2f601a8e0cd5be5cc7506b20a79112370b9b3e9cb5f13f68ab11acd6ca7d60", size = 2732457, upload-time = "2025-07-28T15:48:43.265Z" },
3795
+ { url = "https://files.pythonhosted.org/packages/aa/8f/24f39d7b5c726b7b0be95dca04f344df278a3fe3a4deb15a975d194cbb32/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b376f5a1aee67b4d29032ee85511bbd1b99007ec735f7f35c8a2eb104eade5", size = 3012624, upload-time = "2025-07-28T13:22:43.895Z" },
3796
+ { url = "https://files.pythonhosted.org/packages/58/47/26358925717687a58cb74d7a508de96649544fad5778f0cd9827398dc499/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2107ad649e2cda4488d41dfd031469e9da3fcbfd6183e74e4958fa729ffbf9c6", size = 2939681, upload-time = "2025-07-28T13:22:47.499Z" },
3797
+ { url = "https://files.pythonhosted.org/packages/99/6f/cc300fea5db2ab5ddc2c8aea5757a27b89c84469899710c3aeddc1d39801/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c73012da95afafdf235ba80047699df4384fdc481527448a078ffd00e45a7d9", size = 3247445, upload-time = "2025-07-28T15:48:39.711Z" },
3798
+ { url = "https://files.pythonhosted.org/packages/be/bf/98cb4b9c3c4afd8be89cfa6423704337dc20b73eb4180397a6e0d456c334/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f23186c40395fc390d27f519679a58023f368a0aad234af145e0f39ad1212732", size = 3428014, upload-time = "2025-07-28T13:22:49.569Z" },
3799
+ { url = "https://files.pythonhosted.org/packages/75/c7/96c1cc780e6ca7f01a57c13235dd05b7bc1c0f3588512ebe9d1331b5f5ae/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc88bb34e23a54cc42713d6d98af5f1bf79c07653d24fe984d2d695ba2c922a2", size = 3193197, upload-time = "2025-07-28T13:22:51.471Z" },
3800
+ { url = "https://files.pythonhosted.org/packages/f2/90/273b6c7ec78af547694eddeea9e05de771278bd20476525ab930cecaf7d8/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51b7eabb104f46c1c50b486520555715457ae833d5aee9ff6ae853d1130506ff", size = 3115426, upload-time = "2025-07-28T15:48:41.439Z" },
3801
+ { url = "https://files.pythonhosted.org/packages/91/43/c640d5a07e95f1cf9d2c92501f20a25f179ac53a4f71e1489a3dcfcc67ee/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:714b05b2e1af1288bd1bc56ce496c4cebb64a20d158ee802887757791191e6e2", size = 9089127, upload-time = "2025-07-28T15:48:46.472Z" },
3802
+ { url = "https://files.pythonhosted.org/packages/44/a1/dd23edd6271d4dca788e5200a807b49ec3e6987815cd9d0a07ad9c96c7c2/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:1340ff877ceedfa937544b7d79f5b7becf33a4cfb58f89b3b49927004ef66f78", size = 9055243, upload-time = "2025-07-28T15:48:48.539Z" },
3803
+ { url = "https://files.pythonhosted.org/packages/21/2b/b410d6e9021c4b7ddb57248304dc817c4d4970b73b6ee343674914701197/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:3c1f4317576e465ac9ef0d165b247825a2a4078bcd01cba6b54b867bdf9fdd8b", size = 9298237, upload-time = "2025-07-28T15:48:50.443Z" },
3804
+ { url = "https://files.pythonhosted.org/packages/b7/0a/42348c995c67e2e6e5c89ffb9cfd68507cbaeb84ff39c49ee6e0a6dd0fd2/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:c212aa4e45ec0bb5274b16b6f31dd3f1c41944025c2358faaa5782c754e84c24", size = 9461980, upload-time = "2025-07-28T15:48:52.325Z" },
3805
+ { url = "https://files.pythonhosted.org/packages/3d/d3/dacccd834404cd71b5c334882f3ba40331ad2120e69ded32cf5fda9a7436/tokenizers-0.21.4-cp39-abi3-win32.whl", hash = "sha256:6c42a930bc5f4c47f4ea775c91de47d27910881902b0f20e4990ebe045a415d0", size = 2329871, upload-time = "2025-07-28T15:48:56.841Z" },
3806
+ { url = "https://files.pythonhosted.org/packages/41/f2/fd673d979185f5dcbac4be7d09461cbb99751554ffb6718d0013af8604cb/tokenizers-0.21.4-cp39-abi3-win_amd64.whl", hash = "sha256:475d807a5c3eb72c59ad9b5fcdb254f6e17f53dfcbb9903233b0dfa9c943b597", size = 2507568, upload-time = "2025-07-28T15:48:55.456Z" },
 
 
 
 
 
3807
  ]
3808
 
3809
  [[package]]
 
3957
 
3958
  [[package]]
3959
  name = "transformers"
3960
+ version = "4.51.3"
3961
  source = { registry = "https://pypi.org/simple" }
3962
  dependencies = [
3963
  { name = "filelock" },
 
3967
  { name = "packaging" },
3968
  { name = "pyyaml" },
3969
  { name = "regex" },
3970
+ { name = "requests" },
3971
  { name = "safetensors" },
3972
  { name = "tokenizers" },
3973
  { name = "tqdm" },
 
3974
  ]
3975
+ sdist = { url = "https://files.pythonhosted.org/packages/f1/11/7414d5bc07690002ce4d7553602107bf969af85144bbd02830f9fb471236/transformers-4.51.3.tar.gz", hash = "sha256:e292fcab3990c6defe6328f0f7d2004283ca81a7a07b2de9a46d67fd81ea1409", size = 8941266, upload-time = "2025-04-14T08:15:00.485Z" }
3976
  wheels = [
3977
+ { url = "https://files.pythonhosted.org/packages/a9/b6/5257d04ae327b44db31f15cce39e6020cc986333c715660b1315a9724d82/transformers-4.51.3-py3-none-any.whl", hash = "sha256:fd3279633ceb2b777013234bbf0b4f5c2d23c4626b05497691f00cfda55e8a83", size = 10383940, upload-time = "2025-04-14T08:13:43.023Z" },
3978
  ]
3979
 
3980
  [[package]]
 
4006
  { url = "https://files.pythonhosted.org/packages/a0/1d/d9257dd49ff2ca23ea5f132edf1281a0c4f9de8a762b9ae399b670a59235/typer-0.21.1-py3-none-any.whl", hash = "sha256:7985e89081c636b88d172c2ee0cfe33c253160994d47bdfdc302defd7d1f1d01", size = 47381, upload-time = "2026-01-06T11:21:09.824Z" },
4007
  ]
4008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4009
  [[package]]
4010
  name = "typing-extensions"
4011
  version = "4.15.0"