File size: 5,943 Bytes
d6c2695 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | #!/usr/bin/env python3
from __future__ import annotations
import math
from dataclasses import dataclass
from pathlib import Path
import sys
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.utils import ModelOutput
ROOT_DIR = Path(__file__).resolve().parents[2]
if str(ROOT_DIR) not in sys.path:
sys.path.insert(0, str(ROOT_DIR))
try:
from ..irish_core_span_raw_only.model import hidden_size_from_config
except ImportError:
from experiments.irish_core_span_raw_only.model import hidden_size_from_config
@dataclass
class GlobalPointerSpanOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
span_logits: Optional[torch.Tensor] = None
def build_rope_cache(seq_len: int, head_size: int, device, dtype) -> tuple[torch.Tensor, torch.Tensor]:
position = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(-1)
index = torch.arange(head_size // 2, device=device, dtype=dtype)
theta = torch.pow(torch.tensor(10000.0, device=device, dtype=dtype), -2.0 * index / head_size)
angles = position * theta
sin_base = torch.sin(angles)
cos_base = torch.cos(angles)
sin = torch.stack((sin_base, sin_base), dim=-1).reshape(seq_len, head_size)
cos = torch.stack((cos_base, cos_base), dim=-1).reshape(seq_len, head_size)
return sin, cos
def apply_rope(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
x_even = x[..., ::2]
x_odd = x[..., 1::2]
rotated = torch.stack((-x_odd, x_even), dim=-1).reshape_as(x)
return x * cos.unsqueeze(0).unsqueeze(2) + rotated * sin.unsqueeze(0).unsqueeze(2)
class IrishCoreGlobalPointerModel(PreTrainedModel):
config_class = AutoConfig
base_model_prefix = "encoder"
def __init__(self, config):
super().__init__(config)
self.encoder = AutoModel.from_config(config)
self.num_span_labels = int(getattr(config, "num_span_labels"))
self.head_size = int(getattr(config, "global_pointer_head_size", 64))
self.use_rope = bool(getattr(config, "global_pointer_use_rope", True))
self.negative_ratio = int(getattr(config, "global_pointer_negative_ratio", 16))
self.min_negatives = int(getattr(config, "global_pointer_min_negatives", 256))
hidden_size = hidden_size_from_config(config)
dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1)))
self.dropout = nn.Dropout(dropout)
self.proj = nn.Linear(hidden_size, self.num_span_labels * self.head_size * 2)
pos_weight = float(getattr(config, "span_positive_weight", 6.0))
self.register_buffer("loss_pos_weight", torch.full((self.num_span_labels,), pos_weight), persistent=False)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
span_labels=None,
token_mask=None,
**kwargs,
) -> GlobalPointerSpanOutput:
encoder_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
**kwargs,
}
if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}:
encoder_kwargs["token_type_ids"] = token_type_ids
outputs = self.encoder(**encoder_kwargs)
hidden = self.dropout(outputs.last_hidden_state)
batch_size, seq_len, _ = hidden.shape
projected = self.proj(hidden).view(batch_size, seq_len, self.num_span_labels, self.head_size * 2)
query, key = torch.chunk(projected, 2, dim=-1)
if self.use_rope:
sin, cos = build_rope_cache(seq_len, self.head_size, hidden.device, hidden.dtype)
query = apply_rope(query, sin, cos)
key = apply_rope(key, sin, cos)
span_logits = torch.einsum("bshd,bthd->bhst", query, key) / math.sqrt(self.head_size)
if token_mask is None:
token_mask = attention_mask
if token_mask is None:
token_mask = torch.ones((batch_size, seq_len), device=hidden.device, dtype=hidden.dtype)
token_mask = token_mask.to(hidden.dtype)
pair_mask = token_mask[:, None, :, None] * token_mask[:, None, None, :]
upper_mask = torch.triu(torch.ones((seq_len, seq_len), device=hidden.device, dtype=hidden.dtype))
pair_mask = pair_mask * upper_mask.unsqueeze(0).unsqueeze(0)
masked_logits = span_logits.masked_fill(pair_mask <= 0.0, -1e4)
loss = None
if span_labels is not None:
targets = span_labels.float()
pos_weight = self.loss_pos_weight.to(hidden.device).view(1, self.num_span_labels, 1, 1)
raw_loss = F.binary_cross_entropy_with_logits(span_logits, targets, reduction="none", pos_weight=pos_weight)
valid_mask = pair_mask > 0.0
positive_mask = (targets > 0.0) & valid_mask
negative_mask = (~positive_mask) & valid_mask
positive_loss = raw_loss.masked_select(positive_mask)
negative_loss = raw_loss.masked_select(negative_mask)
if negative_loss.numel() > 0 and self.negative_ratio > 0:
positive_count = int(positive_mask.sum().item())
keep_negatives = max(self.min_negatives, positive_count * self.negative_ratio)
keep_negatives = min(keep_negatives, negative_loss.numel())
negative_loss = torch.topk(negative_loss, keep_negatives).values
parts = []
if positive_loss.numel() > 0:
parts.append(positive_loss.mean())
if negative_loss.numel() > 0:
parts.append(negative_loss.mean())
loss = sum(parts) / len(parts) if parts else raw_loss.sum() * 0.0
return GlobalPointerSpanOutput(loss=loss, span_logits=masked_logits)
|