#!/usr/bin/env python3 from __future__ import annotations import math from dataclasses import dataclass from pathlib import Path import sys from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoConfig, AutoModel, PreTrainedModel from transformers.utils import ModelOutput ROOT_DIR = Path(__file__).resolve().parents[2] if str(ROOT_DIR) not in sys.path: sys.path.insert(0, str(ROOT_DIR)) try: from ..irish_core_span_raw_only.model import hidden_size_from_config except ImportError: from experiments.irish_core_span_raw_only.model import hidden_size_from_config @dataclass class GlobalPointerSpanOutput(ModelOutput): loss: Optional[torch.Tensor] = None span_logits: Optional[torch.Tensor] = None def build_rope_cache(seq_len: int, head_size: int, device, dtype) -> tuple[torch.Tensor, torch.Tensor]: position = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(-1) index = torch.arange(head_size // 2, device=device, dtype=dtype) theta = torch.pow(torch.tensor(10000.0, device=device, dtype=dtype), -2.0 * index / head_size) angles = position * theta sin_base = torch.sin(angles) cos_base = torch.cos(angles) sin = torch.stack((sin_base, sin_base), dim=-1).reshape(seq_len, head_size) cos = torch.stack((cos_base, cos_base), dim=-1).reshape(seq_len, head_size) return sin, cos def apply_rope(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: x_even = x[..., ::2] x_odd = x[..., 1::2] rotated = torch.stack((-x_odd, x_even), dim=-1).reshape_as(x) return x * cos.unsqueeze(0).unsqueeze(2) + rotated * sin.unsqueeze(0).unsqueeze(2) class IrishCoreGlobalPointerModel(PreTrainedModel): config_class = AutoConfig base_model_prefix = "encoder" def __init__(self, config): super().__init__(config) self.encoder = AutoModel.from_config(config) self.num_span_labels = int(getattr(config, "num_span_labels")) self.head_size = int(getattr(config, "global_pointer_head_size", 64)) self.use_rope = bool(getattr(config, "global_pointer_use_rope", True)) self.negative_ratio = int(getattr(config, "global_pointer_negative_ratio", 16)) self.min_negatives = int(getattr(config, "global_pointer_min_negatives", 256)) hidden_size = hidden_size_from_config(config) dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1))) self.dropout = nn.Dropout(dropout) self.proj = nn.Linear(hidden_size, self.num_span_labels * self.head_size * 2) pos_weight = float(getattr(config, "span_positive_weight", 6.0)) self.register_buffer("loss_pos_weight", torch.full((self.num_span_labels,), pos_weight), persistent=False) self.post_init() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, span_labels=None, token_mask=None, **kwargs, ) -> GlobalPointerSpanOutput: encoder_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, **kwargs, } if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}: encoder_kwargs["token_type_ids"] = token_type_ids outputs = self.encoder(**encoder_kwargs) hidden = self.dropout(outputs.last_hidden_state) batch_size, seq_len, _ = hidden.shape projected = self.proj(hidden).view(batch_size, seq_len, self.num_span_labels, self.head_size * 2) query, key = torch.chunk(projected, 2, dim=-1) if self.use_rope: sin, cos = build_rope_cache(seq_len, self.head_size, hidden.device, hidden.dtype) query = apply_rope(query, sin, cos) key = apply_rope(key, sin, cos) span_logits = torch.einsum("bshd,bthd->bhst", query, key) / math.sqrt(self.head_size) if token_mask is None: token_mask = attention_mask if token_mask is None: token_mask = torch.ones((batch_size, seq_len), device=hidden.device, dtype=hidden.dtype) token_mask = token_mask.to(hidden.dtype) pair_mask = token_mask[:, None, :, None] * token_mask[:, None, None, :] upper_mask = torch.triu(torch.ones((seq_len, seq_len), device=hidden.device, dtype=hidden.dtype)) pair_mask = pair_mask * upper_mask.unsqueeze(0).unsqueeze(0) masked_logits = span_logits.masked_fill(pair_mask <= 0.0, -1e4) loss = None if span_labels is not None: targets = span_labels.float() pos_weight = self.loss_pos_weight.to(hidden.device).view(1, self.num_span_labels, 1, 1) raw_loss = F.binary_cross_entropy_with_logits(span_logits, targets, reduction="none", pos_weight=pos_weight) valid_mask = pair_mask > 0.0 positive_mask = (targets > 0.0) & valid_mask negative_mask = (~positive_mask) & valid_mask positive_loss = raw_loss.masked_select(positive_mask) negative_loss = raw_loss.masked_select(negative_mask) if negative_loss.numel() > 0 and self.negative_ratio > 0: positive_count = int(positive_mask.sum().item()) keep_negatives = max(self.min_negatives, positive_count * self.negative_ratio) keep_negatives = min(keep_negatives, negative_loss.numel()) negative_loss = torch.topk(negative_loss, keep_negatives).values parts = [] if positive_loss.numel() > 0: parts.append(positive_loss.mean()) if negative_loss.numel() > 0: parts.append(negative_loss.mean()) loss = sum(parts) / len(parts) if parts else raw_loss.sum() * 0.0 return GlobalPointerSpanOutput(loss=loss, span_logits=masked_logits)