medieval-latin-span-ner / span_ner_model.py
ERCDiDip's picture
Create span_ner_model.py
dfef0e5 verified
"""
Medieval Latin NER - Custom Span-NER Architecture
=============================================================================
Core architecture for the Span-NER model utilizing a bi-encoder approach
with a frozen BGE-M3 semantic label space and XLM-RoBERTa-Large text encoder.
"""
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
# ---------------------------------------------------------------------------
# 1. CONFIGURATION
# ---------------------------------------------------------------------------
class Config:
TEXT_MODEL = "FacebookAI/xlm-roberta-large"
TEXT_DIM = 1024
LABEL_MODEL = "BAAI/bge-m3"
LABEL_DIM = 1024
MAX_SPAN_WIDTH = 80
WIDTH_EMB_DIM = 64
SPAN_HIDDEN = 512
ATTENTION_HEADS = 4
MAX_SEQ_LEN = 512
PREDICT_TEMP = 1.35
# ---------------------------------------------------------------------------
# 2. LABEL DICTIONARY & PROMPTS
# ---------------------------------------------------------------------------
LABEL_DICT = {
"PER": "individual person name without any titles or roles, strictly the given name or family name",
"ACTOR": "full noun phrase referring to a person including their name plus noble title, profession, geographic origin, or social status",
"TITLE": "social rank, noble title, ecclesiastical office, profession, or papal rank such as comes, abbas, episcopus",
"REL": "word or phrase indicating family, kinship, marriage, or social relationship like filius, uxor, frater",
"LOC": "geographical place, settlement, city, diocese, region, or named territory",
"INS": "monastery, abbey, church, cell, or religious order functioning as a corporate and legal body",
"NAT": "natural landscape feature such as a river, stream, forest, mountain, or valley",
"EST": "short physical plot of land, estate, farm, meadows, woods, vineyards, or courtyards",
"PROP": "detailed boundary description of a property, grange, estate, or island including past owners, movables, and immovables",
"LEG": "legal clause declaring rights, conditions, penalties, permissions, or papal commands",
"TRANS": "verb or phrase denoting a core transaction, confirmation, transfer, sale, gift, or donation",
"TIM": "time period, duration, general dating formula, indiction, or papal/royal regnal year",
"DAT": "specific calendar date, precise year of incarnation often starting with Anno or Datum, or named liturgical feast day",
"MON": "money, currency, coin, or monetary value such as libra, solidus, denarius, uncia, or marca",
"TAX": "customary toll, legal tax, tithe, exaction, lucrum camere, or tribute paid to an authority",
"COM": "harvested crops, food, physical goods, salt, wine, wax, gold, wood, or animals traded or given",
"NUM": "number written as a word or roman numeral, including fractions and quantities",
"MEA": "unit of measurement for land, volume, or weight such as mansus, carratas, aratrum, or talentum",
"RELIC": "holy relic, cross, altar, or sacred object of veneration within a church",
}
LABEL_KEYS = list(LABEL_DICT.keys())
LABEL_DESCS = list(LABEL_DICT.values())
LABEL2ID = {k: i for i, k in enumerate(LABEL_KEYS)}
ID2LABEL = {i: k for k, i in LABEL2ID.items()}
NUM_LABELS = len(LABEL_DICT)
def char_tokenize(text):
return [{"token": m.group(), "start": m.start(), "end": m.end()}
for m in re.finditer(r'\w+|[^\w\s]', text)]
# ---------------------------------------------------------------------------
# 3. MODEL ARCHITECTURE
# ---------------------------------------------------------------------------
class SpanRepLayer(nn.Module):
def __init__(self, hidden, max_span_width, width_emb_dim, num_heads=4):
super().__init__()
self.max_span_width = max_span_width
self.num_heads = num_heads
self.width_emb = nn.Embedding(max_span_width + 1, width_emb_dim)
self.att_query = nn.Sequential(
nn.Linear(hidden, hidden // 2),
nn.GELU(),
nn.Linear(hidden // 2, num_heads)
)
self.span_dim = 2 * hidden + (num_heads * hidden) + width_emb_dim
def forward(self, seq_out, spans):
B, S, _ = spans.shape
L = seq_out.size(1)
H = seq_out.size(-1)
h_start = seq_out[torch.arange(B).unsqueeze(1), spans[:,:,0]]
h_end = seq_out[torch.arange(B).unsqueeze(1), spans[:,:,1]]
width = spans[:,:,2].clamp(0, self.max_span_width)
w_emb = self.width_emb(width)
idx = torch.arange(L, device=seq_out.device).view(1, 1, L)
mask = (idx >= spans[:,:,0:1]) & (idx <= spans[:,:,1:2])
att_logits = self.att_query(seq_out)
att_logits = att_logits.unsqueeze(1).expand(B, S, L, self.num_heads)
mask_expanded = mask.unsqueeze(-1).expand(-1, -1, -1, self.num_heads)
att_logits = att_logits.masked_fill(~mask_expanded, float('-inf'))
att_weights = F.softmax(att_logits, dim=2)
h_pool = torch.einsum('bslm,blh->bsmh', att_weights, seq_out)
h_pool = h_pool.reshape(B, S, self.num_heads * H)
return torch.cat([h_start, h_end, h_pool, w_emb], dim=-1)
class SpanNERModel(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.text_enc = AutoModel.from_pretrained(cfg.TEXT_MODEL, add_pooling_layer=False)
self.label_enc = AutoModel.from_pretrained(cfg.LABEL_MODEL)
self.span_layer = SpanRepLayer(cfg.TEXT_DIM, cfg.MAX_SPAN_WIDTH, cfg.WIDTH_EMB_DIM, num_heads=cfg.ATTENTION_HEADS)
self.label_proj = nn.Sequential(
nn.Linear(cfg.LABEL_DIM, cfg.SPAN_HIDDEN),
nn.GELU(),
nn.LayerNorm(cfg.SPAN_HIDDEN),
nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN)
)
self.span_proj = nn.Sequential(
nn.Linear(self.span_layer.span_dim, cfg.SPAN_HIDDEN),
nn.GELU(),
nn.LayerNorm(cfg.SPAN_HIDDEN),
nn.Dropout(0.2),
nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN)
)
self.logit_scale = nn.Parameter(torch.tensor(1.0))
self._raw_label_embs = None
@torch.no_grad()
def _build_label_cache(self, label_tokenizer, device):
enc = label_tokenizer(
LABEL_DESCS, padding=True, truncation=True, max_length=128, return_tensors="pt"
).to(device)
out = self.label_enc(**enc).last_hidden_state
mask = enc["attention_mask"].unsqueeze(-1).float()
pooled = F.normalize((out * mask).sum(1) / mask.sum(1), dim=-1, eps=1e-8)
self._raw_label_embs = pooled.detach()
def predict(self, text, label_tokenizer, text_tokenizer, threshold, flat_ner, device):
self.eval()
tokens_info = char_tokenize(text)
all_tokens = [t["token"] for t in tokens_info]
n_total = len(all_tokens)
if n_total == 0: return []
stride = self.cfg.MAX_SEQ_LEN - 50
chunk_starts = range(0, max(1, n_total), stride)
all_candidates = []
with torch.no_grad():
if self._raw_label_embs is None or self._raw_label_embs.device != device:
self._build_label_cache(label_tokenizer, device)
label_feat = self.label_proj(self._raw_label_embs.to(device))
for start_idx in chunk_starts:
end_idx = min(start_idx + self.cfg.MAX_SEQ_LEN - 2, n_total)
tokens = all_tokens[start_idx:end_idx]
n = len(tokens)
if n == 0: continue
enc = text_tokenizer(tokens, is_split_into_words=True, max_length=self.cfg.MAX_SEQ_LEN,
truncation=True, padding=False, return_tensors="pt").to(device)
word_ids = enc.word_ids(batch_index=0)
first_sw, last_sw = {}, {}
for sw_idx, w_idx in enumerate(word_ids):
if w_idx is not None:
if w_idx not in first_sw: first_sw[w_idx] = sw_idx
last_sw[w_idx] = sw_idx
span_list, span_word_bounds = [], []
for ws in range(n):
for we in range(ws, min(ws + self.cfg.MAX_SPAN_WIDTH, n)):
if ws not in first_sw or we not in last_sw: continue
span_list.append([first_sw[ws], last_sw[we], we - ws + 1])
span_word_bounds.append((start_idx + ws, start_idx + we))
if not span_list: continue
seq_out = self.text_enc(input_ids=enc["input_ids"].to(device),
attention_mask=enc["attention_mask"].to(device)).last_hidden_state
chunk_logits = []
for i in range(0, len(span_list), 4096):
chunk = torch.tensor(span_list[i:i+4096], dtype=torch.long).unsqueeze(0).to(device)
sf = self.span_proj(self.span_layer(seq_out, chunk))
sf_norm = F.normalize(sf, p=2, dim=-1)
lf_norm = F.normalize(label_feat, p=2, dim=-1)
scale = self.logit_scale.exp().clamp(max=120.0)
ch_logits = torch.einsum('bsd,ld->bsl', sf_norm, lf_norm) * scale
chunk_logits.append(ch_logits.squeeze(0).cpu())
scores = torch.sigmoid(torch.cat(chunk_logits, dim=0) / self.cfg.PREDICT_TEMP)
for si, (g_ws, g_we) in enumerate(span_word_bounds):
for li in range(NUM_LABELS):
score = scores[si, li].item()
if score >= threshold:
all_candidates.append((g_ws, g_we, ID2LABEL[li], score))
unique_cands = {}
for ws, we, label, score in all_candidates:
key = (ws, we, label)
if key not in unique_cands or score > unique_cands[key]:
unique_cands[key] = score
final_candidates = [(ws, we, lbl, sc) for (ws, we, lbl), sc in unique_cands.items()]
final_candidates.sort(key=lambda x: -x[3])
taken, result = set(), []
for ws, we, label, score in final_candidates:
covered = set(range(ws, we + 1))
if flat_ner and covered & taken: continue
if flat_ner: taken |= covered
start_char = tokens_info[ws]["start"]
end_char = tokens_info[we]["end"]
text_span = text[start_char:end_char]
result.append({
"label": label,
"score": round(score, 4),
"text": text_span,
"start_char": start_char,
"end_char": end_char,
"start_word": ws,
"end_word": we
})
return result