temsa's picture
Add raw-only rc8 release with ONNX dynamic q8
49a55aa verified
#!/usr/bin/env python3
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.utils import ModelOutput
from model import hidden_size_from_config
@dataclass
class MultitaskSpanOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
token_logits: Optional[torch.Tensor] = None
start_logits: Optional[torch.Tensor] = None
end_logits: Optional[torch.Tensor] = None
class IrishCoreTokenSpanModel(PreTrainedModel):
config_class = AutoConfig
base_model_prefix = "encoder"
def __init__(self, config):
super().__init__(config)
num_span_labels = int(getattr(config, "num_span_labels"))
self.encoder = AutoModel.from_config(config)
hidden_size = hidden_size_from_config(config)
dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1)))
self.dropout = nn.Dropout(dropout)
self.token_classifier = nn.Linear(hidden_size, num_span_labels)
self.start_classifier = nn.Linear(hidden_size, num_span_labels)
self.end_classifier = nn.Linear(hidden_size, num_span_labels)
boundary_pos_weight = float(getattr(config, "span_positive_weight", 6.0))
presence_pos_weight = float(getattr(config, "token_positive_weight", 4.0))
self.register_buffer("boundary_pos_weight", torch.full((num_span_labels,), boundary_pos_weight), persistent=False)
self.register_buffer("presence_pos_weight", torch.full((num_span_labels,), presence_pos_weight), persistent=False)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
token_labels=None,
start_positions=None,
end_positions=None,
token_mask=None,
**kwargs,
) -> MultitaskSpanOutput:
encoder_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
**kwargs,
}
if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}:
encoder_kwargs["token_type_ids"] = token_type_ids
outputs = self.encoder(**encoder_kwargs)
hidden = self.dropout(outputs.last_hidden_state)
token_logits = self.token_classifier(hidden)
start_logits = self.start_classifier(hidden)
end_logits = self.end_classifier(hidden)
loss = None
if token_labels is not None and start_positions is not None and end_positions is not None:
if token_mask is None:
token_mask = attention_mask
mask = token_mask.float().unsqueeze(-1)
boundary_pos_weight = self.boundary_pos_weight.to(token_logits.device)
presence_pos_weight = self.presence_pos_weight.to(token_logits.device)
bce_boundary = nn.BCEWithLogitsLoss(reduction="none", pos_weight=boundary_pos_weight)
bce_presence = nn.BCEWithLogitsLoss(reduction="none", pos_weight=presence_pos_weight)
token_loss = bce_presence(token_logits, token_labels.float()) * mask
start_loss = bce_boundary(start_logits, start_positions.float()) * mask
end_loss = bce_boundary(end_logits, end_positions.float()) * mask
denom = mask.sum().clamp_min(1.0) * token_logits.shape[-1]
token_loss = token_loss.sum() / denom
boundary_loss = (start_loss.sum() + end_loss.sum()) / (2.0 * denom)
token_weight = float(getattr(self.config, "token_presence_weight", 1.0))
boundary_weight = float(getattr(self.config, "boundary_loss_weight", 1.0))
loss = token_weight * token_loss + boundary_weight * boundary_loss
return MultitaskSpanOutput(
loss=loss,
token_logits=token_logits,
start_logits=start_logits,
end_logits=end_logits,
)