| """ |
| Medieval Latin NER - Custom Span-NER Architecture |
| ============================================================================= |
| Core architecture for the Span-NER model utilizing a bi-encoder approach |
| with a frozen BGE-M3 semantic label space and XLM-RoBERTa-Large text encoder. |
| """ |
|
|
| import re |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel |
|
|
| |
| |
| |
| class Config: |
| TEXT_MODEL = "FacebookAI/xlm-roberta-large" |
| TEXT_DIM = 1024 |
| LABEL_MODEL = "BAAI/bge-m3" |
| LABEL_DIM = 1024 |
| |
| MAX_SPAN_WIDTH = 80 |
| WIDTH_EMB_DIM = 64 |
| SPAN_HIDDEN = 512 |
| ATTENTION_HEADS = 4 |
| |
| MAX_SEQ_LEN = 512 |
| PREDICT_TEMP = 1.35 |
|
|
| |
| |
| |
| LABEL_DICT = { |
| "PER": "individual person name without any titles or roles, strictly the given name or family name", |
| "ACTOR": "full noun phrase referring to a person including their name plus noble title, profession, geographic origin, or social status", |
| "TITLE": "social rank, noble title, ecclesiastical office, profession, or papal rank such as comes, abbas, episcopus", |
| "REL": "word or phrase indicating family, kinship, marriage, or social relationship like filius, uxor, frater", |
| "LOC": "geographical place, settlement, city, diocese, region, or named territory", |
| "INS": "monastery, abbey, church, cell, or religious order functioning as a corporate and legal body", |
| "NAT": "natural landscape feature such as a river, stream, forest, mountain, or valley", |
| "EST": "short physical plot of land, estate, farm, meadows, woods, vineyards, or courtyards", |
| "PROP": "detailed boundary description of a property, grange, estate, or island including past owners, movables, and immovables", |
| "LEG": "legal clause declaring rights, conditions, penalties, permissions, or papal commands", |
| "TRANS": "verb or phrase denoting a core transaction, confirmation, transfer, sale, gift, or donation", |
| "TIM": "time period, duration, general dating formula, indiction, or papal/royal regnal year", |
| "DAT": "specific calendar date, precise year of incarnation often starting with Anno or Datum, or named liturgical feast day", |
| "MON": "money, currency, coin, or monetary value such as libra, solidus, denarius, uncia, or marca", |
| "TAX": "customary toll, legal tax, tithe, exaction, lucrum camere, or tribute paid to an authority", |
| "COM": "harvested crops, food, physical goods, salt, wine, wax, gold, wood, or animals traded or given", |
| "NUM": "number written as a word or roman numeral, including fractions and quantities", |
| "MEA": "unit of measurement for land, volume, or weight such as mansus, carratas, aratrum, or talentum", |
| "RELIC": "holy relic, cross, altar, or sacred object of veneration within a church", |
| } |
|
|
| LABEL_KEYS = list(LABEL_DICT.keys()) |
| LABEL_DESCS = list(LABEL_DICT.values()) |
| LABEL2ID = {k: i for i, k in enumerate(LABEL_KEYS)} |
| ID2LABEL = {i: k for k, i in LABEL2ID.items()} |
| NUM_LABELS = len(LABEL_DICT) |
|
|
| def char_tokenize(text): |
| return [{"token": m.group(), "start": m.start(), "end": m.end()} |
| for m in re.finditer(r'\w+|[^\w\s]', text)] |
|
|
| |
| |
| |
| class SpanRepLayer(nn.Module): |
| def __init__(self, hidden, max_span_width, width_emb_dim, num_heads=4): |
| super().__init__() |
| self.max_span_width = max_span_width |
| self.num_heads = num_heads |
| self.width_emb = nn.Embedding(max_span_width + 1, width_emb_dim) |
|
|
| self.att_query = nn.Sequential( |
| nn.Linear(hidden, hidden // 2), |
| nn.GELU(), |
| nn.Linear(hidden // 2, num_heads) |
| ) |
| self.span_dim = 2 * hidden + (num_heads * hidden) + width_emb_dim |
|
|
| def forward(self, seq_out, spans): |
| B, S, _ = spans.shape |
| L = seq_out.size(1) |
| H = seq_out.size(-1) |
|
|
| h_start = seq_out[torch.arange(B).unsqueeze(1), spans[:,:,0]] |
| h_end = seq_out[torch.arange(B).unsqueeze(1), spans[:,:,1]] |
| width = spans[:,:,2].clamp(0, self.max_span_width) |
| w_emb = self.width_emb(width) |
|
|
| idx = torch.arange(L, device=seq_out.device).view(1, 1, L) |
| mask = (idx >= spans[:,:,0:1]) & (idx <= spans[:,:,1:2]) |
|
|
| att_logits = self.att_query(seq_out) |
| att_logits = att_logits.unsqueeze(1).expand(B, S, L, self.num_heads) |
|
|
| mask_expanded = mask.unsqueeze(-1).expand(-1, -1, -1, self.num_heads) |
| att_logits = att_logits.masked_fill(~mask_expanded, float('-inf')) |
| att_weights = F.softmax(att_logits, dim=2) |
|
|
| h_pool = torch.einsum('bslm,blh->bsmh', att_weights, seq_out) |
| h_pool = h_pool.reshape(B, S, self.num_heads * H) |
|
|
| return torch.cat([h_start, h_end, h_pool, w_emb], dim=-1) |
|
|
| class SpanNERModel(nn.Module): |
| def __init__(self, cfg: Config): |
| super().__init__() |
| self.cfg = cfg |
| self.text_enc = AutoModel.from_pretrained(cfg.TEXT_MODEL, add_pooling_layer=False) |
| self.label_enc = AutoModel.from_pretrained(cfg.LABEL_MODEL) |
| |
| self.span_layer = SpanRepLayer(cfg.TEXT_DIM, cfg.MAX_SPAN_WIDTH, cfg.WIDTH_EMB_DIM, num_heads=cfg.ATTENTION_HEADS) |
|
|
| self.label_proj = nn.Sequential( |
| nn.Linear(cfg.LABEL_DIM, cfg.SPAN_HIDDEN), |
| nn.GELU(), |
| nn.LayerNorm(cfg.SPAN_HIDDEN), |
| nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN) |
| ) |
| self.span_proj = nn.Sequential( |
| nn.Linear(self.span_layer.span_dim, cfg.SPAN_HIDDEN), |
| nn.GELU(), |
| nn.LayerNorm(cfg.SPAN_HIDDEN), |
| nn.Dropout(0.2), |
| nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN) |
| ) |
|
|
| self.logit_scale = nn.Parameter(torch.tensor(1.0)) |
| self._raw_label_embs = None |
|
|
| @torch.no_grad() |
| def _build_label_cache(self, label_tokenizer, device): |
| enc = label_tokenizer( |
| LABEL_DESCS, padding=True, truncation=True, max_length=128, return_tensors="pt" |
| ).to(device) |
| out = self.label_enc(**enc).last_hidden_state |
| mask = enc["attention_mask"].unsqueeze(-1).float() |
| pooled = F.normalize((out * mask).sum(1) / mask.sum(1), dim=-1, eps=1e-8) |
| self._raw_label_embs = pooled.detach() |
|
|
| def predict(self, text, label_tokenizer, text_tokenizer, threshold, flat_ner, device): |
| self.eval() |
| tokens_info = char_tokenize(text) |
| all_tokens = [t["token"] for t in tokens_info] |
| n_total = len(all_tokens) |
| if n_total == 0: return [] |
|
|
| stride = self.cfg.MAX_SEQ_LEN - 50 |
| chunk_starts = range(0, max(1, n_total), stride) |
| all_candidates = [] |
|
|
| with torch.no_grad(): |
| if self._raw_label_embs is None or self._raw_label_embs.device != device: |
| self._build_label_cache(label_tokenizer, device) |
| label_feat = self.label_proj(self._raw_label_embs.to(device)) |
|
|
| for start_idx in chunk_starts: |
| end_idx = min(start_idx + self.cfg.MAX_SEQ_LEN - 2, n_total) |
| tokens = all_tokens[start_idx:end_idx] |
| n = len(tokens) |
| if n == 0: continue |
|
|
| enc = text_tokenizer(tokens, is_split_into_words=True, max_length=self.cfg.MAX_SEQ_LEN, |
| truncation=True, padding=False, return_tensors="pt").to(device) |
| word_ids = enc.word_ids(batch_index=0) |
|
|
| first_sw, last_sw = {}, {} |
| for sw_idx, w_idx in enumerate(word_ids): |
| if w_idx is not None: |
| if w_idx not in first_sw: first_sw[w_idx] = sw_idx |
| last_sw[w_idx] = sw_idx |
|
|
| span_list, span_word_bounds = [], [] |
| for ws in range(n): |
| for we in range(ws, min(ws + self.cfg.MAX_SPAN_WIDTH, n)): |
| if ws not in first_sw or we not in last_sw: continue |
| span_list.append([first_sw[ws], last_sw[we], we - ws + 1]) |
| span_word_bounds.append((start_idx + ws, start_idx + we)) |
|
|
| if not span_list: continue |
|
|
| seq_out = self.text_enc(input_ids=enc["input_ids"].to(device), |
| attention_mask=enc["attention_mask"].to(device)).last_hidden_state |
|
|
| chunk_logits = [] |
| for i in range(0, len(span_list), 4096): |
| chunk = torch.tensor(span_list[i:i+4096], dtype=torch.long).unsqueeze(0).to(device) |
| sf = self.span_proj(self.span_layer(seq_out, chunk)) |
|
|
| sf_norm = F.normalize(sf, p=2, dim=-1) |
| lf_norm = F.normalize(label_feat, p=2, dim=-1) |
| scale = self.logit_scale.exp().clamp(max=120.0) |
|
|
| ch_logits = torch.einsum('bsd,ld->bsl', sf_norm, lf_norm) * scale |
| chunk_logits.append(ch_logits.squeeze(0).cpu()) |
|
|
| scores = torch.sigmoid(torch.cat(chunk_logits, dim=0) / self.cfg.PREDICT_TEMP) |
|
|
| for si, (g_ws, g_we) in enumerate(span_word_bounds): |
| for li in range(NUM_LABELS): |
| score = scores[si, li].item() |
| if score >= threshold: |
| all_candidates.append((g_ws, g_we, ID2LABEL[li], score)) |
|
|
| unique_cands = {} |
| for ws, we, label, score in all_candidates: |
| key = (ws, we, label) |
| if key not in unique_cands or score > unique_cands[key]: |
| unique_cands[key] = score |
|
|
| final_candidates = [(ws, we, lbl, sc) for (ws, we, lbl), sc in unique_cands.items()] |
| final_candidates.sort(key=lambda x: -x[3]) |
|
|
| taken, result = set(), [] |
| for ws, we, label, score in final_candidates: |
| covered = set(range(ws, we + 1)) |
| if flat_ner and covered & taken: continue |
| if flat_ner: taken |= covered |
| |
| start_char = tokens_info[ws]["start"] |
| end_char = tokens_info[we]["end"] |
| text_span = text[start_char:end_char] |
| |
| result.append({ |
| "label": label, |
| "score": round(score, 4), |
| "text": text_span, |
| "start_char": start_char, |
| "end_char": end_char, |
| "start_word": ws, |
| "end_word": we |
| }) |
|
|
| return result |