from collections.abc import Generator, Iterable from dataclasses import dataclass from enum import StrEnum import numpy as np import pandas as pd import torch import torch.nn as nn from transformers import ( AutoConfig, AutoModel, ModernBertModel, PreTrainedConfig, PreTrainedModel, PreTrainedTokenizer, ) from transformers.modeling_outputs import TokenClassifierOutput BATCH_SIZE = 16 class ModelURI(StrEnum): BASE = "answerdotai/ModernBERT-base" LARGE = "answerdotai/ModernBERT-large" @dataclass(slots=True, frozen=True) class LexicalExample: concept: str definition: str @dataclass(slots=True, frozen=True) class PaddedBatch: input_ids: torch.Tensor attention_mask: torch.Tensor class DisamBertSingleSense(PreTrainedModel): def __init__(self, config: PreTrainedConfig): super().__init__(config) if config.init_basemodel: self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto") self.config.vocab_size += 2 self.BaseModel.resize_token_embeddings(self.config.vocab_size) self.classifier_head = nn.UninitializedParameter() self.bias = nn.UninitializedParameter() self.__entities = None else: self.BaseModel = ModernBertModel(config) self.classifier_head = nn.Parameter( torch.empty((config.ontology_size, config.hidden_size)) ) self.bias = nn.Parameter(torch.empty((1,config.ontology_size))) self.__entities = pd.Series(config.entities) config.init_basemodel = False self.loss = nn.CrossEntropyLoss() self.post_init() @classmethod def from_base(cls, base_id: ModelURI): config = AutoConfig.from_pretrained(base_id) config.init_basemodel = True return cls(config) def init_classifier( self, entities: Generator[LexicalExample], tokenizer: PreTrainedTokenizer ) -> None: entity_ids = [] vectors = [] batch = [] n = 0 with self.BaseModel.device: torch.cuda.empty_cache() for entity in entities: entity_ids.append(entity.concept) batch.append(entity.definition) n += 1 if n == BATCH_SIZE: tokens = tokenizer(batch, padding=True, return_tensors="pt") encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"]) vectors.append(encoding.last_hidden_state.detach()[:, 0]) n = 0 batch = [] if n > 0: tokens = tokenizer(batch, padding=True, return_tensors="pt") encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"]) vectors.append(encoding.last_hidden_state.detach()[:, 0]) self.__entities = pd.Series(entity_ids) self.config.entities = entity_ids self.config.ontology_size = len(entity_ids) self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0)) self.bias = nn.Parameter( torch.nn.init.normal_( torch.empty((1,self.config.ontology_size)), std=self.classifier_head.std().item() * np.sqrt(self.config.hidden_size), ) ) @property def entities(self) -> pd.Series: if self.__entities is None and hasattr(self.config, "entities"): self.__entities = pd.Series(self.config.entities) return self.__entities def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: Iterable[int] | None = None, output_hidden_states: bool = False, output_attentions: bool = False, ) -> TokenClassifierOutput: assert not nn.parameter.is_lazy(self.classifier_head), ( "Run init_classifier to initialise weights" ) base_model_output = self.BaseModel( input_ids, attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) token_vectors = base_model_output.last_hidden_state[:, 0] logits = torch.einsum("ij,kj->ik", token_vectors, self.classifier_head) + self.bias return TokenClassifierOutput( logits=logits, loss=self.loss(logits, labels) if labels is not None else None, hidden_states=base_model_output.hidden_states if output_hidden_states else None, attentions=base_model_output.attentions if output_attentions else None, )