from collections.abc import Generator, Iterable from dataclasses import dataclass from enum import StrEnum from itertools import chain import numpy as np import pandas as pd import torch import torch.nn as nn from transformers import ( AutoConfig, AutoModel, AutoTokenizer, ModernBertModel, PreTrainedConfig, PreTrainedModel, ) from transformers.modeling_outputs import TokenClassifierOutput BATCH_SIZE = 16 class ModelURI(StrEnum): BASE = "answerdotai/ModernBERT-base" LARGE = "answerdotai/ModernBERT-large" @dataclass(slots=True, frozen=True) class LexicalExample: concept: str definition: str @dataclass(slots=True, frozen=True) class PaddedBatch: input_ids: torch.Tensor attention_mask: torch.Tensor class DisamBert(PreTrainedModel): def __init__(self, config: PreTrainedConfig): super().__init__(config) if config.init_basemodel: self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto") self.classifier_head = nn.UninitializedParameter() self.bias = nn.UninitializedParameter() self.__entities = None else: self.BaseModel = ModernBertModel(config) self.classifier_head = nn.Parameter( torch.empty((config.ontology_size, config.hidden_size)) ) self.bias = nn.Parameter(torch.empty((config.ontology_size, 1))) self.__entities = pd.Series(config.entities) config.init_basemodel = False self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path) self.loss = nn.CrossEntropyLoss() self.post_init() @classmethod def from_base(cls, base_id: ModelURI): config = AutoConfig.from_pretrained(base_id) config.init_basemodel = True config.tokenizer_path = base_id return cls(config) def init_classifier(self, entities: Generator[LexicalExample]) -> None: entity_ids = [] vectors = [] batch = [] n = 0 with self.BaseModel.device: torch.cuda.empty_cache() for entity in entities: entity_ids.append(entity.concept) batch.append(entity.definition) n += 1 if n == BATCH_SIZE: tokens = self.tokenizer(batch, padding=True, return_tensors="pt") encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"]) vectors.append(encoding.last_hidden_state.detach()[:, 0]) n = 0 batch = [] if n > 0: tokens = self.tokenizer(batch, padding=True, return_tensors="pt") encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"]) vectors.append(encoding.last_hidden_state.detach()[:, 0]) self.__entities = pd.Series(entity_ids) self.config.entities = entity_ids self.config.ontology_size = len(entity_ids) self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0)) self.bias = nn.Parameter( torch.nn.init.normal_( torch.empty((self.config.ontology_size, 1)), std=self.classifier_head.std().item() * np.sqrt(self.config.hidden_size) ) ) @property def entities(self) -> pd.Series: if self.__entities is None and hasattr(self.config, "entities"): self.__entities = pd.Series(self.config.entities) return self.__entities def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, lengths: list[list[int]], candidates: list[list[list[int]]], labels: Iterable[list[int]] | None = None, output_hidden_states: bool = False, output_attentions: bool = False, ) -> TokenClassifierOutput: assert not nn.parameter.is_lazy(self.classifier_head), ( "Run init_classifier to initialise weights" ) base_model_output = self.BaseModel( input_ids, attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) token_vectors = base_model_output.last_hidden_state span_vectors = torch.cat( [ torch.vstack( [ torch.sum(chunk, dim=0) for chunk in self.split(token_vectors[i], sentence_indices) ] ) for (i, sentence_indices) in enumerate(lengths) ] ) logits = torch.einsum("ij,kj->ki", span_vectors, self.classifier_head) + self.bias logits1 = logits - logits.min() mask = torch.zeros_like(logits) for i, concepts in enumerate(chain.from_iterable(candidates)): mask[concepts, i] = torch.tensor(1.0) logits2 = logits1 * mask sentence_lengths = [len(sentence_indices) for sentence_indices in lengths] maxlen = max(sentence_lengths) split_logits = torch.split(logits2, sentence_lengths, dim=1) logits3 = torch.stack( [ self.extend_to_max_length(sentence, length, maxlen) for (sentence, length) in zip(split_logits, sentence_lengths, strict=True) ] ) return TokenClassifierOutput( logits=logits3, loss=self.loss(logits3, labels) if labels is not None else None, hidden_states=base_model_output.hidden_states if output_hidden_states else None, attentions=base_model_output.attentions if output_attentions else None, ) def split(self, vectors: torch.Tensor, lengths: list[int]) -> tuple[torch.Tensor, ...]: maxlen = vectors.shape[0] total_length = sum(lengths) is_padded = total_length < maxlen chunks = vectors.split((lengths + [maxlen - total_length]) if is_padded else lengths) return chunks[:-1] if is_padded else chunks def pad(self, tokens: Iterable[list[int]]) -> PaddedBatch: lengths = [len(sentence) for sentence in tokens] maxlen = max(lengths) input_ids = torch.tensor( [ sentence + [self.config.pad_token_id] * (maxlen - length) for (sentence, length) in zip(tokens, lengths, strict=True) ] ) attention_mask = torch.vstack( [torch.cat((torch.ones(length), torch.zeros(maxlen - length))) for length in lengths] ) return PaddedBatch(input_ids, attention_mask) def extend_to_max_length( self, sentence: torch.Tensor, length: int, maxlength: int ) -> torch.Tensor: with self.BaseModel.device: return ( torch.cat( [ sentence, torch.zeros((self.config.ontology_size, maxlength - length)), ], dim=1, ) if length < maxlength else sentence ) def pad_labels(self, labels: list[list[int]]) -> torch.Tensor: unk = len(self.config.entities) - 1 lengths = [len(seq) for seq in labels] maxlen = max(lengths) with self.BaseModel.device: return torch.tensor( [ seq + [unk] * (maxlen - length) for (seq, length) in zip(labels, lengths, strict=True) ] ) def tokenize( self, batch: list[dict[str, str | list[int]]] ) -> dict[str, torch.Tensor | list[list[int]]]: all_indices = [] all_tokens = [] with self.BaseModel.device: for example in batch: text = example["text"] span_indices = example["indices"] indices = [] tokens = [] last_span = len(span_indices) - 2 for i, position in enumerate(span_indices[:-1]): span = text[position : span_indices[i + 1]] span_tokens = self.tokenizer([span], padding=False)["input_ids"][0] if i > 0: span_tokens = span_tokens[1:] if i < last_span: span_tokens = span_tokens[:-1] indices.append(len(span_tokens)) tokens.extend(span_tokens) all_indices.append(indices) all_tokens.append(tokens) padded = self.pad(all_tokens) result = { "input_ids": padded.input_ids, "attention_mask": padded.attention_mask, "lengths": all_indices, "candidates": [example["candidates"] for example in batch], } if "labels" in batch[0]: result["labels"] = self.pad_labels([example["labels"] for example in batch]) return result