| |
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from pathlib import Path |
| import sys |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoConfig, AutoModel, PreTrainedModel |
| from transformers.utils import ModelOutput |
|
|
| ROOT_DIR = Path(__file__).resolve().parents[2] |
| if str(ROOT_DIR) not in sys.path: |
| sys.path.insert(0, str(ROOT_DIR)) |
|
|
| try: |
| from ..irish_core_span_raw_only.model import hidden_size_from_config |
| except ImportError: |
| from experiments.irish_core_span_raw_only.model import hidden_size_from_config |
|
|
|
|
| @dataclass |
| class GlobalPointerSpanOutput(ModelOutput): |
| loss: Optional[torch.Tensor] = None |
| span_logits: Optional[torch.Tensor] = None |
|
|
|
|
| def build_rope_cache(seq_len: int, head_size: int, device, dtype) -> tuple[torch.Tensor, torch.Tensor]: |
| position = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(-1) |
| index = torch.arange(head_size // 2, device=device, dtype=dtype) |
| theta = torch.pow(torch.tensor(10000.0, device=device, dtype=dtype), -2.0 * index / head_size) |
| angles = position * theta |
| sin_base = torch.sin(angles) |
| cos_base = torch.cos(angles) |
| sin = torch.stack((sin_base, sin_base), dim=-1).reshape(seq_len, head_size) |
| cos = torch.stack((cos_base, cos_base), dim=-1).reshape(seq_len, head_size) |
| return sin, cos |
|
|
|
|
| def apply_rope(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: |
| x_even = x[..., ::2] |
| x_odd = x[..., 1::2] |
| rotated = torch.stack((-x_odd, x_even), dim=-1).reshape_as(x) |
| return x * cos.unsqueeze(0).unsqueeze(2) + rotated * sin.unsqueeze(0).unsqueeze(2) |
|
|
|
|
| class IrishCoreGlobalPointerModel(PreTrainedModel): |
| config_class = AutoConfig |
| base_model_prefix = "encoder" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.encoder = AutoModel.from_config(config) |
| self.num_span_labels = int(getattr(config, "num_span_labels")) |
| self.head_size = int(getattr(config, "global_pointer_head_size", 64)) |
| self.use_rope = bool(getattr(config, "global_pointer_use_rope", True)) |
| self.negative_ratio = int(getattr(config, "global_pointer_negative_ratio", 16)) |
| self.min_negatives = int(getattr(config, "global_pointer_min_negatives", 256)) |
| hidden_size = hidden_size_from_config(config) |
| dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1))) |
| self.dropout = nn.Dropout(dropout) |
| self.proj = nn.Linear(hidden_size, self.num_span_labels * self.head_size * 2) |
| pos_weight = float(getattr(config, "span_positive_weight", 6.0)) |
| self.register_buffer("loss_pos_weight", torch.full((self.num_span_labels,), pos_weight), persistent=False) |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| span_labels=None, |
| token_mask=None, |
| **kwargs, |
| ) -> GlobalPointerSpanOutput: |
| encoder_kwargs = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| **kwargs, |
| } |
| if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}: |
| encoder_kwargs["token_type_ids"] = token_type_ids |
| outputs = self.encoder(**encoder_kwargs) |
| hidden = self.dropout(outputs.last_hidden_state) |
| batch_size, seq_len, _ = hidden.shape |
|
|
| projected = self.proj(hidden).view(batch_size, seq_len, self.num_span_labels, self.head_size * 2) |
| query, key = torch.chunk(projected, 2, dim=-1) |
|
|
| if self.use_rope: |
| sin, cos = build_rope_cache(seq_len, self.head_size, hidden.device, hidden.dtype) |
| query = apply_rope(query, sin, cos) |
| key = apply_rope(key, sin, cos) |
|
|
| span_logits = torch.einsum("bshd,bthd->bhst", query, key) / math.sqrt(self.head_size) |
|
|
| if token_mask is None: |
| token_mask = attention_mask |
| if token_mask is None: |
| token_mask = torch.ones((batch_size, seq_len), device=hidden.device, dtype=hidden.dtype) |
| token_mask = token_mask.to(hidden.dtype) |
| pair_mask = token_mask[:, None, :, None] * token_mask[:, None, None, :] |
| upper_mask = torch.triu(torch.ones((seq_len, seq_len), device=hidden.device, dtype=hidden.dtype)) |
| pair_mask = pair_mask * upper_mask.unsqueeze(0).unsqueeze(0) |
| masked_logits = span_logits.masked_fill(pair_mask <= 0.0, -1e4) |
|
|
| loss = None |
| if span_labels is not None: |
| targets = span_labels.float() |
| pos_weight = self.loss_pos_weight.to(hidden.device).view(1, self.num_span_labels, 1, 1) |
| raw_loss = F.binary_cross_entropy_with_logits(span_logits, targets, reduction="none", pos_weight=pos_weight) |
| valid_mask = pair_mask > 0.0 |
| positive_mask = (targets > 0.0) & valid_mask |
| negative_mask = (~positive_mask) & valid_mask |
|
|
| positive_loss = raw_loss.masked_select(positive_mask) |
| negative_loss = raw_loss.masked_select(negative_mask) |
| if negative_loss.numel() > 0 and self.negative_ratio > 0: |
| positive_count = int(positive_mask.sum().item()) |
| keep_negatives = max(self.min_negatives, positive_count * self.negative_ratio) |
| keep_negatives = min(keep_negatives, negative_loss.numel()) |
| negative_loss = torch.topk(negative_loss, keep_negatives).values |
|
|
| parts = [] |
| if positive_loss.numel() > 0: |
| parts.append(positive_loss.mean()) |
| if negative_loss.numel() > 0: |
| parts.append(negative_loss.mean()) |
| loss = sum(parts) / len(parts) if parts else raw_loss.sum() * 0.0 |
|
|
| return GlobalPointerSpanOutput(loss=loss, span_logits=masked_logits) |
|
|