File size: 5,943 Bytes
d6c2695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
from __future__ import annotations

import math
from dataclasses import dataclass
from pathlib import Path
import sys
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.utils import ModelOutput

ROOT_DIR = Path(__file__).resolve().parents[2]
if str(ROOT_DIR) not in sys.path:
    sys.path.insert(0, str(ROOT_DIR))

try:
    from ..irish_core_span_raw_only.model import hidden_size_from_config
except ImportError:
    from experiments.irish_core_span_raw_only.model import hidden_size_from_config


@dataclass
class GlobalPointerSpanOutput(ModelOutput):
    loss: Optional[torch.Tensor] = None
    span_logits: Optional[torch.Tensor] = None


def build_rope_cache(seq_len: int, head_size: int, device, dtype) -> tuple[torch.Tensor, torch.Tensor]:
    position = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(-1)
    index = torch.arange(head_size // 2, device=device, dtype=dtype)
    theta = torch.pow(torch.tensor(10000.0, device=device, dtype=dtype), -2.0 * index / head_size)
    angles = position * theta
    sin_base = torch.sin(angles)
    cos_base = torch.cos(angles)
    sin = torch.stack((sin_base, sin_base), dim=-1).reshape(seq_len, head_size)
    cos = torch.stack((cos_base, cos_base), dim=-1).reshape(seq_len, head_size)
    return sin, cos


def apply_rope(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
    x_even = x[..., ::2]
    x_odd = x[..., 1::2]
    rotated = torch.stack((-x_odd, x_even), dim=-1).reshape_as(x)
    return x * cos.unsqueeze(0).unsqueeze(2) + rotated * sin.unsqueeze(0).unsqueeze(2)


class IrishCoreGlobalPointerModel(PreTrainedModel):
    config_class = AutoConfig
    base_model_prefix = "encoder"

    def __init__(self, config):
        super().__init__(config)
        self.encoder = AutoModel.from_config(config)
        self.num_span_labels = int(getattr(config, "num_span_labels"))
        self.head_size = int(getattr(config, "global_pointer_head_size", 64))
        self.use_rope = bool(getattr(config, "global_pointer_use_rope", True))
        self.negative_ratio = int(getattr(config, "global_pointer_negative_ratio", 16))
        self.min_negatives = int(getattr(config, "global_pointer_min_negatives", 256))
        hidden_size = hidden_size_from_config(config)
        dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1)))
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_size, self.num_span_labels * self.head_size * 2)
        pos_weight = float(getattr(config, "span_positive_weight", 6.0))
        self.register_buffer("loss_pos_weight", torch.full((self.num_span_labels,), pos_weight), persistent=False)
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        span_labels=None,
        token_mask=None,
        **kwargs,
    ) -> GlobalPointerSpanOutput:
        encoder_kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            **kwargs,
        }
        if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}:
            encoder_kwargs["token_type_ids"] = token_type_ids
        outputs = self.encoder(**encoder_kwargs)
        hidden = self.dropout(outputs.last_hidden_state)
        batch_size, seq_len, _ = hidden.shape

        projected = self.proj(hidden).view(batch_size, seq_len, self.num_span_labels, self.head_size * 2)
        query, key = torch.chunk(projected, 2, dim=-1)

        if self.use_rope:
            sin, cos = build_rope_cache(seq_len, self.head_size, hidden.device, hidden.dtype)
            query = apply_rope(query, sin, cos)
            key = apply_rope(key, sin, cos)

        span_logits = torch.einsum("bshd,bthd->bhst", query, key) / math.sqrt(self.head_size)

        if token_mask is None:
            token_mask = attention_mask
        if token_mask is None:
            token_mask = torch.ones((batch_size, seq_len), device=hidden.device, dtype=hidden.dtype)
        token_mask = token_mask.to(hidden.dtype)
        pair_mask = token_mask[:, None, :, None] * token_mask[:, None, None, :]
        upper_mask = torch.triu(torch.ones((seq_len, seq_len), device=hidden.device, dtype=hidden.dtype))
        pair_mask = pair_mask * upper_mask.unsqueeze(0).unsqueeze(0)
        masked_logits = span_logits.masked_fill(pair_mask <= 0.0, -1e4)

        loss = None
        if span_labels is not None:
            targets = span_labels.float()
            pos_weight = self.loss_pos_weight.to(hidden.device).view(1, self.num_span_labels, 1, 1)
            raw_loss = F.binary_cross_entropy_with_logits(span_logits, targets, reduction="none", pos_weight=pos_weight)
            valid_mask = pair_mask > 0.0
            positive_mask = (targets > 0.0) & valid_mask
            negative_mask = (~positive_mask) & valid_mask

            positive_loss = raw_loss.masked_select(positive_mask)
            negative_loss = raw_loss.masked_select(negative_mask)
            if negative_loss.numel() > 0 and self.negative_ratio > 0:
                positive_count = int(positive_mask.sum().item())
                keep_negatives = max(self.min_negatives, positive_count * self.negative_ratio)
                keep_negatives = min(keep_negatives, negative_loss.numel())
                negative_loss = torch.topk(negative_loss, keep_negatives).values

            parts = []
            if positive_loss.numel() > 0:
                parts.append(positive_loss.mean())
            if negative_loss.numel() > 0:
                parts.append(negative_loss.mean())
            loss = sum(parts) / len(parts) if parts else raw_loss.sum() * 0.0

        return GlobalPointerSpanOutput(loss=loss, span_logits=masked_logits)