IrishCore-DiffMask-135M-v1-rc6 / multitask_model.py
temsa's picture
Publish IrishCore-DiffMask-135M-v1-rc6
b08ade7 verified
#!/usr/bin/env python3
from __future__ import annotations
from dataclasses import dataclass
import math
from typing import Optional
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.utils import ModelOutput
try:
from .model import hidden_size_from_config
except ImportError:
from model import hidden_size_from_config
@dataclass
class MultitaskSpanOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
token_logits: Optional[torch.Tensor] = None
start_logits: Optional[torch.Tensor] = None
end_logits: Optional[torch.Tensor] = None
class IrishCoreTokenSpanModel(PreTrainedModel):
config_class = AutoConfig
base_model_prefix = "encoder"
def __init__(self, config):
super().__init__(config)
num_span_labels = int(getattr(config, "num_span_labels"))
self.encoder = AutoModel.from_config(config)
hidden_size = hidden_size_from_config(config)
dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1)))
self.dropout = nn.Dropout(dropout)
self.token_classifier = nn.Linear(hidden_size, num_span_labels)
self.start_classifier = nn.Linear(hidden_size, num_span_labels)
self.end_classifier = nn.Linear(hidden_size, num_span_labels)
boundary_pos_weight = float(getattr(config, "span_positive_weight", 6.0))
presence_pos_weight = float(getattr(config, "token_positive_weight", 4.0))
self.register_buffer("boundary_pos_weight", torch.full((num_span_labels,), boundary_pos_weight), persistent=False)
self.register_buffer("presence_pos_weight", torch.full((num_span_labels,), presence_pos_weight), persistent=False)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
token_labels=None,
start_positions=None,
end_positions=None,
token_mask=None,
**kwargs,
) -> MultitaskSpanOutput:
encoder_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
**kwargs,
}
if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}:
encoder_kwargs["token_type_ids"] = token_type_ids
outputs = self.encoder(**encoder_kwargs)
hidden = self.dropout(outputs.last_hidden_state)
token_logits = self.token_classifier(hidden)
start_logits = self.start_classifier(hidden)
end_logits = self.end_classifier(hidden)
loss = None
if token_labels is not None and start_positions is not None and end_positions is not None:
if token_mask is None:
token_mask = attention_mask
mask = token_mask.float().unsqueeze(-1)
boundary_pos_weight = self.boundary_pos_weight.to(token_logits.device)
presence_pos_weight = self.presence_pos_weight.to(token_logits.device)
bce_boundary = nn.BCEWithLogitsLoss(reduction="none", pos_weight=boundary_pos_weight)
bce_presence = nn.BCEWithLogitsLoss(reduction="none", pos_weight=presence_pos_weight)
token_targets = token_labels.float()
start_targets = start_positions.float()
end_targets = end_positions.float()
token_loss = bce_presence(token_logits, token_targets)
start_loss = bce_boundary(start_logits, start_targets)
end_loss = bce_boundary(end_logits, end_targets)
token_focal_gamma = float(getattr(self.config, "token_focal_gamma", getattr(self.config, "focal_gamma", 0.0)))
boundary_focal_gamma = float(getattr(self.config, "boundary_focal_gamma", getattr(self.config, "focal_gamma", 0.0)))
if token_focal_gamma > 0.0:
token_loss = apply_focal_weight(token_loss, token_logits, token_targets, token_focal_gamma)
if boundary_focal_gamma > 0.0:
start_loss = apply_focal_weight(start_loss, start_logits, start_targets, boundary_focal_gamma)
end_loss = apply_focal_weight(end_loss, end_logits, end_targets, boundary_focal_gamma)
token_hard_fraction = float(getattr(self.config, "token_hard_fraction", getattr(self.config, "hard_fraction", 1.0)))
boundary_hard_fraction = float(getattr(self.config, "boundary_hard_fraction", getattr(self.config, "hard_fraction", 1.0)))
token_loss = reduce_masked_loss(token_loss, mask, token_hard_fraction)
start_loss = reduce_masked_loss(start_loss, mask, boundary_hard_fraction)
end_loss = reduce_masked_loss(end_loss, mask, boundary_hard_fraction)
boundary_loss = 0.5 * (start_loss + end_loss)
token_weight = float(getattr(self.config, "token_presence_weight", 1.0))
boundary_weight = float(getattr(self.config, "boundary_loss_weight", 1.0))
loss = token_weight * token_loss + boundary_weight * boundary_loss
return MultitaskSpanOutput(
loss=loss,
token_logits=token_logits,
start_logits=start_logits,
end_logits=end_logits,
)
def apply_focal_weight(loss: torch.Tensor, logits: torch.Tensor, targets: torch.Tensor, gamma: float) -> torch.Tensor:
probs = torch.sigmoid(logits)
pt = probs * targets + (1.0 - probs) * (1.0 - targets)
return loss * (1.0 - pt).pow(gamma)
def reduce_masked_loss(loss: torch.Tensor, mask: torch.Tensor, hard_fraction: float) -> torch.Tensor:
expanded_mask = mask.expand_as(loss).bool()
masked = loss.masked_select(expanded_mask)
if masked.numel() == 0:
return loss.sum() * 0.0
if 0.0 < hard_fraction < 1.0:
keep = max(1, math.ceil(masked.numel() * hard_fraction))
masked = torch.topk(masked, keep).values
return masked.mean()