ecombert-ner-v1 / model.py
xinyacs's picture
Upload folder using huggingface_hub
7781e94 verified
"""
model.py β€” GlobalPointer-based NER model on top of BERT
Changes vs previous version:
[FIX-1] Circle Loss: correct two-term formulation (Su Jianlin style),
with margin (m) and scale (gamma) params; no more logaddexp merging.
[FIX-2] Numerical safety: negated pos_logits no longer turns -1e9 β†’ +1e9;
we apply the mask BEFORE negation.
[FIX-3] labels .float() cast inside forward (no silent runtime error / nan).
[FIX-4] valid_mask (bool, BΓ—L) replaces attention_mask for span masking;
attention_mask is still passed to the encoder for self-attention.
[FIX-5] use_rope flag for GlobalPointer's span-level RoPE (independent of
BERT encoder internals).
"""
import json
from pathlib import Path
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
# ════════════════════════════════════════════════════════════════════════════
# EfficientGlobalPointer head
# - shared q/k projection (hidden -> 2D)
# - per-label token bias (hidden -> 2C) as start/end bias
# - final logits: base_span + start_bias + end_bias
# ════════════════════════════════════════════════════════════════════════════
class EfficientGlobalPointer(nn.Module):
"""
EfficientGlobalPointer span scorer (Su Jianlin style).
Differences vs standard GlobalPointer:
- q/k are shared across labels: hidden -> 2 * head_size
- label-specific bias per token: hidden -> 2 * num_labels
(start_bias and end_bias for each label)
- logits: (q @ k^T)/sqrt(D) expanded to C labels, then add biases
Output shape: (B, C, L, L)
"""
def __init__(
self,
hidden_size: int,
num_labels: int,
head_size: int = 64,
use_rope: bool = True,
dropout: float = 0.1,
):
super().__init__()
self.num_labels = num_labels
self.head_size = head_size
self.use_rope = use_rope
self.dropout = nn.Dropout(dropout)
# shared q/k: (H -> 2D)
self.dense_qk = nn.Linear(hidden_size, head_size * 2)
# label bias: (H -> 2C) => per token: start_bias + end_bias
self.dense_bias = nn.Linear(hidden_size, num_labels * 2)
if use_rope:
self.rope = RotaryEmbedding(head_size)
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
"""
hidden: (B, L, H)
returns logits: (B, C, L, L)
"""
B, L, _ = hidden.shape
C = self.num_labels
D = self.head_size
hidden = self.dropout(hidden)
# ── shared q/k ───────────────────────────────────────────────────────
qk = self.dense_qk(hidden) # (B, L, 2D)
q, k = qk[..., :D], qk[..., D:] # each (B, L, D)
if self.use_rope:
emb = self.rope(L, hidden.device) # (L, D)
cos_ = emb.cos()[None, :, :] # (1, L, D)
sin_ = emb.sin()[None, :, :]
q = apply_rotary(q, cos_, sin_) # (B, L, D)
k = apply_rotary(k, cos_, sin_) # (B, L, D)
# base span score (shared across labels): (B, L, L)
base = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(D)
# ── per-label start/end bias ────────────────────────────────────────
bias = self.dense_bias(hidden) # (B, L, 2C)
bias = bias.view(B, L, C, 2) # (B, L, C, 2)
# start/end: (B, C, L)
start_bias = bias[..., 0].permute(0, 2, 1) # (B, C, L)
end_bias = bias[..., 1].permute(0, 2, 1) # (B, C, L)
# combine:
# base: (B, 1, L, L)
# start_bias: (B, C, L, 1)
# end_bias: (B, C, 1, L)
logits = (
base[:, None, :, :] +
start_bias[:, :, :, None] +
end_bias[:, :, None, :]
) # (B, C, L, L)
return logits
# ════════════════════════════════════════════════════════════════════════════
# RoPE helper (span-level, applied to GlobalPointer q/k)
# ════════════════════════════════════════════════════════════════════════════
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding for GlobalPointer span scoring."""
def __init__(self, dim: int):
super().__init__()
assert dim % 2 == 0, "RoPE dim must be even"
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, seq_len: int, device: torch.device) -> torch.Tensor:
"""Returns cos/sin interleaved tensor of shape (seq_len, dim)."""
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, self.inv_freq) # (L, dim/2)
emb = torch.cat([freqs, freqs], dim=-1) # (L, dim)
return emb # caller does cos/sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
half = x.shape[-1] // 2
x1, x2 = x[..., :half], x[..., half:]
return torch.cat([-x2, x1], dim=-1)
def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""x: (..., L, D) cos/sin: (L, D)"""
return x * cos + rotate_half(x) * sin
# ════════════════════════════════════════════════════════════════════════════
# Loss functions
# ════════════════════════════════════════════════════════════════════════════
def multilabel_circle_loss(
logits: torch.Tensor, # (B, C, L, L) raw scores
labels: torch.Tensor, # (B, C, L, L) float 0/1
mask2d: torch.Tensor, # (B, 1, L, L) bool β€” True = valid span position
margin: float = 0.25,
gamma: float = 32.0,
) -> torch.Tensor:
"""
Su Jianlin–style Circle Loss for multi-label span classification.
L = log(1 + Ξ£ exp(Ξ³Β·(s_neg + m))) + log(1 + Ξ£ exp(βˆ’Ξ³Β·(s_pos βˆ’ m)))
Two independent logsumexp terms keep the original loss geometry intact.
Mask is applied BEFORE any sign flip to avoid Β±1e9 explosions.
Args:
logits: raw span scores, shape (B, C, L, L)
labels: float tensor {0, 1}, same shape
mask2d: bool (B, 1, L, L) β€” True where span is valid (upper-tri + valid tokens)
margin: additive margin (default 0.25)
gamma: temperature / scale (default 32)
"""
B, C, L, _ = logits.shape
# ── expand mask to (B, C, L, L) ─────────────────────────────────────────
mask = mask2d.expand(B, C, L, L) # broadcast over C
# ── positions that are valid positive / valid negative ───────────────────
pos_mask = mask & (labels > 0.5) # bool
neg_mask = mask & (labels < 0.5) # bool
# ── scale logits ─────────────────────────────────────────────────────────
s = logits * gamma # (B, C, L, L)
# ── negative term: log(1 + Ξ£ exp(s_neg + Ξ³Β·m)) ──────────────────────────
# Fill invalid & positive positions with -inf so they don't contribute
neg_scores = s.masked_fill(~neg_mask, float("-inf"))
# logsumexp over (L, L) for each (b, c)
neg_lse = torch.logsumexp(neg_scores.view(B, C, -1), dim=-1) # (B, C)
loss_neg = F.softplus(neg_lse + gamma * margin) # log(1+exp(...))
# ── positive term: log(1 + Ξ£ exp(βˆ’(s_pos βˆ’ Ξ³Β·m))) ───────────────────────
# Fill invalid & negative positions with -inf (in the negated domain)
# To avoid -(-1e9) = +1e9: we mask FIRST, then negate.
pos_scores = s.masked_fill(~pos_mask, float("-inf"))
neg_pos_scores = (-pos_scores).masked_fill(~pos_mask, float("-inf"))
pos_lse = torch.logsumexp(neg_pos_scores.view(B, C, -1), dim=-1) # (B, C)
loss_pos = F.softplus(pos_lse + gamma * margin)
# ── average over labels (skip labels with no positive AND no negative) ───
loss = (loss_neg + loss_pos).mean()
return loss
def multilabel_bce_loss(
logits: torch.Tensor, # (B, C, L, L)
labels: torch.Tensor, # (B, C, L, L) float
mask2d: torch.Tensor, # (B, 1, L, L) bool
) -> torch.Tensor:
mask = mask2d.expand_as(logits)
loss = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
loss = loss * mask.float()
return loss.sum() / mask.float().sum().clamp(min=1)
# ════════════════════════════════════════════════════════════════════════════
# GlobalPointer head
# ════════════════════════════════════════════════════════════════════════════
class GlobalPointer(nn.Module):
"""
GlobalPointer span scorer.
Projects encoder hidden states to per-label (q, k) vectors and computes
an (LΓ—L) score matrix per label. Optionally applies span-level RoPE.
Note: encoder internals (inside self-attention layers) are entirely
separate from this span-level RoPE β€” both can be active simultaneously.
"""
def __init__(
self,
hidden_size: int,
num_labels: int,
head_size: int = 64,
use_rope: bool = True,
dropout: float = 0.1,
):
super().__init__()
self.num_labels = num_labels
self.head_size = head_size
self.use_rope = use_rope
self.dropout = nn.Dropout(dropout)
# Project to 2 * num_labels * head_size (q and k for every label)
self.dense = nn.Linear(hidden_size, num_labels * head_size * 2)
if use_rope:
self.rope = RotaryEmbedding(head_size)
def forward(
self,
hidden: torch.Tensor, # (B, L, H)
) -> torch.Tensor: # (B, C, L, L)
B, L, H = hidden.shape
C = self.num_labels
D = self.head_size
hidden = self.dropout(hidden)
proj = self.dense(hidden) # (B, L, C*D*2)
proj = proj.view(B, L, C, D * 2) # (B, L, C, D*2)
q, k = proj[..., :D], proj[..., D:] # each (B, L, C, D)
if self.use_rope:
emb = self.rope(L, hidden.device) # (L, D)
cos_ = emb.cos()[None, :, None, :] # (1, L, 1, D)
sin_ = emb.sin()[None, :, None, :]
q = apply_rotary(q, cos_, sin_)
k = apply_rotary(k, cos_, sin_)
# q: (B, L, C, D) β†’ (B, C, L, D)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
# Score matrix: (B, C, L, D) Γ— (B, C, D, L) β†’ (B, C, L, L)
logits = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(D)
return logits
# ════════════════════════════════════════════════════════════════════════════
# Full model
# ════════════════════════════════════════════════════════════════════════════
class EcomBertNER(nn.Module):
"""
BERT encoder + GlobalPointer head for span-based NER.
forward() signature:
input_ids (B, L) β€” token ids
attention_mask (B, L) β€” passed to encoder (1=real, 0=pad)
labels (B, C, L, L) torch.bool, optional
valid_mask (B, L) torch.bool, optional β€” True = valid token
(excludes CLS/SEP/PAD; from dataset collate_fn)
If valid_mask is not provided, falls back to attention_mask.bool()
(slightly less precise β€” includes CLS/SEP as negative spans).
"""
def __init__(
self,
model_name: str = "bert-base-chinese",
num_labels: int = 23,
head_size: int = 64,
loss_type: str = "circle", # "circle" | "bce"
use_rope: bool = True,
dropout: float = 0.1,
cache_dir: str = None,
# Circle Loss hyper-params (ignored for BCE)
circle_margin: float = 0.25,
circle_gamma: float = 32.0,
):
super().__init__()
assert loss_type in ("circle", "bce"), \
f"loss_type must be 'circle' or 'bce', got {loss_type!r}"
self.loss_type = loss_type
self.circle_margin = circle_margin
self.circle_gamma = circle_gamma
self.encoder = AutoModel.from_pretrained(
model_name, cache_dir=cache_dir
)
hidden_size = self.encoder.config.hidden_size
self.global_pointer = EfficientGlobalPointer(
hidden_size = hidden_size,
num_labels = num_labels,
head_size = head_size,
use_rope = use_rope,
dropout = dropout,
)
self.model_name = model_name
self.num_labels = num_labels
self.head_size = head_size
self.use_rope = use_rope
self.dropout = dropout
# ── span validity mask ────────────────────────────────────────────────────
@staticmethod
def _build_span_mask(
valid_mask: torch.Tensor, # (B, L) bool
) -> torch.Tensor:
"""
Returns upper-triangular span mask (B, 1, L, L) where
mask[b,0,i,j] = True iff i<=j and both token i and j are valid.
"""
# row mask (B, 1, L, 1) & col mask (B, 1, 1, L) β†’ (B, 1, L, L)
row = valid_mask[:, None, :, None] # (B, 1, L, 1)
col = valid_mask[:, None, None, :] # (B, 1, 1, L)
pair_mask = row & col # (B, 1, L, L)
L = valid_mask.size(1)
upper_tri = torch.triu(
torch.ones(L, L, dtype=torch.bool, device=valid_mask.device)
) # (L, L)
return pair_mask & upper_tri # (B, 1, L, L)
# ── forward ───────────────────────────────────────────────────────────────
def forward(
self,
input_ids: torch.Tensor, # (B, L)
attention_mask: torch.Tensor, # (B, L)
labels: torch.Tensor = None, # (B, C, L, L) bool
valid_mask: torch.Tensor = None, # (B, L) bool
) -> dict:
# ── encoder ─────────────────────────────────────────────────────────
encoder_out = self.encoder(
input_ids = input_ids,
attention_mask = attention_mask,
)
hidden = encoder_out.last_hidden_state # (B, L, H)
# ── GlobalPointer logits ─────────────────────────────────────────────
logits = self.global_pointer(hidden) # (B, C, L, L)
# ── span validity mask ───────────────────────────────────────────────
# [FIX-4] prefer valid_mask (excludes CLS/SEP) over attention_mask
if valid_mask is None:
valid_mask = attention_mask.bool()
mask2d = self._build_span_mask(valid_mask) # (B, 1, L, L)
# Apply mask to logits for inference (fill invalid with -1e4)
logits_masked = logits.masked_fill(
~mask2d.expand_as(logits), -1e4
)
# ── loss ─────────────────────────────────────────────────────────────
loss = None
if labels is not None:
# [FIX-3] ensure float regardless of bool input from dataset
labels_f = labels.float()
if self.loss_type == "circle":
loss = multilabel_circle_loss(
logits = logits, # raw (unmasked) scores
labels = labels_f,
mask2d = mask2d,
margin = self.circle_margin,
gamma = self.circle_gamma,
)
else:
loss = multilabel_bce_loss(
logits = logits,
labels = labels_f,
mask2d = mask2d,
)
return {
"loss": loss,
"logits": logits_masked, # (B, C, L, L)
}
def save_pretrained(self, save_directory: str | Path, *, extra_config: dict | None = None) -> None:
save_dir = Path(save_directory)
save_dir.mkdir(parents=True, exist_ok=True)
config = {
"architectures": [self.__class__.__name__],
"model_name": self.model_name,
"num_labels": self.num_labels,
"head_size": self.head_size,
"loss_type": self.loss_type,
"use_rope": self.use_rope,
"dropout": self.dropout,
"circle_margin": self.circle_margin,
"circle_gamma": self.circle_gamma,
}
if extra_config:
config.update(extra_config)
with open(save_dir / "config.json", "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
torch.save(self.state_dict(), save_dir / "pytorch_model.bin")
@classmethod
def from_pretrained(
cls,
model_dir: str | Path,
*,
device: torch.device | str | None = None,
cache_dir: str | None = None,
) -> tuple["EcomBertNER", dict]:
model_dir = Path(model_dir)
with open(model_dir / "config.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
model = cls(
model_name=cfg.get("model_name", "bert-base-chinese"),
num_labels=int(cfg.get("num_labels", 23)),
head_size=int(cfg.get("head_size", 64)),
loss_type=str(cfg.get("loss_type", "circle")),
use_rope=bool(cfg.get("use_rope", True)),
dropout=float(cfg.get("dropout", 0.1)),
cache_dir=cache_dir,
circle_margin=float(cfg.get("circle_margin", 0.25)),
circle_gamma=float(cfg.get("circle_gamma", 32.0)),
)
state = torch.load(model_dir / "pytorch_model.bin", map_location="cpu", weights_only=False)
model.load_state_dict(state)
if device is not None:
model.to(device)
model.eval()
return model, cfg