text2tobi / model.py
lemmatix's picture
Upload folder using huggingface_hub
06b9041 verified
Raw
History Blame Contribute Delete
3.94 kB
# model.py
# ProsodyBoundaryModel β€” custom DistilBERT multi-task token classifier.
# The libri+peoples+sbc checkpoint was trained with use_pos_embedding=False.
# The POS embedding path is present in the class but inactive for this model.
import torch
import torch.nn as nn
from transformers import (
DistilBertModel,
DistilBertPreTrainedModel,
AutoTokenizer,
)
# ── POS tag vocabulary (Universal Dependencies / spaCy UPOS) ─────────────────
# Preserved for checkpoint compatibility. Not used by libri+peoples+sbc.
UNIVERSAL_TO_TOKEN = {
"ADJ": "adj",
"ADP": "adp",
"ADV": "adv",
"AUX": "aux",
"CCONJ": "cc",
"DET": "det",
"INTJ": "ij",
"NOUN": "nn",
"NUM": "num",
"PART": "pt",
"PRON": "pro",
"PROPN": "np",
"PUNCT": "pun",
"SCONJ": "sc",
"SYM": "sym",
"VERB": "vb",
"X": "xx",
"SPACE": "sp",
}
UNK_POS_TOKEN = "unk"
_POS_TAG_NAMES = ["PAD"] + list(UNIVERSAL_TO_TOKEN.keys())
POS_TAG_TO_ID = {tag: i for i, tag in enumerate(_POS_TAG_NAMES)}
NUM_POS_TAGS = len(_POS_TAG_NAMES) # 19
class ProsodyBoundaryModel(DistilBertPreTrainedModel):
"""
Multi-task token classifier for ToBI prosodic annotation.
Architecture
────────────
DistilBERT encoder
[+ optional POS embedding addition, post-transformer]
└─► dropout (seq_classif_dropout)
β”œβ”€β–Ί boundary_head Linear(768 β†’ 2) boundary / non-boundary
β”œβ”€β–Ί intonation_head Linear(768 β†’ 3) H% / L% / !H%
└─► break_idx_head Linear(768 β†’ 2) index-3 / index-4
This checkpoint is set to use_pos_embedding=False.
All three heads are applied to every token; intonation and break index
predictions are only meaningful at boundary positions.
"""
def __init__(self, config):
super().__init__(config)
self.distilbert = DistilBertModel(config)
self.dropout = nn.Dropout(config.seq_classif_dropout)
self.use_pos_embedding = getattr(config, "use_pos_embedding", False)
if self.use_pos_embedding:
_pos_emb_dim = getattr(config, "pos_emb_dim", 64)
_num_pos_tags = getattr(config, "num_pos_tags", NUM_POS_TAGS)
self.pos_embedding = nn.Embedding(
_num_pos_tags, _pos_emb_dim, padding_idx=0
)
self.pos_proj = nn.Linear(_pos_emb_dim, config.hidden_size, bias=False)
self.boundary_head = nn.Linear(config.hidden_size, 2)
self.intonation_head = nn.Linear(config.hidden_size, 3)
self.break_idx_head = nn.Linear(config.hidden_size, 2)
self.post_init()
def forward(self, input_ids, attention_mask, pos_ids=None, **kwargs):
"""
Parameters
----------
input_ids : (B, T)
attention_mask : (B, T)
pos_ids : (B, T) LongTensor | None β€” only used when use_pos_embedding=True
Returns
-------
dict with keys:
boundary_logits : (B, T, 2)
intonation_logits : (B, T, 3)
break_idx_logits : (B, T, 2)
"""
outputs = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask)
seq_out = self.dropout(outputs.last_hidden_state) # (B, T, H)
if self.use_pos_embedding and pos_ids is not None:
pos_emb = self.pos_proj(self.pos_embedding(pos_ids)) # (B, T, H)
seq_out = seq_out + pos_emb
return {
"boundary_logits": self.boundary_head(seq_out), # (B, T, 2)
"intonation_logits": self.intonation_head(seq_out), # (B, T, 3)
"break_idx_logits": self.break_idx_head(seq_out), # (B, T, 2)
}
@classmethod
def _can_set_experts_implementation(cls):
return False