| | """ |
| | model.py β GlobalPointer-based NER model on top of BERT |
| | |
| | Changes vs previous version: |
| | [FIX-1] Circle Loss: correct two-term formulation (Su Jianlin style), |
| | with margin (m) and scale (gamma) params; no more logaddexp merging. |
| | [FIX-2] Numerical safety: negated pos_logits no longer turns -1e9 β +1e9; |
| | we apply the mask BEFORE negation. |
| | [FIX-3] labels .float() cast inside forward (no silent runtime error / nan). |
| | [FIX-4] valid_mask (bool, BΓL) replaces attention_mask for span masking; |
| | attention_mask is still passed to the encoder for self-attention. |
| | [FIX-5] use_rope flag for GlobalPointer's span-level RoPE (independent of |
| | BERT encoder internals). |
| | """ |
| |
|
| | import json |
| | from pathlib import Path |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import AutoModel |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | class EfficientGlobalPointer(nn.Module): |
| | """ |
| | EfficientGlobalPointer span scorer (Su Jianlin style). |
| | |
| | Differences vs standard GlobalPointer: |
| | - q/k are shared across labels: hidden -> 2 * head_size |
| | - label-specific bias per token: hidden -> 2 * num_labels |
| | (start_bias and end_bias for each label) |
| | - logits: (q @ k^T)/sqrt(D) expanded to C labels, then add biases |
| | |
| | Output shape: (B, C, L, L) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | num_labels: int, |
| | head_size: int = 64, |
| | use_rope: bool = True, |
| | dropout: float = 0.1, |
| | ): |
| | super().__init__() |
| | self.num_labels = num_labels |
| | self.head_size = head_size |
| | self.use_rope = use_rope |
| |
|
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | |
| | self.dense_qk = nn.Linear(hidden_size, head_size * 2) |
| |
|
| | |
| | self.dense_bias = nn.Linear(hidden_size, num_labels * 2) |
| |
|
| | if use_rope: |
| | self.rope = RotaryEmbedding(head_size) |
| |
|
| | def forward(self, hidden: torch.Tensor) -> torch.Tensor: |
| | """ |
| | hidden: (B, L, H) |
| | returns logits: (B, C, L, L) |
| | """ |
| | B, L, _ = hidden.shape |
| | C = self.num_labels |
| | D = self.head_size |
| |
|
| | hidden = self.dropout(hidden) |
| |
|
| | |
| | qk = self.dense_qk(hidden) |
| | q, k = qk[..., :D], qk[..., D:] |
| |
|
| | if self.use_rope: |
| | emb = self.rope(L, hidden.device) |
| | cos_ = emb.cos()[None, :, :] |
| | sin_ = emb.sin()[None, :, :] |
| | q = apply_rotary(q, cos_, sin_) |
| | k = apply_rotary(k, cos_, sin_) |
| |
|
| | |
| | base = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(D) |
| |
|
| | |
| | bias = self.dense_bias(hidden) |
| | bias = bias.view(B, L, C, 2) |
| |
|
| | |
| | start_bias = bias[..., 0].permute(0, 2, 1) |
| | end_bias = bias[..., 1].permute(0, 2, 1) |
| |
|
| | |
| | |
| | |
| | |
| | logits = ( |
| | base[:, None, :, :] + |
| | start_bias[:, :, :, None] + |
| | end_bias[:, :, None, :] |
| | ) |
| |
|
| | return logits |
| |
|
| | |
| | |
| | |
| |
|
| | class RotaryEmbedding(nn.Module): |
| | """Rotary Position Embedding for GlobalPointer span scoring.""" |
| |
|
| | def __init__(self, dim: int): |
| | super().__init__() |
| | assert dim % 2 == 0, "RoPE dim must be even" |
| | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer("inv_freq", inv_freq) |
| |
|
| | def forward(self, seq_len: int, device: torch.device) -> torch.Tensor: |
| | """Returns cos/sin interleaved tensor of shape (seq_len, dim).""" |
| | t = torch.arange(seq_len, device=device).float() |
| | freqs = torch.outer(t, self.inv_freq) |
| | emb = torch.cat([freqs, freqs], dim=-1) |
| | return emb |
| |
|
| |
|
| | def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| | half = x.shape[-1] // 2 |
| | x1, x2 = x[..., :half], x[..., half:] |
| | return torch.cat([-x2, x1], dim=-1) |
| |
|
| |
|
| | def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| | """x: (..., L, D) cos/sin: (L, D)""" |
| | return x * cos + rotate_half(x) * sin |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def multilabel_circle_loss( |
| | logits: torch.Tensor, |
| | labels: torch.Tensor, |
| | mask2d: torch.Tensor, |
| | margin: float = 0.25, |
| | gamma: float = 32.0, |
| | ) -> torch.Tensor: |
| | """ |
| | Su Jianlinβstyle Circle Loss for multi-label span classification. |
| | |
| | L = log(1 + Ξ£ exp(Ξ³Β·(s_neg + m))) + log(1 + Ξ£ exp(βΞ³Β·(s_pos β m))) |
| | |
| | Two independent logsumexp terms keep the original loss geometry intact. |
| | Mask is applied BEFORE any sign flip to avoid Β±1e9 explosions. |
| | |
| | Args: |
| | logits: raw span scores, shape (B, C, L, L) |
| | labels: float tensor {0, 1}, same shape |
| | mask2d: bool (B, 1, L, L) β True where span is valid (upper-tri + valid tokens) |
| | margin: additive margin (default 0.25) |
| | gamma: temperature / scale (default 32) |
| | """ |
| | B, C, L, _ = logits.shape |
| |
|
| | |
| | mask = mask2d.expand(B, C, L, L) |
| |
|
| | |
| | pos_mask = mask & (labels > 0.5) |
| | neg_mask = mask & (labels < 0.5) |
| |
|
| | |
| | s = logits * gamma |
| |
|
| | |
| | |
| | neg_scores = s.masked_fill(~neg_mask, float("-inf")) |
| | |
| | neg_lse = torch.logsumexp(neg_scores.view(B, C, -1), dim=-1) |
| | loss_neg = F.softplus(neg_lse + gamma * margin) |
| |
|
| | |
| | |
| | |
| | pos_scores = s.masked_fill(~pos_mask, float("-inf")) |
| | neg_pos_scores = (-pos_scores).masked_fill(~pos_mask, float("-inf")) |
| | pos_lse = torch.logsumexp(neg_pos_scores.view(B, C, -1), dim=-1) |
| | loss_pos = F.softplus(pos_lse + gamma * margin) |
| |
|
| | |
| | loss = (loss_neg + loss_pos).mean() |
| | return loss |
| |
|
| |
|
| | def multilabel_bce_loss( |
| | logits: torch.Tensor, |
| | labels: torch.Tensor, |
| | mask2d: torch.Tensor, |
| | ) -> torch.Tensor: |
| | mask = mask2d.expand_as(logits) |
| | loss = F.binary_cross_entropy_with_logits(logits, labels, reduction="none") |
| | loss = loss * mask.float() |
| | return loss.sum() / mask.float().sum().clamp(min=1) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GlobalPointer(nn.Module): |
| | """ |
| | GlobalPointer span scorer. |
| | |
| | Projects encoder hidden states to per-label (q, k) vectors and computes |
| | an (LΓL) score matrix per label. Optionally applies span-level RoPE. |
| | |
| | Note: encoder internals (inside self-attention layers) are entirely |
| | separate from this span-level RoPE β both can be active simultaneously. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | num_labels: int, |
| | head_size: int = 64, |
| | use_rope: bool = True, |
| | dropout: float = 0.1, |
| | ): |
| | super().__init__() |
| | self.num_labels = num_labels |
| | self.head_size = head_size |
| | self.use_rope = use_rope |
| |
|
| | self.dropout = nn.Dropout(dropout) |
| | |
| | self.dense = nn.Linear(hidden_size, num_labels * head_size * 2) |
| |
|
| | if use_rope: |
| | self.rope = RotaryEmbedding(head_size) |
| |
|
| | def forward( |
| | self, |
| | hidden: torch.Tensor, |
| | ) -> torch.Tensor: |
| | B, L, H = hidden.shape |
| | C = self.num_labels |
| | D = self.head_size |
| |
|
| | hidden = self.dropout(hidden) |
| | proj = self.dense(hidden) |
| | proj = proj.view(B, L, C, D * 2) |
| | q, k = proj[..., :D], proj[..., D:] |
| |
|
| | if self.use_rope: |
| | emb = self.rope(L, hidden.device) |
| | cos_ = emb.cos()[None, :, None, :] |
| | sin_ = emb.sin()[None, :, None, :] |
| | q = apply_rotary(q, cos_, sin_) |
| | k = apply_rotary(k, cos_, sin_) |
| |
|
| | |
| | q = q.permute(0, 2, 1, 3) |
| | k = k.permute(0, 2, 1, 3) |
| |
|
| | |
| | logits = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(D) |
| |
|
| | return logits |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class EcomBertNER(nn.Module): |
| | """ |
| | BERT encoder + GlobalPointer head for span-based NER. |
| | |
| | forward() signature: |
| | input_ids (B, L) β token ids |
| | attention_mask (B, L) β passed to encoder (1=real, 0=pad) |
| | labels (B, C, L, L) torch.bool, optional |
| | valid_mask (B, L) torch.bool, optional β True = valid token |
| | (excludes CLS/SEP/PAD; from dataset collate_fn) |
| | |
| | If valid_mask is not provided, falls back to attention_mask.bool() |
| | (slightly less precise β includes CLS/SEP as negative spans). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model_name: str = "bert-base-chinese", |
| | num_labels: int = 23, |
| | head_size: int = 64, |
| | loss_type: str = "circle", |
| | use_rope: bool = True, |
| | dropout: float = 0.1, |
| | cache_dir: str = None, |
| | |
| | circle_margin: float = 0.25, |
| | circle_gamma: float = 32.0, |
| | ): |
| | super().__init__() |
| | assert loss_type in ("circle", "bce"), \ |
| | f"loss_type must be 'circle' or 'bce', got {loss_type!r}" |
| |
|
| | self.loss_type = loss_type |
| | self.circle_margin = circle_margin |
| | self.circle_gamma = circle_gamma |
| |
|
| | self.encoder = AutoModel.from_pretrained( |
| | model_name, cache_dir=cache_dir |
| | ) |
| | hidden_size = self.encoder.config.hidden_size |
| |
|
| | self.global_pointer = EfficientGlobalPointer( |
| | hidden_size = hidden_size, |
| | num_labels = num_labels, |
| | head_size = head_size, |
| | use_rope = use_rope, |
| | dropout = dropout, |
| | ) |
| |
|
| | self.model_name = model_name |
| | self.num_labels = num_labels |
| | self.head_size = head_size |
| | self.use_rope = use_rope |
| | self.dropout = dropout |
| |
|
| | |
| |
|
| | @staticmethod |
| | def _build_span_mask( |
| | valid_mask: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Returns upper-triangular span mask (B, 1, L, L) where |
| | mask[b,0,i,j] = True iff i<=j and both token i and j are valid. |
| | """ |
| | |
| | row = valid_mask[:, None, :, None] |
| | col = valid_mask[:, None, None, :] |
| | pair_mask = row & col |
| |
|
| | L = valid_mask.size(1) |
| | upper_tri = torch.triu( |
| | torch.ones(L, L, dtype=torch.bool, device=valid_mask.device) |
| | ) |
| |
|
| | return pair_mask & upper_tri |
| |
|
| | |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | labels: torch.Tensor = None, |
| | valid_mask: torch.Tensor = None, |
| | ) -> dict: |
| | |
| | encoder_out = self.encoder( |
| | input_ids = input_ids, |
| | attention_mask = attention_mask, |
| | ) |
| | hidden = encoder_out.last_hidden_state |
| |
|
| | |
| | logits = self.global_pointer(hidden) |
| |
|
| | |
| | |
| | if valid_mask is None: |
| | valid_mask = attention_mask.bool() |
| |
|
| | mask2d = self._build_span_mask(valid_mask) |
| |
|
| | |
| | logits_masked = logits.masked_fill( |
| | ~mask2d.expand_as(logits), -1e4 |
| | ) |
| |
|
| | |
| | loss = None |
| | if labels is not None: |
| | |
| | labels_f = labels.float() |
| |
|
| | if self.loss_type == "circle": |
| | loss = multilabel_circle_loss( |
| | logits = logits, |
| | labels = labels_f, |
| | mask2d = mask2d, |
| | margin = self.circle_margin, |
| | gamma = self.circle_gamma, |
| | ) |
| | else: |
| | loss = multilabel_bce_loss( |
| | logits = logits, |
| | labels = labels_f, |
| | mask2d = mask2d, |
| | ) |
| |
|
| | return { |
| | "loss": loss, |
| | "logits": logits_masked, |
| | } |
| |
|
| | def save_pretrained(self, save_directory: str | Path, *, extra_config: dict | None = None) -> None: |
| | save_dir = Path(save_directory) |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | config = { |
| | "architectures": [self.__class__.__name__], |
| | "model_name": self.model_name, |
| | "num_labels": self.num_labels, |
| | "head_size": self.head_size, |
| | "loss_type": self.loss_type, |
| | "use_rope": self.use_rope, |
| | "dropout": self.dropout, |
| | "circle_margin": self.circle_margin, |
| | "circle_gamma": self.circle_gamma, |
| | } |
| | if extra_config: |
| | config.update(extra_config) |
| |
|
| | with open(save_dir / "config.json", "w", encoding="utf-8") as f: |
| | json.dump(config, f, indent=2, ensure_ascii=False) |
| |
|
| | torch.save(self.state_dict(), save_dir / "pytorch_model.bin") |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | model_dir: str | Path, |
| | *, |
| | device: torch.device | str | None = None, |
| | cache_dir: str | None = None, |
| | ) -> tuple["EcomBertNER", dict]: |
| | model_dir = Path(model_dir) |
| | with open(model_dir / "config.json", "r", encoding="utf-8") as f: |
| | cfg = json.load(f) |
| |
|
| | model = cls( |
| | model_name=cfg.get("model_name", "bert-base-chinese"), |
| | num_labels=int(cfg.get("num_labels", 23)), |
| | head_size=int(cfg.get("head_size", 64)), |
| | loss_type=str(cfg.get("loss_type", "circle")), |
| | use_rope=bool(cfg.get("use_rope", True)), |
| | dropout=float(cfg.get("dropout", 0.1)), |
| | cache_dir=cache_dir, |
| | circle_margin=float(cfg.get("circle_margin", 0.25)), |
| | circle_gamma=float(cfg.get("circle_gamma", 32.0)), |
| | ) |
| | state = torch.load(model_dir / "pytorch_model.bin", map_location="cpu", weights_only=False) |
| | model.load_state_dict(state) |
| | if device is not None: |
| | model.to(device) |
| | model.eval() |
| | return model, cfg |