""" model.py — GlobalPointer-based NER model on top of BERT Changes vs previous version: [FIX-1] Circle Loss: correct two-term formulation (Su Jianlin style), with margin (m) and scale (gamma) params; no more logaddexp merging. [FIX-2] Numerical safety: negated pos_logits no longer turns -1e9 → +1e9; we apply the mask BEFORE negation. [FIX-3] labels .float() cast inside forward (no silent runtime error / nan). [FIX-4] valid_mask (bool, B×L) replaces attention_mask for span masking; attention_mask is still passed to the encoder for self-attention. [FIX-5] use_rope flag for GlobalPointer's span-level RoPE (independent of BERT encoder internals). """ import json from pathlib import Path import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel # ════════════════════════════════════════════════════════════════════════════ # EfficientGlobalPointer head # - shared q/k projection (hidden -> 2D) # - per-label token bias (hidden -> 2C) as start/end bias # - final logits: base_span + start_bias + end_bias # ════════════════════════════════════════════════════════════════════════════ class EfficientGlobalPointer(nn.Module): """ EfficientGlobalPointer span scorer (Su Jianlin style). Differences vs standard GlobalPointer: - q/k are shared across labels: hidden -> 2 * head_size - label-specific bias per token: hidden -> 2 * num_labels (start_bias and end_bias for each label) - logits: (q @ k^T)/sqrt(D) expanded to C labels, then add biases Output shape: (B, C, L, L) """ def __init__( self, hidden_size: int, num_labels: int, head_size: int = 64, use_rope: bool = True, dropout: float = 0.1, ): super().__init__() self.num_labels = num_labels self.head_size = head_size self.use_rope = use_rope self.dropout = nn.Dropout(dropout) # shared q/k: (H -> 2D) self.dense_qk = nn.Linear(hidden_size, head_size * 2) # label bias: (H -> 2C) => per token: start_bias + end_bias self.dense_bias = nn.Linear(hidden_size, num_labels * 2) if use_rope: self.rope = RotaryEmbedding(head_size) def forward(self, hidden: torch.Tensor) -> torch.Tensor: """ hidden: (B, L, H) returns logits: (B, C, L, L) """ B, L, _ = hidden.shape C = self.num_labels D = self.head_size hidden = self.dropout(hidden) # ── shared q/k ─────────────────────────────────────────────────────── qk = self.dense_qk(hidden) # (B, L, 2D) q, k = qk[..., :D], qk[..., D:] # each (B, L, D) if self.use_rope: emb = self.rope(L, hidden.device) # (L, D) cos_ = emb.cos()[None, :, :] # (1, L, D) sin_ = emb.sin()[None, :, :] q = apply_rotary(q, cos_, sin_) # (B, L, D) k = apply_rotary(k, cos_, sin_) # (B, L, D) # base span score (shared across labels): (B, L, L) base = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(D) # ── per-label start/end bias ──────────────────────────────────────── bias = self.dense_bias(hidden) # (B, L, 2C) bias = bias.view(B, L, C, 2) # (B, L, C, 2) # start/end: (B, C, L) start_bias = bias[..., 0].permute(0, 2, 1) # (B, C, L) end_bias = bias[..., 1].permute(0, 2, 1) # (B, C, L) # combine: # base: (B, 1, L, L) # start_bias: (B, C, L, 1) # end_bias: (B, C, 1, L) logits = ( base[:, None, :, :] + start_bias[:, :, :, None] + end_bias[:, :, None, :] ) # (B, C, L, L) return logits # ════════════════════════════════════════════════════════════════════════════ # RoPE helper (span-level, applied to GlobalPointer q/k) # ════════════════════════════════════════════════════════════════════════════ class RotaryEmbedding(nn.Module): """Rotary Position Embedding for GlobalPointer span scoring.""" def __init__(self, dim: int): super().__init__() assert dim % 2 == 0, "RoPE dim must be even" inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, seq_len: int, device: torch.device) -> torch.Tensor: """Returns cos/sin interleaved tensor of shape (seq_len, dim).""" t = torch.arange(seq_len, device=device).float() freqs = torch.outer(t, self.inv_freq) # (L, dim/2) emb = torch.cat([freqs, freqs], dim=-1) # (L, dim) return emb # caller does cos/sin def rotate_half(x: torch.Tensor) -> torch.Tensor: half = x.shape[-1] // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat([-x2, x1], dim=-1) def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """x: (..., L, D) cos/sin: (L, D)""" return x * cos + rotate_half(x) * sin # ════════════════════════════════════════════════════════════════════════════ # Loss functions # ════════════════════════════════════════════════════════════════════════════ def multilabel_circle_loss( logits: torch.Tensor, # (B, C, L, L) raw scores labels: torch.Tensor, # (B, C, L, L) float 0/1 mask2d: torch.Tensor, # (B, 1, L, L) bool — True = valid span position margin: float = 0.25, gamma: float = 32.0, ) -> torch.Tensor: """ Su Jianlin–style Circle Loss for multi-label span classification. L = log(1 + Σ exp(γ·(s_neg + m))) + log(1 + Σ exp(−γ·(s_pos − m))) Two independent logsumexp terms keep the original loss geometry intact. Mask is applied BEFORE any sign flip to avoid ±1e9 explosions. Args: logits: raw span scores, shape (B, C, L, L) labels: float tensor {0, 1}, same shape mask2d: bool (B, 1, L, L) — True where span is valid (upper-tri + valid tokens) margin: additive margin (default 0.25) gamma: temperature / scale (default 32) """ B, C, L, _ = logits.shape # ── expand mask to (B, C, L, L) ───────────────────────────────────────── mask = mask2d.expand(B, C, L, L) # broadcast over C # ── positions that are valid positive / valid negative ─────────────────── pos_mask = mask & (labels > 0.5) # bool neg_mask = mask & (labels < 0.5) # bool # ── scale logits ───────────────────────────────────────────────────────── s = logits * gamma # (B, C, L, L) # ── negative term: log(1 + Σ exp(s_neg + γ·m)) ────────────────────────── # Fill invalid & positive positions with -inf so they don't contribute neg_scores = s.masked_fill(~neg_mask, float("-inf")) # logsumexp over (L, L) for each (b, c) neg_lse = torch.logsumexp(neg_scores.view(B, C, -1), dim=-1) # (B, C) loss_neg = F.softplus(neg_lse + gamma * margin) # log(1+exp(...)) # ── positive term: log(1 + Σ exp(−(s_pos − γ·m))) ─────────────────────── # Fill invalid & negative positions with -inf (in the negated domain) # To avoid -(-1e9) = +1e9: we mask FIRST, then negate. pos_scores = s.masked_fill(~pos_mask, float("-inf")) neg_pos_scores = (-pos_scores).masked_fill(~pos_mask, float("-inf")) pos_lse = torch.logsumexp(neg_pos_scores.view(B, C, -1), dim=-1) # (B, C) loss_pos = F.softplus(pos_lse + gamma * margin) # ── average over labels (skip labels with no positive AND no negative) ─── loss = (loss_neg + loss_pos).mean() return loss def multilabel_bce_loss( logits: torch.Tensor, # (B, C, L, L) labels: torch.Tensor, # (B, C, L, L) float mask2d: torch.Tensor, # (B, 1, L, L) bool ) -> torch.Tensor: mask = mask2d.expand_as(logits) loss = F.binary_cross_entropy_with_logits(logits, labels, reduction="none") loss = loss * mask.float() return loss.sum() / mask.float().sum().clamp(min=1) # ════════════════════════════════════════════════════════════════════════════ # GlobalPointer head # ════════════════════════════════════════════════════════════════════════════ class GlobalPointer(nn.Module): """ GlobalPointer span scorer. Projects encoder hidden states to per-label (q, k) vectors and computes an (L×L) score matrix per label. Optionally applies span-level RoPE. Note: encoder internals (inside self-attention layers) are entirely separate from this span-level RoPE — both can be active simultaneously. """ def __init__( self, hidden_size: int, num_labels: int, head_size: int = 64, use_rope: bool = True, dropout: float = 0.1, ): super().__init__() self.num_labels = num_labels self.head_size = head_size self.use_rope = use_rope self.dropout = nn.Dropout(dropout) # Project to 2 * num_labels * head_size (q and k for every label) self.dense = nn.Linear(hidden_size, num_labels * head_size * 2) if use_rope: self.rope = RotaryEmbedding(head_size) def forward( self, hidden: torch.Tensor, # (B, L, H) ) -> torch.Tensor: # (B, C, L, L) B, L, H = hidden.shape C = self.num_labels D = self.head_size hidden = self.dropout(hidden) proj = self.dense(hidden) # (B, L, C*D*2) proj = proj.view(B, L, C, D * 2) # (B, L, C, D*2) q, k = proj[..., :D], proj[..., D:] # each (B, L, C, D) if self.use_rope: emb = self.rope(L, hidden.device) # (L, D) cos_ = emb.cos()[None, :, None, :] # (1, L, 1, D) sin_ = emb.sin()[None, :, None, :] q = apply_rotary(q, cos_, sin_) k = apply_rotary(k, cos_, sin_) # q: (B, L, C, D) → (B, C, L, D) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) # Score matrix: (B, C, L, D) × (B, C, D, L) → (B, C, L, L) logits = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(D) return logits # ════════════════════════════════════════════════════════════════════════════ # Full model # ════════════════════════════════════════════════════════════════════════════ class EcomBertNER(nn.Module): """ BERT encoder + GlobalPointer head for span-based NER. forward() signature: input_ids (B, L) — token ids attention_mask (B, L) — passed to encoder (1=real, 0=pad) labels (B, C, L, L) torch.bool, optional valid_mask (B, L) torch.bool, optional — True = valid token (excludes CLS/SEP/PAD; from dataset collate_fn) If valid_mask is not provided, falls back to attention_mask.bool() (slightly less precise — includes CLS/SEP as negative spans). """ def __init__( self, model_name: str = "bert-base-chinese", num_labels: int = 23, head_size: int = 64, loss_type: str = "circle", # "circle" | "bce" use_rope: bool = True, dropout: float = 0.1, cache_dir: str = None, # Circle Loss hyper-params (ignored for BCE) circle_margin: float = 0.25, circle_gamma: float = 32.0, ): super().__init__() assert loss_type in ("circle", "bce"), \ f"loss_type must be 'circle' or 'bce', got {loss_type!r}" self.loss_type = loss_type self.circle_margin = circle_margin self.circle_gamma = circle_gamma self.encoder = AutoModel.from_pretrained( model_name, cache_dir=cache_dir ) hidden_size = self.encoder.config.hidden_size self.global_pointer = EfficientGlobalPointer( hidden_size = hidden_size, num_labels = num_labels, head_size = head_size, use_rope = use_rope, dropout = dropout, ) self.model_name = model_name self.num_labels = num_labels self.head_size = head_size self.use_rope = use_rope self.dropout = dropout # ── span validity mask ──────────────────────────────────────────────────── @staticmethod def _build_span_mask( valid_mask: torch.Tensor, # (B, L) bool ) -> torch.Tensor: """ Returns upper-triangular span mask (B, 1, L, L) where mask[b,0,i,j] = True iff i<=j and both token i and j are valid. """ # row mask (B, 1, L, 1) & col mask (B, 1, 1, L) → (B, 1, L, L) row = valid_mask[:, None, :, None] # (B, 1, L, 1) col = valid_mask[:, None, None, :] # (B, 1, 1, L) pair_mask = row & col # (B, 1, L, L) L = valid_mask.size(1) upper_tri = torch.triu( torch.ones(L, L, dtype=torch.bool, device=valid_mask.device) ) # (L, L) return pair_mask & upper_tri # (B, 1, L, L) # ── forward ─────────────────────────────────────────────────────────────── def forward( self, input_ids: torch.Tensor, # (B, L) attention_mask: torch.Tensor, # (B, L) labels: torch.Tensor = None, # (B, C, L, L) bool valid_mask: torch.Tensor = None, # (B, L) bool ) -> dict: # ── encoder ───────────────────────────────────────────────────────── encoder_out = self.encoder( input_ids = input_ids, attention_mask = attention_mask, ) hidden = encoder_out.last_hidden_state # (B, L, H) # ── GlobalPointer logits ───────────────────────────────────────────── logits = self.global_pointer(hidden) # (B, C, L, L) # ── span validity mask ─────────────────────────────────────────────── # [FIX-4] prefer valid_mask (excludes CLS/SEP) over attention_mask if valid_mask is None: valid_mask = attention_mask.bool() mask2d = self._build_span_mask(valid_mask) # (B, 1, L, L) # Apply mask to logits for inference (fill invalid with -1e4) logits_masked = logits.masked_fill( ~mask2d.expand_as(logits), -1e4 ) # ── loss ───────────────────────────────────────────────────────────── loss = None if labels is not None: # [FIX-3] ensure float regardless of bool input from dataset labels_f = labels.float() if self.loss_type == "circle": loss = multilabel_circle_loss( logits = logits, # raw (unmasked) scores labels = labels_f, mask2d = mask2d, margin = self.circle_margin, gamma = self.circle_gamma, ) else: loss = multilabel_bce_loss( logits = logits, labels = labels_f, mask2d = mask2d, ) return { "loss": loss, "logits": logits_masked, # (B, C, L, L) } def save_pretrained(self, save_directory: str | Path, *, extra_config: dict | None = None) -> None: save_dir = Path(save_directory) save_dir.mkdir(parents=True, exist_ok=True) config = { "architectures": [self.__class__.__name__], "model_name": self.model_name, "num_labels": self.num_labels, "head_size": self.head_size, "loss_type": self.loss_type, "use_rope": self.use_rope, "dropout": self.dropout, "circle_margin": self.circle_margin, "circle_gamma": self.circle_gamma, } if extra_config: config.update(extra_config) with open(save_dir / "config.json", "w", encoding="utf-8") as f: json.dump(config, f, indent=2, ensure_ascii=False) torch.save(self.state_dict(), save_dir / "pytorch_model.bin") @classmethod def from_pretrained( cls, model_dir: str | Path, *, device: torch.device | str | None = None, cache_dir: str | None = None, ) -> tuple["EcomBertNER", dict]: model_dir = Path(model_dir) with open(model_dir / "config.json", "r", encoding="utf-8") as f: cfg = json.load(f) model = cls( model_name=cfg.get("model_name", "bert-base-chinese"), num_labels=int(cfg.get("num_labels", 23)), head_size=int(cfg.get("head_size", 64)), loss_type=str(cfg.get("loss_type", "circle")), use_rope=bool(cfg.get("use_rope", True)), dropout=float(cfg.get("dropout", 0.1)), cache_dir=cache_dir, circle_margin=float(cfg.get("circle_margin", 0.25)), circle_gamma=float(cfg.get("circle_gamma", 32.0)), ) state = torch.load(model_dir / "pytorch_model.bin", map_location="cpu", weights_only=False) model.load_state_dict(state) if device is not None: model.to(device) model.eval() return model, cfg