File size: 19,837 Bytes
be5f706
efb213a
 
 
 
 
be5f706
 
efb213a
 
 
 
 
 
 
 
 
 
 
8c50d16
ed49faa
be5f706
 
efb213a
 
be5f706
efb213a
 
 
 
 
 
 
 
 
 
be5f706
efb213a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be5f706
 
 
 
 
 
 
 
 
efb213a
 
ed49faa
be5f706
efb213a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed49faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efb213a
 
 
 
 
 
 
be5f706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efb213a
be5f706
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
"""
Tiny BERT models for anime filename token classification.

The default linear token-classification head is kept for compatibility.  A
learned linear-chain CRF head is also available for structural sequence-label
training while preserving the same emission logits used by the thin runtime.
"""

from __future__ import annotations

import os
from typing import List, Optional

import torch
from torch import nn
from transformers import BertConfig, BertForTokenClassification, BertModel, BertPreTrainedModel
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.modeling_utils import PreTrainedModel

from .config import Config
from .labels import infer_legacy_id2label, label_migration_sources


class LinearChainCRF(nn.Module):
    """A small batched linear-chain CRF for BIO token classification."""

    def __init__(self, num_labels: int, id2label: Optional[dict] = None) -> None:
        super().__init__()
        self.num_labels = num_labels
        self.start_transitions = nn.Parameter(torch.zeros(num_labels))
        self.end_transitions = nn.Parameter(torch.zeros(num_labels))
        self.transitions = nn.Parameter(torch.zeros(num_labels, num_labels))
        self.register_buffer("start_allowed", torch.ones(num_labels, dtype=torch.bool))
        self.register_buffer("transition_allowed", torch.ones(num_labels, num_labels, dtype=torch.bool))
        if id2label:
            self._configure_bio_masks(id2label)

    @staticmethod
    def _normalize_label_map(id2label: dict) -> dict[int, str]:
        return {int(label_id): str(label) for label_id, label in id2label.items()}

    def _configure_bio_masks(self, id2label: dict) -> None:
        label_map = self._normalize_label_map(id2label)
        for prev_id in range(self.num_labels):
            prev_label = label_map.get(prev_id, "O")
            self.start_allowed[prev_id] = not prev_label.startswith("I-")
            for next_id in range(self.num_labels):
                next_label = label_map.get(next_id, "O")
                if next_label.startswith("I-"):
                    entity = next_label[2:]
                    allowed = prev_label in {f"B-{entity}", f"I-{entity}"}
                else:
                    allowed = True
                self.transition_allowed[prev_id, next_id] = allowed

    def neg_log_likelihood(
        self,
        emissions: torch.Tensor,
        tags: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        """Return mean negative log likelihood for a padded batch."""
        if emissions.ndim != 3:
            raise ValueError("emissions must have shape [batch, seq, labels]")
        if tags.shape != emissions.shape[:2]:
            raise ValueError("tags must have shape [batch, seq]")
        if mask.shape != tags.shape:
            raise ValueError("mask must have shape [batch, seq]")

        mask = mask.bool()
        lengths = mask.long().sum(dim=1)
        if torch.any(lengths == 0):
            raise ValueError("CRF received an empty token sequence")

        safe_tags = tags.masked_fill(~mask, 0)
        log_partition = self._compute_log_partition(emissions, mask)
        gold_score = self._compute_gold_score(emissions, safe_tags, mask, lengths)
        return (log_partition - gold_score).mean()

    def _compute_log_partition(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, _num_labels = emissions.shape
        emissions = emissions.float()
        start_transitions = self.start_transitions.float()
        transition_scores = self.transitions.float()
        scores = start_transitions + emissions[:, 0]

        for idx in range(1, sequence_length):
            next_scores = (
                scores.unsqueeze(2)
                + transition_scores.unsqueeze(0)
                + emissions[:, idx].unsqueeze(1)
            )
            next_scores = torch.logsumexp(next_scores, dim=1)
            scores = torch.where(mask[:, idx].unsqueeze(1), next_scores, scores)

        scores = scores + self.end_transitions
        return torch.logsumexp(scores, dim=1)

    def _compute_gold_score(
        self,
        emissions: torch.Tensor,
        tags: torch.Tensor,
        mask: torch.Tensor,
        lengths: torch.Tensor,
    ) -> torch.Tensor:
        emissions = emissions.float()
        start_transitions = self.start_transitions.float()
        transition_scores = self.transitions.float()
        end_transitions = self.end_transitions.float()
        batch_indices = torch.arange(emissions.shape[0], device=emissions.device)
        score = start_transitions[tags[:, 0]]
        score = score + emissions[batch_indices, 0, tags[:, 0]]

        for idx in range(1, emissions.shape[1]):
            transition_score = transition_scores[tags[:, idx - 1], tags[:, idx]]
            emission_score = emissions[batch_indices, idx, tags[:, idx]]
            score = score + (transition_score + emission_score) * mask[:, idx]

        last_tag_indices = (lengths - 1).unsqueeze(1)
        last_tags = tags.gather(1, last_tag_indices).squeeze(1)
        return score + end_transitions[last_tags]

    def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> List[List[int]]:
        """Viterbi decode a padded batch and return variable-length label IDs."""
        if emissions.ndim != 3:
            raise ValueError("emissions must have shape [batch, seq, labels]")
        mask = mask.bool()
        lengths = mask.long().sum(dim=1)
        if torch.any(lengths == 0):
            raise ValueError("CRF received an empty token sequence")

        start_transitions = self.start_transitions.masked_fill(~self.start_allowed, float("-inf"))
        transition_scores = self.transitions.masked_fill(~self.transition_allowed, float("-inf"))
        scores = start_transitions + emissions[:, 0]
        history: List[torch.Tensor] = []

        for idx in range(1, emissions.shape[1]):
            next_scores = scores.unsqueeze(2) + transition_scores.unsqueeze(0)
            best_scores, best_tags = next_scores.max(dim=1)
            best_scores = best_scores + emissions[:, idx]
            scores = torch.where(mask[:, idx].unsqueeze(1), best_scores, scores)
            history.append(best_tags)

        scores = scores + self.end_transitions
        best_last_tags = scores.argmax(dim=1)

        paths: List[List[int]] = []
        for batch_idx in range(emissions.shape[0]):
            length = int(lengths[batch_idx].item())
            best_tag = int(best_last_tags[batch_idx].item())
            path = [best_tag]
            for hist in reversed(history[: max(0, length - 1)]):
                best_tag = int(hist[batch_idx, best_tag].item())
                path.append(best_tag)
            path.reverse()
            paths.append(path)
        return paths


class BertCrfForTokenClassification(BertPreTrainedModel):
    """BERT emission classifier trained with a learned CRF sequence loss."""

    config_class = BertConfig

    def __init__(self, config: BertConfig) -> None:
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config, add_pooling_layer=False)
        classifier_dropout = getattr(config, "classifier_dropout", None)
        dropout_prob = classifier_dropout if classifier_dropout is not None else config.hidden_dropout_prob
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = LinearChainCRF(config.num_labels, getattr(config, "id2label", None))
        self.post_init()
        # Keep CRF transitions neutral when bootstrapping from a linear checkpoint.
        nn.init.zeros_(self.crf.start_transitions)
        nn.init.zeros_(self.crf.end_transitions)
        nn.init.zeros_(self.crf.transitions)

    def _crf_inputs(
        self,
        logits: torch.Tensor,
        labels: Optional[torch.Tensor],
        attention_mask: Optional[torch.Tensor],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
        if logits.shape[1] <= 2:
            raise ValueError("CRF token classification expects CLS, tokens, and SEP positions")
        emissions = logits[:, 1:-1, :]
        if attention_mask is None:
            if labels is None:
                mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=logits.device)
            else:
                mask = labels[:, 1:-1].ne(-100)
        else:
            if labels is None:
                real_lengths = attention_mask.long().sum(dim=1).clamp_min(2) - 2
                positions = torch.arange(emissions.shape[1], device=logits.device).unsqueeze(0)
                mask = positions < real_lengths.unsqueeze(1)
            else:
                mask = attention_mask[:, 1:-1].bool()
                mask = mask & labels[:, 1:-1].ne(-100)
        tags = labels[:, 1:-1] if labels is not None else None
        return emissions, tags, mask

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> TokenClassifierOutput:
        return_dict = return_dict if return_dict is not None else getattr(self.config, "return_dict", True)
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = self.dropout(outputs[0])
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            emissions, tags, mask = self._crf_inputs(logits, labels, attention_mask)
            if tags is None:
                raise ValueError("labels are required for CRF loss")
            loss = self.crf.neg_log_likelihood(emissions, tags, mask)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def decode(self, logits: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> List[List[int]]:
        """Decode full-sequence logits, excluding CLS/SEP and padding positions."""
        emissions, _tags, mask = self._crf_inputs(logits, None, attention_mask)
        return self.crf.decode(emissions, mask)


def build_bert_config(config: Config) -> BertConfig:
    """Build the Hugging Face BERT config shared by both model heads."""
    return BertConfig(
        vocab_size=config.vocab_size,
        hidden_size=config.hidden_size,
        num_hidden_layers=config.num_hidden_layers,
        num_attention_heads=config.num_attention_heads,
        intermediate_size=config.intermediate_size,
        max_position_embeddings=config.max_position_embeddings,
        num_labels=config.num_labels,
        hidden_dropout_prob=config.hidden_dropout_prob,
        attention_probs_dropout_prob=config.attention_probs_dropout_prob,
        id2label=config.id2label,
        label2id=config.label2id,
        label_schema_version=config.label_schema_version,
    )


def normalize_model_head(model_head: Optional[str]) -> str:
    head = (model_head or "linear").strip().lower()
    if head not in {"linear", "crf"}:
        raise ValueError(f"Unsupported model head: {model_head}")
    return head


def create_model(config: Config, model_head: str = "linear") -> PreTrainedModel:
    """
    Create a Tiny BERT model for token classification.

    Args:
        config: Config object with model hyperparameters.
        model_head: ``linear`` for Hugging Face's standard token classifier or
            ``crf`` for a learned linear-chain CRF sequence head.
    """
    head = normalize_model_head(model_head)
    bert_config = build_bert_config(config)
    bert_config.model_head = head
    if head == "crf":
        bert_config.architectures = ["BertCrfForTokenClassification"]
        return BertCrfForTokenClassification(bert_config)
    bert_config.architectures = ["BertForTokenClassification"]
    return BertForTokenClassification(bert_config)


def infer_model_head(config: BertConfig) -> str:
    head = getattr(config, "model_head", None)
    if head:
        return normalize_model_head(str(head))
    architectures = getattr(config, "architectures", None) or []
    if any("Crf" in str(architecture) or "CRF" in str(architecture) for architecture in architectures):
        return "crf"
    return "linear"


def load_model(model_dir: str, model_head: Optional[str] = None) -> PreTrainedModel:
    """Load a linear or CRF token classifier from a Hugging Face checkpoint."""
    config = BertConfig.from_pretrained(model_dir)
    head = normalize_model_head(model_head) if model_head is not None else infer_model_head(config)
    if head == "crf":
        return BertCrfForTokenClassification.from_pretrained(model_dir)
    return BertForTokenClassification.from_pretrained(model_dir)


def _model_id2label_for_migration(model: PreTrainedModel) -> dict[int, str]:
    raw_id2label = getattr(model.config, "id2label", None) or {}
    normalized = {int(label_id): str(label) for label_id, label in raw_id2label.items()}
    classifier = getattr(model, "classifier", None)
    out_features = getattr(classifier, "out_features", None)
    if out_features is not None and len(normalized) != int(out_features):
        inferred = infer_legacy_id2label(int(out_features))
        if inferred is not None:
            return inferred
    return normalized


def migrate_token_classifier_labels(
    model: PreTrainedModel,
    target_label2id: dict[str, int],
    target_id2label: dict[int, str],
) -> dict[str, object]:
    """
    Expand or reorder token-classification label rows for the shared schema.

    Exact labels are copied by name. Legacy 15-label TITLE rows initialize all
    title-like rows, and legacy SEASON rows initialize PATH_SEASON.
    """
    classifier = getattr(model, "classifier", None)
    if classifier is None or not isinstance(classifier, nn.Linear):
        return {"changed": False, "reason": "no_linear_classifier"}

    target_id2label = {int(label_id): str(label) for label_id, label in target_id2label.items()}
    target_label2id = {str(label): int(label_id) for label, label_id in target_label2id.items()}
    old_id2label = _model_id2label_for_migration(model)
    old_label2id = {label: label_id for label_id, label in old_id2label.items()}
    old_num_labels = int(classifier.out_features)
    new_num_labels = len(target_label2id)

    same_schema = (
        old_num_labels == new_num_labels
        and all(old_id2label.get(idx) == target_id2label.get(idx) for idx in range(new_num_labels))
    )
    if same_schema:
        model.config.num_labels = new_num_labels
        model.config.id2label = target_id2label
        model.config.label2id = target_label2id
        return {"changed": False, "copied": new_num_labels, "target_labels": new_num_labels}

    old_weight = classifier.weight.detach()
    old_bias = classifier.bias.detach() if classifier.bias is not None else None
    new_classifier = nn.Linear(
        classifier.in_features,
        new_num_labels,
        bias=classifier.bias is not None,
        device=old_weight.device,
        dtype=old_weight.dtype,
    )
    nn.init.normal_(
        new_classifier.weight,
        mean=0.0,
        std=getattr(model.config, "initializer_range", 0.02),
    )
    if new_classifier.bias is not None:
        nn.init.zeros_(new_classifier.bias)

    row_sources: dict[int, int] = {}
    copied = 0
    for target_label, target_id in target_label2id.items():
        for source_label in label_migration_sources(target_label):
            source_id = old_label2id.get(source_label)
            if source_id is None or source_id >= old_num_labels:
                continue
            new_classifier.weight.data[target_id].copy_(old_weight[source_id])
            if new_classifier.bias is not None and old_bias is not None:
                new_classifier.bias.data[target_id].copy_(old_bias[source_id])
            row_sources[target_id] = source_id
            copied += 1
            break

    model.classifier = new_classifier
    model.num_labels = new_num_labels
    model.config.num_labels = new_num_labels
    model.config.id2label = target_id2label
    model.config.label2id = target_label2id

    if hasattr(model, "crf"):
        old_crf = model.crf
        new_crf = LinearChainCRF(new_num_labels, target_id2label).to(
            device=old_weight.device,
            dtype=old_weight.dtype,
        )
        nn.init.zeros_(new_crf.start_transitions)
        nn.init.zeros_(new_crf.end_transitions)
        nn.init.zeros_(new_crf.transitions)
        with torch.no_grad():
            for target_id, source_id in row_sources.items():
                if source_id < old_crf.start_transitions.shape[0]:
                    new_crf.start_transitions[target_id].copy_(old_crf.start_transitions[source_id])
                    new_crf.end_transitions[target_id].copy_(old_crf.end_transitions[source_id])
            for target_to_id, source_to_id in row_sources.items():
                for target_from_id, source_from_id in row_sources.items():
                    if (
                        source_from_id < old_crf.transitions.shape[0]
                        and source_to_id < old_crf.transitions.shape[1]
                    ):
                        new_crf.transitions[target_from_id, target_to_id].copy_(
                            old_crf.transitions[source_from_id, source_to_id]
                        )
        model.crf = new_crf

    return {
        "changed": True,
        "source_labels": old_num_labels,
        "target_labels": new_num_labels,
        "copied": copied,
    }


def save_model_head_config(model: PreTrainedModel, model_head: str) -> None:
    """Persist the selected head in config.json for later auto-loading."""
    head = normalize_model_head(model_head)
    model.config.model_head = head
    model.config.architectures = [
        "BertCrfForTokenClassification" if head == "crf" else "BertForTokenClassification"
    ]


def count_parameters(model) -> int:
    """Count total trainable parameters in a model."""
    return sum(p.numel() for p in model.parameters())


def print_model_summary(model):
    """Print model architecture summary with parameter count."""
    total_params = count_parameters(model)
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Parameter limit: 5,000,000")
    if total_params < 5_000_000:
        print(f"[OK] Within 5M limit ({(5_000_000 - total_params):,} remaining)")
    else:
        print(f"[FAIL] Exceeds 5M limit by {total_params - 5_000_000:,}")
    return total_params


if __name__ == "__main__":
    cfg = Config()
    cfg.vocab_size = 3000
    model = create_model(cfg, model_head=os.environ.get("ANIFILEBERT_MODEL_HEAD", "linear"))
    print_model_summary(model)