| | from collections.abc import Generator, Iterable |
| | from dataclasses import dataclass |
| | from enum import StrEnum |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | import torch.nn as nn |
| | from transformers import ( |
| | AutoConfig, |
| | AutoModel, |
| | ModernBertModel, |
| | PreTrainedConfig, |
| | PreTrainedModel, |
| | PreTrainedTokenizer, |
| | ) |
| | from transformers.modeling_outputs import TokenClassifierOutput |
| |
|
| | BATCH_SIZE = 16 |
| |
|
| |
|
| | class ModelURI(StrEnum): |
| | BASE = "answerdotai/ModernBERT-base" |
| | LARGE = "answerdotai/ModernBERT-large" |
| |
|
| |
|
| | @dataclass(slots=True, frozen=True) |
| | class LexicalExample: |
| | concept: str |
| | definition: str |
| |
|
| |
|
| | @dataclass(slots=True, frozen=True) |
| | class PaddedBatch: |
| | input_ids: torch.Tensor |
| | attention_mask: torch.Tensor |
| |
|
| |
|
| | class DisamBertSingleSense(PreTrainedModel): |
| | def __init__(self, config: PreTrainedConfig): |
| | super().__init__(config) |
| | if config.init_basemodel: |
| | self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto") |
| | self.config.vocab_size += 2 |
| | self.BaseModel.resize_token_embeddings(self.config.vocab_size) |
| | self.classifier_head = nn.UninitializedParameter() |
| | self.bias = nn.UninitializedParameter() |
| | self.__entities = None |
| | else: |
| | self.BaseModel = ModernBertModel(config) |
| | self.classifier_head = nn.Parameter( |
| | torch.empty((config.ontology_size, config.hidden_size)) |
| | ) |
| | self.bias = nn.Parameter(torch.empty((1,config.ontology_size))) |
| | self.__entities = pd.Series(config.entities) |
| | config.init_basemodel = False |
| |
|
| | self.loss = nn.CrossEntropyLoss() |
| | self.post_init() |
| |
|
| | @classmethod |
| | def from_base(cls, base_id: ModelURI): |
| | config = AutoConfig.from_pretrained(base_id) |
| | config.init_basemodel = True |
| | return cls(config) |
| |
|
| | def init_classifier( |
| | self, entities: Generator[LexicalExample], tokenizer: PreTrainedTokenizer |
| | ) -> None: |
| | entity_ids = [] |
| | vectors = [] |
| | batch = [] |
| | n = 0 |
| | with self.BaseModel.device: |
| | torch.cuda.empty_cache() |
| | for entity in entities: |
| | entity_ids.append(entity.concept) |
| | batch.append(entity.definition) |
| |
|
| | n += 1 |
| | if n == BATCH_SIZE: |
| | tokens = tokenizer(batch, padding=True, return_tensors="pt") |
| | encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"]) |
| | vectors.append(encoding.last_hidden_state.detach()[:, 0]) |
| | n = 0 |
| | batch = [] |
| | if n > 0: |
| | tokens = tokenizer(batch, padding=True, return_tensors="pt") |
| | encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"]) |
| | vectors.append(encoding.last_hidden_state.detach()[:, 0]) |
| |
|
| | self.__entities = pd.Series(entity_ids) |
| | self.config.entities = entity_ids |
| | self.config.ontology_size = len(entity_ids) |
| | self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0)) |
| | self.bias = nn.Parameter( |
| | torch.nn.init.normal_( |
| | torch.empty((1,self.config.ontology_size)), |
| | std=self.classifier_head.std().item() * np.sqrt(self.config.hidden_size), |
| | ) |
| | ) |
| |
|
| | @property |
| | def entities(self) -> pd.Series: |
| | if self.__entities is None and hasattr(self.config, "entities"): |
| | self.__entities = pd.Series(self.config.entities) |
| | return self.__entities |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | labels: Iterable[int] | None = None, |
| | output_hidden_states: bool = False, |
| | output_attentions: bool = False, |
| | ) -> TokenClassifierOutput: |
| | assert not nn.parameter.is_lazy(self.classifier_head), ( |
| | "Run init_classifier to initialise weights" |
| | ) |
| | base_model_output = self.BaseModel( |
| | input_ids, |
| | attention_mask, |
| | output_hidden_states=output_hidden_states, |
| | output_attentions=output_attentions, |
| | ) |
| | token_vectors = base_model_output.last_hidden_state[:, 0] |
| | logits = torch.einsum("ij,kj->ik", token_vectors, self.classifier_head) + self.bias |
| |
|
| | return TokenClassifierOutput( |
| | logits=logits, |
| | loss=self.loss(logits, labels) if labels is not None else None, |
| | hidden_states=base_model_output.hidden_states if output_hidden_states else None, |
| | attentions=base_model_output.attentions if output_attentions else None, |
| | ) |
| |
|