File size: 10,937 Bytes
1a5832c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfef0e5
1a5832c
 
 
 
dfef0e5
1a5832c
dfef0e5
1a5832c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfef0e5
1a5832c
dfef0e5
1a5832c
 
 
 
dfef0e5
1a5832c
dfef0e5
1a5832c
 
 
 
 
 
 
 
 
 
dfef0e5
1a5832c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfef0e5
1a5832c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfef0e5
1a5832c
 
 
dfef0e5
1a5832c
 
 
 
 
 
 
 
 
 
dfef0e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
Medieval Latin NER - Custom Span-NER Architecture
=============================================================================
Core architecture for the Span-NER model utilizing a bi-encoder approach
with a frozen BGE-M3 semantic label space and XLM-RoBERTa-Large text encoder.
"""

import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel

# ---------------------------------------------------------------------------
# 1. CONFIGURATION
# ---------------------------------------------------------------------------
class Config:
    TEXT_MODEL     = "FacebookAI/xlm-roberta-large"
    TEXT_DIM       = 1024
    LABEL_MODEL    = "BAAI/bge-m3"
    LABEL_DIM      = 1024
    
    MAX_SPAN_WIDTH = 80
    WIDTH_EMB_DIM  = 64
    SPAN_HIDDEN    = 512
    ATTENTION_HEADS = 4
    
    MAX_SEQ_LEN    = 512
    PREDICT_TEMP   = 1.35 

# ---------------------------------------------------------------------------
# 2. LABEL DICTIONARY & PROMPTS
# ---------------------------------------------------------------------------
LABEL_DICT = {
    "PER":   "individual person name without any titles or roles, strictly the given name or family name",
    "ACTOR": "full noun phrase referring to a person including their name plus noble title, profession, geographic origin, or social status",
    "TITLE": "social rank, noble title, ecclesiastical office, profession, or papal rank such as comes, abbas, episcopus",
    "REL":   "word or phrase indicating family, kinship, marriage, or social relationship like filius, uxor, frater",
    "LOC":   "geographical place, settlement, city, diocese, region, or named territory",
    "INS":   "monastery, abbey, church, cell, or religious order functioning as a corporate and legal body",
    "NAT":   "natural landscape feature such as a river, stream, forest, mountain, or valley",
    "EST":   "short physical plot of land, estate, farm, meadows, woods, vineyards, or courtyards",
    "PROP":  "detailed boundary description of a property, grange, estate, or island including past owners, movables, and immovables",
    "LEG":   "legal clause declaring rights, conditions, penalties, permissions, or papal commands",
    "TRANS": "verb or phrase denoting a core transaction, confirmation, transfer, sale, gift, or donation",
    "TIM":   "time period, duration, general dating formula, indiction, or papal/royal regnal year",
    "DAT":   "specific calendar date, precise year of incarnation often starting with Anno or Datum, or named liturgical feast day",
    "MON":   "money, currency, coin, or monetary value such as libra, solidus, denarius, uncia, or marca",
    "TAX":   "customary toll, legal tax, tithe, exaction, lucrum camere, or tribute paid to an authority",
    "COM":   "harvested crops, food, physical goods, salt, wine, wax, gold, wood, or animals traded or given",
    "NUM":   "number written as a word or roman numeral, including fractions and quantities",
    "MEA":   "unit of measurement for land, volume, or weight such as mansus, carratas, aratrum, or talentum",
    "RELIC": "holy relic, cross, altar, or sacred object of veneration within a church",
}

LABEL_KEYS  = list(LABEL_DICT.keys())
LABEL_DESCS = list(LABEL_DICT.values())
LABEL2ID    = {k: i for i, k in enumerate(LABEL_KEYS)}
ID2LABEL    = {i: k for k, i in LABEL2ID.items()}
NUM_LABELS  = len(LABEL_DICT)

def char_tokenize(text):
    return [{"token": m.group(), "start": m.start(), "end": m.end()}
            for m in re.finditer(r'\w+|[^\w\s]', text)]

# ---------------------------------------------------------------------------
# 3. MODEL ARCHITECTURE
# ---------------------------------------------------------------------------
class SpanRepLayer(nn.Module):
    def __init__(self, hidden, max_span_width, width_emb_dim, num_heads=4):
        super().__init__()
        self.max_span_width = max_span_width
        self.num_heads = num_heads
        self.width_emb = nn.Embedding(max_span_width + 1, width_emb_dim)

        self.att_query = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.GELU(),
            nn.Linear(hidden // 2, num_heads)
        )
        self.span_dim = 2 * hidden + (num_heads * hidden) + width_emb_dim

    def forward(self, seq_out, spans):
        B, S, _ = spans.shape
        L = seq_out.size(1)
        H = seq_out.size(-1)

        h_start = seq_out[torch.arange(B).unsqueeze(1), spans[:,:,0]]
        h_end   = seq_out[torch.arange(B).unsqueeze(1), spans[:,:,1]]
        width   = spans[:,:,2].clamp(0, self.max_span_width)
        w_emb   = self.width_emb(width)

        idx = torch.arange(L, device=seq_out.device).view(1, 1, L)
        mask = (idx >= spans[:,:,0:1]) & (idx <= spans[:,:,1:2]) 

        att_logits = self.att_query(seq_out) 
        att_logits = att_logits.unsqueeze(1).expand(B, S, L, self.num_heads)

        mask_expanded = mask.unsqueeze(-1).expand(-1, -1, -1, self.num_heads)
        att_logits = att_logits.masked_fill(~mask_expanded, float('-inf'))
        att_weights = F.softmax(att_logits, dim=2) 

        h_pool = torch.einsum('bslm,blh->bsmh', att_weights, seq_out) 
        h_pool = h_pool.reshape(B, S, self.num_heads * H)

        return torch.cat([h_start, h_end, h_pool, w_emb], dim=-1)

class SpanNERModel(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.text_enc = AutoModel.from_pretrained(cfg.TEXT_MODEL, add_pooling_layer=False)
        self.label_enc = AutoModel.from_pretrained(cfg.LABEL_MODEL)
        
        self.span_layer = SpanRepLayer(cfg.TEXT_DIM, cfg.MAX_SPAN_WIDTH, cfg.WIDTH_EMB_DIM, num_heads=cfg.ATTENTION_HEADS)

        self.label_proj = nn.Sequential(
            nn.Linear(cfg.LABEL_DIM, cfg.SPAN_HIDDEN),
            nn.GELU(),
            nn.LayerNorm(cfg.SPAN_HIDDEN),
            nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN)
        )
        self.span_proj = nn.Sequential(
            nn.Linear(self.span_layer.span_dim, cfg.SPAN_HIDDEN),
            nn.GELU(),
            nn.LayerNorm(cfg.SPAN_HIDDEN),
            nn.Dropout(0.2),
            nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN)
        )

        self.logit_scale = nn.Parameter(torch.tensor(1.0)) 
        self._raw_label_embs = None

    @torch.no_grad()
    def _build_label_cache(self, label_tokenizer, device):
        enc = label_tokenizer(
            LABEL_DESCS, padding=True, truncation=True, max_length=128, return_tensors="pt"
        ).to(device)
        out  = self.label_enc(**enc).last_hidden_state
        mask = enc["attention_mask"].unsqueeze(-1).float()
        pooled = F.normalize((out * mask).sum(1) / mask.sum(1), dim=-1, eps=1e-8)
        self._raw_label_embs = pooled.detach()

    def predict(self, text, label_tokenizer, text_tokenizer, threshold, flat_ner, device):
        self.eval()
        tokens_info = char_tokenize(text)
        all_tokens = [t["token"] for t in tokens_info]
        n_total = len(all_tokens)
        if n_total == 0: return []

        stride = self.cfg.MAX_SEQ_LEN - 50
        chunk_starts = range(0, max(1, n_total), stride)
        all_candidates = []

        with torch.no_grad():
            if self._raw_label_embs is None or self._raw_label_embs.device != device:
                self._build_label_cache(label_tokenizer, device)
            label_feat = self.label_proj(self._raw_label_embs.to(device))

            for start_idx in chunk_starts:
                end_idx = min(start_idx + self.cfg.MAX_SEQ_LEN - 2, n_total)
                tokens = all_tokens[start_idx:end_idx]
                n = len(tokens)
                if n == 0: continue

                enc = text_tokenizer(tokens, is_split_into_words=True, max_length=self.cfg.MAX_SEQ_LEN,
                                     truncation=True, padding=False, return_tensors="pt").to(device)
                word_ids = enc.word_ids(batch_index=0)

                first_sw, last_sw = {}, {}
                for sw_idx, w_idx in enumerate(word_ids):
                    if w_idx is not None:
                        if w_idx not in first_sw: first_sw[w_idx] = sw_idx
                        last_sw[w_idx] = sw_idx

                span_list, span_word_bounds = [], []
                for ws in range(n):
                    for we in range(ws, min(ws + self.cfg.MAX_SPAN_WIDTH, n)):
                        if ws not in first_sw or we not in last_sw: continue
                        span_list.append([first_sw[ws], last_sw[we], we - ws + 1])
                        span_word_bounds.append((start_idx + ws, start_idx + we))

                if not span_list: continue

                seq_out = self.text_enc(input_ids=enc["input_ids"].to(device),
                                        attention_mask=enc["attention_mask"].to(device)).last_hidden_state

                chunk_logits = []
                for i in range(0, len(span_list), 4096):
                    chunk = torch.tensor(span_list[i:i+4096], dtype=torch.long).unsqueeze(0).to(device)
                    sf    = self.span_proj(self.span_layer(seq_out, chunk))

                    sf_norm = F.normalize(sf, p=2, dim=-1)
                    lf_norm = F.normalize(label_feat, p=2, dim=-1)
                    scale = self.logit_scale.exp().clamp(max=120.0)

                    ch_logits = torch.einsum('bsd,ld->bsl', sf_norm, lf_norm) * scale
                    chunk_logits.append(ch_logits.squeeze(0).cpu())

                scores = torch.sigmoid(torch.cat(chunk_logits, dim=0) / self.cfg.PREDICT_TEMP)

                for si, (g_ws, g_we) in enumerate(span_word_bounds):
                    for li in range(NUM_LABELS):
                        score = scores[si, li].item()
                        if score >= threshold:
                            all_candidates.append((g_ws, g_we, ID2LABEL[li], score))

        unique_cands = {}
        for ws, we, label, score in all_candidates:
            key = (ws, we, label)
            if key not in unique_cands or score > unique_cands[key]:
                unique_cands[key] = score

        final_candidates = [(ws, we, lbl, sc) for (ws, we, lbl), sc in unique_cands.items()]
        final_candidates.sort(key=lambda x: -x[3])

        taken, result = set(), []
        for ws, we, label, score in final_candidates:
            covered = set(range(ws, we + 1))
            if flat_ner and covered & taken: continue
            if flat_ner: taken |= covered
            
            start_char = tokens_info[ws]["start"]
            end_char = tokens_info[we]["end"]
            text_span = text[start_char:end_char]
            
            result.append({
                "label": label,
                "score": round(score, 4),
                "text": text_span,
                "start_char": start_char,
                "end_char": end_char,
                "start_word": ws,
                "end_word": we
            })

        return result