| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import ( |
| DistilBertModel, |
| DistilBertPreTrainedModel, |
| AutoTokenizer, |
| ) |
|
|
| |
| |
| UNIVERSAL_TO_TOKEN = { |
| "ADJ": "adj", |
| "ADP": "adp", |
| "ADV": "adv", |
| "AUX": "aux", |
| "CCONJ": "cc", |
| "DET": "det", |
| "INTJ": "ij", |
| "NOUN": "nn", |
| "NUM": "num", |
| "PART": "pt", |
| "PRON": "pro", |
| "PROPN": "np", |
| "PUNCT": "pun", |
| "SCONJ": "sc", |
| "SYM": "sym", |
| "VERB": "vb", |
| "X": "xx", |
| "SPACE": "sp", |
| } |
| UNK_POS_TOKEN = "unk" |
|
|
| _POS_TAG_NAMES = ["PAD"] + list(UNIVERSAL_TO_TOKEN.keys()) |
| POS_TAG_TO_ID = {tag: i for i, tag in enumerate(_POS_TAG_NAMES)} |
| NUM_POS_TAGS = len(_POS_TAG_NAMES) |
|
|
|
|
| class ProsodyBoundaryModel(DistilBertPreTrainedModel): |
| """ |
| Multi-task token classifier for ToBI prosodic annotation. |
| |
| Architecture |
| ββββββββββββ |
| DistilBERT encoder |
| [+ optional POS embedding addition, post-transformer] |
| βββΊ dropout (seq_classif_dropout) |
| βββΊ boundary_head Linear(768 β 2) boundary / non-boundary |
| βββΊ intonation_head Linear(768 β 3) H% / L% / !H% |
| βββΊ break_idx_head Linear(768 β 2) index-3 / index-4 |
| |
| This checkpoint is set to use_pos_embedding=False. |
| All three heads are applied to every token; intonation and break index |
| predictions are only meaningful at boundary positions. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.distilbert = DistilBertModel(config) |
| self.dropout = nn.Dropout(config.seq_classif_dropout) |
|
|
| self.use_pos_embedding = getattr(config, "use_pos_embedding", False) |
| if self.use_pos_embedding: |
| _pos_emb_dim = getattr(config, "pos_emb_dim", 64) |
| _num_pos_tags = getattr(config, "num_pos_tags", NUM_POS_TAGS) |
| self.pos_embedding = nn.Embedding( |
| _num_pos_tags, _pos_emb_dim, padding_idx=0 |
| ) |
| self.pos_proj = nn.Linear(_pos_emb_dim, config.hidden_size, bias=False) |
|
|
| self.boundary_head = nn.Linear(config.hidden_size, 2) |
| self.intonation_head = nn.Linear(config.hidden_size, 3) |
| self.break_idx_head = nn.Linear(config.hidden_size, 2) |
| self.post_init() |
|
|
| def forward(self, input_ids, attention_mask, pos_ids=None, **kwargs): |
| """ |
| Parameters |
| ---------- |
| input_ids : (B, T) |
| attention_mask : (B, T) |
| pos_ids : (B, T) LongTensor | None β only used when use_pos_embedding=True |
| |
| Returns |
| ------- |
| dict with keys: |
| boundary_logits : (B, T, 2) |
| intonation_logits : (B, T, 3) |
| break_idx_logits : (B, T, 2) |
| """ |
| outputs = self.distilbert(input_ids=input_ids, |
| attention_mask=attention_mask) |
| seq_out = self.dropout(outputs.last_hidden_state) |
|
|
| if self.use_pos_embedding and pos_ids is not None: |
| pos_emb = self.pos_proj(self.pos_embedding(pos_ids)) |
| seq_out = seq_out + pos_emb |
|
|
| return { |
| "boundary_logits": self.boundary_head(seq_out), |
| "intonation_logits": self.intonation_head(seq_out), |
| "break_idx_logits": self.break_idx_head(seq_out), |
| } |
|
|
| @classmethod |
| def _can_set_experts_implementation(cls): |
| return False |
|
|