temsa's picture
Publish rc15 release
6631226 verified
#!/usr/bin/env python3
from __future__ import annotations
import math
from dataclasses import dataclass
from pathlib import Path
import sys
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.utils import ModelOutput
ROOT_DIR = Path(__file__).resolve().parents[2]
if str(ROOT_DIR) not in sys.path:
sys.path.insert(0, str(ROOT_DIR))
try:
from ..irish_core_span_raw_only.model import hidden_size_from_config
except ImportError:
from experiments.irish_core_span_raw_only.model import hidden_size_from_config
@dataclass
class GlobalPointerSpanOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
span_logits: Optional[torch.Tensor] = None
def build_rope_cache(seq_len: int, head_size: int, device, dtype) -> tuple[torch.Tensor, torch.Tensor]:
position = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(-1)
index = torch.arange(head_size // 2, device=device, dtype=dtype)
theta = torch.pow(torch.tensor(10000.0, device=device, dtype=dtype), -2.0 * index / head_size)
angles = position * theta
sin_base = torch.sin(angles)
cos_base = torch.cos(angles)
sin = torch.stack((sin_base, sin_base), dim=-1).reshape(seq_len, head_size)
cos = torch.stack((cos_base, cos_base), dim=-1).reshape(seq_len, head_size)
return sin, cos
def apply_rope(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
x_even = x[..., ::2]
x_odd = x[..., 1::2]
rotated = torch.stack((-x_odd, x_even), dim=-1).reshape_as(x)
return x * cos.unsqueeze(0).unsqueeze(2) + rotated * sin.unsqueeze(0).unsqueeze(2)
class IrishCoreGlobalPointerModel(PreTrainedModel):
config_class = AutoConfig
base_model_prefix = "encoder"
def __init__(self, config):
super().__init__(config)
self.encoder = AutoModel.from_config(config)
self.num_span_labels = int(getattr(config, "num_span_labels"))
self.head_size = int(getattr(config, "global_pointer_head_size", 64))
self.use_rope = bool(getattr(config, "global_pointer_use_rope", True))
self.negative_ratio = int(getattr(config, "global_pointer_negative_ratio", 16))
self.min_negatives = int(getattr(config, "global_pointer_min_negatives", 256))
hidden_size = hidden_size_from_config(config)
dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1)))
self.dropout = nn.Dropout(dropout)
self.proj = nn.Linear(hidden_size, self.num_span_labels * self.head_size * 2)
pos_weight = float(getattr(config, "span_positive_weight", 6.0))
self.register_buffer("loss_pos_weight", torch.full((self.num_span_labels,), pos_weight), persistent=False)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
span_labels=None,
token_mask=None,
**kwargs,
) -> GlobalPointerSpanOutput:
encoder_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
**kwargs,
}
if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}:
encoder_kwargs["token_type_ids"] = token_type_ids
outputs = self.encoder(**encoder_kwargs)
hidden = self.dropout(outputs.last_hidden_state)
batch_size, seq_len, _ = hidden.shape
projected = self.proj(hidden).view(batch_size, seq_len, self.num_span_labels, self.head_size * 2)
query, key = torch.chunk(projected, 2, dim=-1)
if self.use_rope:
sin, cos = build_rope_cache(seq_len, self.head_size, hidden.device, hidden.dtype)
query = apply_rope(query, sin, cos)
key = apply_rope(key, sin, cos)
span_logits = torch.einsum("bshd,bthd->bhst", query, key) / math.sqrt(self.head_size)
if token_mask is None:
token_mask = attention_mask
if token_mask is None:
token_mask = torch.ones((batch_size, seq_len), device=hidden.device, dtype=hidden.dtype)
token_mask = token_mask.to(hidden.dtype)
pair_mask = token_mask[:, None, :, None] * token_mask[:, None, None, :]
upper_mask = torch.triu(torch.ones((seq_len, seq_len), device=hidden.device, dtype=hidden.dtype))
pair_mask = pair_mask * upper_mask.unsqueeze(0).unsqueeze(0)
masked_logits = span_logits.masked_fill(pair_mask <= 0.0, -1e4)
loss = None
if span_labels is not None:
targets = span_labels.float()
pos_weight = self.loss_pos_weight.to(hidden.device).view(1, self.num_span_labels, 1, 1)
raw_loss = F.binary_cross_entropy_with_logits(span_logits, targets, reduction="none", pos_weight=pos_weight)
valid_mask = pair_mask > 0.0
positive_mask = (targets > 0.0) & valid_mask
negative_mask = (~positive_mask) & valid_mask
positive_loss = raw_loss.masked_select(positive_mask)
negative_loss = raw_loss.masked_select(negative_mask)
if negative_loss.numel() > 0 and self.negative_ratio > 0:
positive_count = int(positive_mask.sum().item())
keep_negatives = max(self.min_negatives, positive_count * self.negative_ratio)
keep_negatives = min(keep_negatives, negative_loss.numel())
negative_loss = torch.topk(negative_loss, keep_negatives).values
parts = []
if positive_loss.numel() > 0:
parts.append(positive_loss.mean())
if negative_loss.numel() > 0:
parts.append(negative_loss.mean())
loss = sum(parts) / len(parts) if parts else raw_loss.sum() * 0.0
return GlobalPointerSpanOutput(loss=loss, span_logits=masked_logits)