| """Inference API for Bamboo-1 Vietnamese Dependency Parser.""" |
|
|
| import sys |
| from collections import Counter |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
| from huggingface_hub import hf_hub_download |
|
|
|
|
| |
| |
| |
|
|
| class Vocabulary: |
| """Vocabulary for words, characters, and relations.""" |
| PAD = '<pad>' |
| UNK = '<unk>' |
|
|
| def __init__(self, min_freq: int = 2): |
| self.min_freq = min_freq |
| self.word2idx = {self.PAD: 0, self.UNK: 1} |
| self.char2idx = {self.PAD: 0, self.UNK: 1} |
| self.rel2idx = {} |
| self.idx2rel = {} |
|
|
| def build(self, sentences): |
| """Build vocabulary from sentences.""" |
| word_counts = Counter() |
| char_counts = Counter() |
| rel_counts = Counter() |
|
|
| for sent in sentences: |
| for word in sent.words: |
| word_counts[word.lower()] += 1 |
| for char in word: |
| char_counts[char] += 1 |
| for rel in sent.rels: |
| rel_counts[rel] += 1 |
|
|
| for word, count in word_counts.items(): |
| if count >= self.min_freq and word not in self.word2idx: |
| self.word2idx[word] = len(self.word2idx) |
|
|
| for char, count in char_counts.items(): |
| if char not in self.char2idx: |
| self.char2idx[char] = len(self.char2idx) |
|
|
| for rel in rel_counts: |
| if rel not in self.rel2idx: |
| idx = len(self.rel2idx) |
| self.rel2idx[rel] = idx |
| self.idx2rel[idx] = rel |
|
|
| def encode_word(self, word: str) -> int: |
| return self.word2idx.get(word.lower(), self.word2idx[self.UNK]) |
|
|
| def encode_char(self, char: str) -> int: |
| return self.char2idx.get(char, self.char2idx[self.UNK]) |
|
|
| def encode_rel(self, rel: str) -> int: |
| return self.rel2idx.get(rel, 0) |
|
|
| @property |
| def n_words(self) -> int: |
| return len(self.word2idx) |
|
|
| @property |
| def n_chars(self) -> int: |
| return len(self.char2idx) |
|
|
| @property |
| def n_rels(self) -> int: |
| return len(self.rel2idx) |
|
|
|
|
| |
| |
| |
|
|
| class CharLSTM(nn.Module): |
| """Character-level LSTM embeddings.""" |
|
|
| def __init__(self, n_chars: int, char_dim: int = 50, hidden_dim: int = 100): |
| super().__init__() |
| self.embed = nn.Embedding(n_chars, char_dim, padding_idx=0) |
| self.lstm = nn.LSTM(char_dim, hidden_dim // 2, batch_first=True, bidirectional=True) |
| self.hidden_dim = hidden_dim |
|
|
| def forward(self, chars): |
| batch, seq_len, max_word_len = chars.shape |
| chars_flat = chars.view(-1, max_word_len) |
| word_lens = (chars_flat != 0).sum(dim=1).clamp(min=1) |
| char_embeds = self.embed(chars_flat) |
| packed = pack_padded_sequence(char_embeds, word_lens.cpu(), batch_first=True, enforce_sorted=False) |
| _, (hidden, _) = self.lstm(packed) |
| hidden = torch.cat([hidden[0], hidden[1]], dim=-1) |
| return hidden.view(batch, seq_len, self.hidden_dim) |
|
|
|
|
| class MLP(nn.Module): |
| """Multi-layer perceptron.""" |
|
|
| def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33): |
| super().__init__() |
| self.linear = nn.Linear(input_dim, hidden_dim) |
| self.activation = nn.LeakyReLU(0.1) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| return self.dropout(self.activation(self.linear(x))) |
|
|
|
|
| class Biaffine(nn.Module): |
| """Biaffine attention layer.""" |
|
|
| def __init__(self, input_dim: int, output_dim: int = 1, bias_x: bool = True, bias_y: bool = True): |
| super().__init__() |
| self.input_dim = input_dim |
| self.output_dim = output_dim |
| self.bias_x = bias_x |
| self.bias_y = bias_y |
| self.weight = nn.Parameter(torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y)) |
| nn.init.xavier_uniform_(self.weight) |
|
|
| def forward(self, x, y): |
| if self.bias_x: |
| x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1) |
| if self.bias_y: |
| y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1) |
| x = torch.einsum('bxi,oij->bxoj', x, self.weight) |
| scores = torch.einsum('bxoj,byj->bxyo', x, y) |
| if self.output_dim == 1: |
| scores = scores.squeeze(-1) |
| return scores |
|
|
|
|
| class BiaffineDependencyParser(nn.Module): |
| """Biaffine Dependency Parser (Dozat & Manning, 2017).""" |
|
|
| def __init__( |
| self, |
| n_words: int, |
| n_chars: int, |
| n_rels: int, |
| word_dim: int = 100, |
| char_dim: int = 50, |
| char_hidden: int = 100, |
| lstm_hidden: int = 400, |
| lstm_layers: int = 3, |
| arc_hidden: int = 500, |
| rel_hidden: int = 100, |
| dropout: float = 0.33, |
| ): |
| super().__init__() |
| self.word_embed = nn.Embedding(n_words, word_dim, padding_idx=0) |
| self.char_lstm = CharLSTM(n_chars, char_dim, char_hidden) |
| input_dim = word_dim + char_hidden |
|
|
| self.lstm = nn.LSTM( |
| input_dim, lstm_hidden // 2, |
| num_layers=lstm_layers, |
| batch_first=True, |
| bidirectional=True, |
| dropout=dropout if lstm_layers > 1 else 0 |
| ) |
|
|
| self.mlp_arc_dep = MLP(lstm_hidden, arc_hidden, dropout) |
| self.mlp_arc_head = MLP(lstm_hidden, arc_hidden, dropout) |
| self.mlp_rel_dep = MLP(lstm_hidden, rel_hidden, dropout) |
| self.mlp_rel_head = MLP(lstm_hidden, rel_hidden, dropout) |
|
|
| self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) |
| self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) |
|
|
| self.dropout = nn.Dropout(dropout) |
| self.n_rels = n_rels |
|
|
| def forward(self, words, chars, mask): |
| word_embeds = self.word_embed(words) |
| char_embeds = self.char_lstm(chars) |
| embeds = torch.cat([word_embeds, char_embeds], dim=-1) |
| embeds = self.dropout(embeds) |
|
|
| lengths = mask.sum(dim=1).cpu() |
| packed = pack_padded_sequence(embeds, lengths, batch_first=True, enforce_sorted=False) |
| lstm_out, _ = self.lstm(packed) |
| lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True, total_length=mask.size(1)) |
| lstm_out = self.dropout(lstm_out) |
|
|
| arc_dep = self.mlp_arc_dep(lstm_out) |
| arc_head = self.mlp_arc_head(lstm_out) |
| rel_dep = self.mlp_rel_dep(lstm_out) |
| rel_head = self.mlp_rel_head(lstm_out) |
|
|
| arc_scores = self.arc_attn(arc_dep, arc_head) |
| rel_scores = self.rel_attn(rel_dep, rel_head) |
|
|
| return arc_scores, rel_scores |
|
|
| def decode(self, arc_scores, rel_scores, mask): |
| arc_preds = arc_scores.argmax(dim=-1) |
| batch_size, seq_len = mask.shape |
| rel_scores_pred = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), arc_preds] |
| rel_preds = rel_scores_pred.argmax(dim=-1) |
| return arc_preds, rel_preds |
|
|
|
|
| class TransformerDependencyParser(nn.Module): |
| """Trankit-style dependency parser using XLM-RoBERTa.""" |
|
|
| def __init__( |
| self, |
| n_rels: int, |
| encoder: str = "xlm-roberta-base", |
| arc_hidden: int = 500, |
| rel_hidden: int = 100, |
| dropout: float = 0.33, |
| ): |
| super().__init__() |
| from transformers import AutoModel, AutoTokenizer |
|
|
| self.encoder_name = encoder |
| self.tokenizer = AutoTokenizer.from_pretrained(encoder) |
| self.encoder = AutoModel.from_pretrained(encoder) |
| self.hidden_size = self.encoder.config.hidden_size |
|
|
| self.mlp_arc_dep = MLP(self.hidden_size, arc_hidden, dropout) |
| self.mlp_arc_head = MLP(self.hidden_size, arc_hidden, dropout) |
| self.mlp_rel_dep = MLP(self.hidden_size, rel_hidden, dropout) |
| self.mlp_rel_head = MLP(self.hidden_size, rel_hidden, dropout) |
|
|
| self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) |
| self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) |
|
|
| self.dropout = nn.Dropout(dropout) |
| self.n_rels = n_rels |
|
|
| def encode_batch(self, sentences: list[list[str]], device): |
| """Tokenize and encode sentences, return word-level representations.""" |
| batch_size = len(sentences) |
| max_words = max(len(s) for s in sentences) |
|
|
| all_input_ids = [] |
| word_starts = [] |
|
|
| for sent in sentences: |
| input_ids = [self.tokenizer.cls_token_id] |
| starts = [] |
|
|
| for word in sent: |
| starts.append(len(input_ids)) |
| tokens = self.tokenizer.encode(word, add_special_tokens=False) |
| input_ids.extend(tokens if tokens else [self.tokenizer.unk_token_id]) |
|
|
| input_ids.append(self.tokenizer.sep_token_id) |
| all_input_ids.append(input_ids) |
| word_starts.append(starts) |
|
|
| max_len = max(len(ids) for ids in all_input_ids) |
| padded_ids = torch.zeros(batch_size, max_len, dtype=torch.long, device=device) |
| attention_mask = torch.zeros(batch_size, max_len, dtype=torch.long, device=device) |
|
|
| for i, ids in enumerate(all_input_ids): |
| padded_ids[i, :len(ids)] = torch.tensor(ids) |
| attention_mask[i, :len(ids)] = 1 |
|
|
| outputs = self.encoder(padded_ids, attention_mask=attention_mask) |
| hidden = outputs.last_hidden_state |
|
|
| word_hidden = torch.zeros(batch_size, max_words, self.hidden_size, device=device) |
| word_mask = torch.zeros(batch_size, max_words, dtype=torch.bool, device=device) |
|
|
| for i, starts in enumerate(word_starts): |
| for j, pos in enumerate(starts): |
| word_hidden[i, j] = hidden[i, pos] |
| word_mask[i, j] = True |
|
|
| return word_hidden, word_mask |
|
|
| def forward(self, word_hidden, word_mask): |
| """Compute arc and relation scores from word representations.""" |
| word_hidden = self.dropout(word_hidden) |
|
|
| arc_dep = self.mlp_arc_dep(word_hidden) |
| arc_head = self.mlp_arc_head(word_hidden) |
| rel_dep = self.mlp_rel_dep(word_hidden) |
| rel_head = self.mlp_rel_head(word_hidden) |
|
|
| arc_scores = self.arc_attn(arc_dep, arc_head) |
| rel_scores = self.rel_attn(rel_dep, rel_head) |
|
|
| return arc_scores, rel_scores |
|
|
| def decode(self, arc_scores, rel_scores, mask): |
| """Greedy decoding.""" |
| arc_preds = arc_scores.argmax(dim=-1) |
| batch_size, seq_len = mask.shape |
| rel_scores_pred = rel_scores[torch.arange(batch_size, device=mask.device).unsqueeze(1), |
| torch.arange(seq_len, device=mask.device), arc_preds] |
| rel_preds = rel_scores_pred.argmax(dim=-1) |
| return arc_preds, rel_preds |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class Token: |
| """A token with its dependency information.""" |
|
|
| id: int |
| form: str |
| head: int |
| deprel: str |
|
|
| @property |
| def head_form(self) -> str: |
| """Return 'ROOT' for root tokens, otherwise requires parent sentence context.""" |
| return "ROOT" if self.head == 0 else "" |
|
|
| def to_conllu(self) -> str: |
| """Format as CoNLL-U line.""" |
| return f"{self.id}\t{self.form}\t_\t_\t_\t_\t{self.head}\t{self.deprel}\t_\t_" |
|
|
|
|
| @dataclass |
| class ParsedSentence: |
| """A parsed sentence with dependency structure.""" |
|
|
| text: str |
| tokens: list[Token] |
|
|
| def __iter__(self): |
| return iter(self.tokens) |
|
|
| def __len__(self): |
| return len(self.tokens) |
|
|
| def __getitem__(self, idx): |
| return self.tokens[idx] |
|
|
| def get_head(self, token: Token) -> Optional[Token]: |
| """Get the head token of the given token, or None for ROOT.""" |
| if token.head == 0: |
| return None |
| return self.tokens[token.head - 1] |
|
|
| def get_dependents(self, token: Token) -> list[Token]: |
| """Get all tokens that depend on the given token.""" |
| return [t for t in self.tokens if t.head == token.id] |
|
|
| def get_root(self) -> Optional[Token]: |
| """Get the root token of the sentence.""" |
| for token in self.tokens: |
| if token.head == 0: |
| return token |
| return None |
|
|
| def to_conllu(self, sent_id: Optional[int] = None) -> str: |
| """Format as CoNLL-U block.""" |
| lines = [] |
| if sent_id is not None: |
| lines.append(f"# sent_id = {sent_id}") |
| lines.append(f"# text = {self.text}") |
| for token in self.tokens: |
| lines.append(token.to_conllu()) |
| return "\n".join(lines) |
|
|
|
|
| |
| Sentence = ParsedSentence |
|
|
|
|
| class Parser: |
| """Vietnamese Dependency Parser using Bamboo-1 model.""" |
|
|
| def __init__(self, model_path: str | Path): |
| """Load the parser from a model file or Hugging Face Hub. |
| |
| Args: |
| model_path: Path to the trained model file, directory, or HF repo ID |
| (e.g., "undertheseanlp/bamboo-1"). |
| """ |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| model_path_str = str(model_path) |
| if "/" in model_path_str and not Path(model_path_str).exists(): |
| |
| self.model_path = Path(hf_hub_download( |
| repo_id=model_path_str, |
| filename=MODEL_FILENAME, |
| )) |
| else: |
| self.model_path = Path(model_path) |
| |
| if self.model_path.is_dir(): |
| self.model_path = self.model_path / 'model.pt' |
|
|
| |
| import __main__ |
| __main__.Vocabulary = Vocabulary |
|
|
| |
| checkpoint = torch.load(self.model_path, map_location=self.device, weights_only=False) |
|
|
| self.vocab = checkpoint['vocab'] |
| self.config = checkpoint.get('config', {}) |
|
|
| |
| self.method = self.config.get('method', 'baseline') |
|
|
| if self.method == 'trankit': |
| encoder = self.config.get('encoder', 'xlm-roberta-base') |
| self.model = TransformerDependencyParser( |
| n_rels=self.config.get('n_rels', self.vocab.n_rels), |
| encoder=encoder, |
| ) |
| else: |
| self.model = BiaffineDependencyParser( |
| n_words=self.config.get('n_words', self.vocab.n_words), |
| n_chars=self.config.get('n_chars', self.vocab.n_chars), |
| n_rels=self.config.get('n_rels', self.vocab.n_rels), |
| lstm_hidden=self.config.get('lstm_hidden', 400), |
| lstm_layers=self.config.get('lstm_layers', 3), |
| ) |
|
|
| self.model.load_state_dict(checkpoint['model']) |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| def _tokenize(self, text: str) -> list[str]: |
| """Simple whitespace tokenization.""" |
| return text.strip().split() |
|
|
| def _prepare_input_baseline(self, words: list[str]): |
| """Prepare model input tensors for baseline model.""" |
| word_ids = [self.vocab.encode_word(w) for w in words] |
| char_ids = [[self.vocab.encode_char(c) for c in w] for w in words] |
| max_word_len = max(len(c) for c in char_ids) if char_ids else 1 |
|
|
| word_tensor = torch.tensor([word_ids], dtype=torch.long, device=self.device) |
| char_tensor = torch.zeros(1, len(words), max_word_len, dtype=torch.long, device=self.device) |
| for i, chars in enumerate(char_ids): |
| char_tensor[0, i, :len(chars)] = torch.tensor(chars) |
|
|
| mask = torch.ones(1, len(words), dtype=torch.bool, device=self.device) |
| return word_tensor, char_tensor, mask |
|
|
| def parse(self, text: str) -> ParsedSentence: |
| """Parse a single sentence. |
| |
| Args: |
| text: Vietnamese text to parse. |
| |
| Returns: |
| ParsedSentence object with tokens and dependency information. |
| """ |
| words = self._tokenize(text) |
| if not words: |
| return ParsedSentence(text=text, tokens=[]) |
|
|
| with torch.no_grad(): |
| if self.method == 'trankit': |
| word_hidden, mask = self.model.encode_batch([words], self.device) |
| arc_scores, rel_scores = self.model(word_hidden, mask) |
| arc_preds, rel_preds = self.model.decode(arc_scores, rel_scores, mask) |
| else: |
| word_tensor, char_tensor, mask = self._prepare_input_baseline(words) |
| arc_scores, rel_scores = self.model(word_tensor, char_tensor, mask) |
| arc_preds, rel_preds = self.model.decode(arc_scores, rel_scores, mask) |
|
|
| |
| tokens = [] |
| for i, word in enumerate(words): |
| head = arc_preds[0, i].item() |
| rel_idx = rel_preds[0, i].item() |
| deprel = self.vocab.idx2rel.get(rel_idx, 'dep') |
| tokens.append(Token(id=i + 1, form=word, head=head, deprel=deprel)) |
|
|
| return ParsedSentence(text=text, tokens=tokens) |
|
|
| def parse_batch(self, texts: list[str]) -> list[ParsedSentence]: |
| """Parse multiple sentences. |
| |
| Args: |
| texts: List of Vietnamese texts to parse. |
| |
| Returns: |
| List of ParsedSentence objects. |
| """ |
| return [self.parse(text) for text in texts] |
|
|
| def __call__(self, text: str) -> ParsedSentence: |
| """Parse a sentence (shorthand for parse()).""" |
| return self.parse(text) |
|
|
|
|
| |
| MODEL_VERSION = "1.0.0" |
| MODEL_DATE = "20260202" |
| MODEL_FILENAME = f"bamboo-{MODEL_VERSION}-{MODEL_DATE}.pt" |
| REPO_ID = "undertheseanlp/bamboo-1-model" |
| DEFAULT_MODEL = REPO_ID |
|
|
| |
| _default_parser: Optional[Parser] = None |
|
|
|
|
| def load(model: str | Path = DEFAULT_MODEL) -> Parser: |
| """Load a parser from a model file or Hugging Face Hub. |
| |
| Args: |
| model: Path to the trained model file, directory, or HF repo ID |
| (e.g., "undertheseanlp/bamboo-1"). |
| |
| Returns: |
| Parser instance. |
| |
| Example: |
| >>> parser = load("undertheseanlp/bamboo-1") # From Hugging Face |
| >>> parser = load("models/bamboo-1") # From local directory |
| """ |
| return Parser(model) |
|
|
|
|
| def parse(text: str, model: str | Path = DEFAULT_MODEL) -> ParsedSentence: |
| """Parse a Vietnamese sentence using the default model. |
| |
| Args: |
| text: Vietnamese text to parse. |
| model: Path to the model or HF repo ID (uses "undertheseanlp/bamboo-1" if not specified). |
| |
| Returns: |
| ParsedSentence object with tokens and dependency information. |
| |
| Example: |
| >>> from src import parse |
| >>> sent = parse("Tôi yêu Việt Nam") |
| >>> for token in sent: |
| ... print(f"{token.form} -> {sent.get_head(token).form if sent.get_head(token) else 'ROOT'}") |
| """ |
| global _default_parser |
| model_str = str(model) |
| if _default_parser is None or str(_default_parser.model_path) != model_str: |
| _default_parser = Parser(model) |
| return _default_parser.parse(text) |
|
|