| | from collections.abc import Generator, Iterable |
| | from dataclasses import dataclass |
| | from enum import StrEnum |
| | from itertools import chain |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | import torch.nn as nn |
| | from transformers import ( |
| | AutoConfig, |
| | AutoModel, |
| | AutoTokenizer, |
| | ModernBertModel, |
| | PreTrainedConfig, |
| | PreTrainedModel, |
| | ) |
| | from transformers.modeling_outputs import TokenClassifierOutput |
| |
|
| | BATCH_SIZE = 64 |
| |
|
| |
|
| | class ModelURI(StrEnum): |
| | BASE = "answerdotai/ModernBERT-base" |
| | LARGE = "answerdotai/ModernBERT-large" |
| |
|
| |
|
| | @dataclass(slots=True, frozen=True) |
| | class LexicalExample: |
| | concept: str |
| | definition: str |
| |
|
| |
|
| | @dataclass(slots=True, frozen=True) |
| | class PaddedBatch: |
| | input_ids: torch.Tensor |
| | attention_mask: torch.Tensor |
| |
|
| |
|
| | class DisamBert(PreTrainedModel): |
| | def __init__(self, config: PreTrainedConfig): |
| | super().__init__(config) |
| | if config.init_basemodel: |
| | self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto") |
| | self.classifier_head = nn.UninitializedParameter() |
| | self.bias = nn.UninitializedParameter() |
| | self.__entities = None |
| | else: |
| | self.BaseModel = ModernBertModel(config) |
| | self.classifier_head = nn.Parameter( |
| | torch.empty((config.ontology_size, config.hidden_size)) |
| | ) |
| | self.bias = nn.Parameter(torch.empty((config.ontology_size, 1))) |
| | self.__entities = pd.Series(config.entities) |
| | config.init_basemodel = False |
| | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path) |
| | self.loss = nn.CrossEntropyLoss() |
| | self.post_init() |
| |
|
| | @classmethod |
| | def from_base(cls, base_id: ModelURI): |
| | config = AutoConfig.from_pretrained(base_id) |
| | config.init_basemodel = True |
| | config.tokenizer_path = base_id |
| | return cls(config) |
| |
|
| | def init_classifier(self, entities: Generator[LexicalExample]) -> None: |
| | entity_ids = [] |
| | vectors = [] |
| | batch = [] |
| | n = 0 |
| | with self.BaseModel.device: |
| | for entity in entities: |
| | entity_ids.append(entity.concept) |
| | batch.append(entity.definition) |
| |
|
| | n += 1 |
| | if n == BATCH_SIZE: |
| | tokens = self.tokenizer(batch, padding=True, return_tensors="pt") |
| | encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"]) |
| | vectors.append(encoding.last_hidden_state.detach()[:, 0]) |
| | n = 0 |
| | batch = [] |
| | if n > 0: |
| | tokens = self.tokenizer(batch, padding=True, return_tensors="pt") |
| | encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"]) |
| | vectors.append(encoding.last_hidden_state.detach()[:, 0]) |
| |
|
| | self.__entities = pd.Series(entity_ids) |
| | self.config.entities = entity_ids |
| | self.config.ontology_size = len(entity_ids) |
| | self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0)) |
| | self.bias = nn.Parameter( |
| | torch.nn.init.normal_( |
| | torch.empty((self.config.ontology_size, 1)), |
| | std=self.classifier_head.std().item() * np.sqrt(self.config.hidden_size) |
| | ) |
| | ) |
| |
|
| | @property |
| | def entities(self) -> pd.Series: |
| | if self.__entities is None and hasattr(self.config, "entities"): |
| | self.__entities = pd.Series(self.config.entities) |
| | return self.__entities |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | lengths: list[list[int]], |
| | candidates: list[list[list[int]]], |
| | labels: Iterable[list[int]] | None = None, |
| | output_hidden_states: bool = False, |
| | output_attentions: bool = False, |
| | ) -> TokenClassifierOutput: |
| | assert not nn.parameter.is_lazy(self.classifier_head), ( |
| | "Run init_classifier to initialise weights" |
| | ) |
| | base_model_output = self.BaseModel( |
| | input_ids, |
| | attention_mask, |
| | output_hidden_states=output_hidden_states, |
| | output_attentions=output_attentions, |
| | ) |
| | token_vectors = base_model_output.last_hidden_state |
| | span_vectors = torch.cat( |
| | [ |
| | torch.vstack( |
| | [ |
| | torch.sum(chunk, dim=0) |
| | for chunk in self.split(token_vectors[i], sentence_indices) |
| | ] |
| | ) |
| | for (i, sentence_indices) in enumerate(lengths) |
| | ] |
| | ) |
| | logits = torch.einsum("ij,kj->ki", span_vectors, self.classifier_head) + self.bias |
| | logits1 = logits - logits.min() |
| | mask = torch.zeros_like(logits) |
| | for i, concepts in enumerate(chain.from_iterable(candidates)): |
| | mask[concepts, i] = torch.tensor(1.0) |
| | logits2 = logits1 * mask |
| | sentence_lengths = [len(sentence_indices) for sentence_indices in lengths] |
| | maxlen = max(sentence_lengths) |
| | split_logits = torch.split(logits2, sentence_lengths, dim=1) |
| | logits3 = torch.stack( |
| | [ |
| | self.extend_to_max_length(sentence, length, maxlen) |
| | for (sentence, length) in zip(split_logits, sentence_lengths, strict=True) |
| | ] |
| | ) |
| | return TokenClassifierOutput( |
| | logits=logits3, |
| | loss=self.loss(logits3, labels) if labels is not None else None, |
| | hidden_states=base_model_output.hidden_states if output_hidden_states else None, |
| | attentions=base_model_output.attentions if output_attentions else None, |
| | ) |
| |
|
| | def split(self, vectors: torch.Tensor, lengths: list[int]) -> tuple[torch.Tensor, ...]: |
| | maxlen = vectors.shape[0] |
| | total_length = sum(lengths) |
| | is_padded = total_length < maxlen |
| | chunks = vectors.split((lengths + [maxlen - total_length]) if is_padded else lengths) |
| | return chunks[:-1] if is_padded else chunks |
| |
|
| | def pad(self, tokens: Iterable[list[int]]) -> PaddedBatch: |
| | lengths = [len(sentence) for sentence in tokens] |
| | maxlen = max(lengths) |
| | input_ids = torch.tensor( |
| | [ |
| | sentence + [self.config.pad_token_id] * (maxlen - length) |
| | for (sentence, length) in zip(tokens, lengths, strict=True) |
| | ] |
| | ) |
| | attention_mask = torch.vstack( |
| | [torch.cat((torch.ones(length), torch.zeros(maxlen - length))) for length in lengths] |
| | ) |
| | return PaddedBatch(input_ids, attention_mask) |
| |
|
| | def extend_to_max_length( |
| | self, sentence: torch.Tensor, length: int, maxlength: int |
| | ) -> torch.Tensor: |
| | with self.BaseModel.device: |
| | return ( |
| | torch.cat( |
| | [ |
| | sentence, |
| | torch.zeros((self.config.ontology_size, maxlength - length)), |
| | ], |
| | dim=1, |
| | ) |
| | if length < maxlength |
| | else sentence |
| | ) |
| |
|
| | def pad_labels(self, labels: list[list[int]]) -> torch.Tensor: |
| | unk = len(self.config.entities) - 1 |
| | lengths = [len(seq) for seq in labels] |
| | maxlen = max(lengths) |
| | with self.BaseModel.device: |
| | return torch.tensor( |
| | [ |
| | seq + [unk] * (maxlen - length) |
| | for (seq, length) in zip(labels, lengths, strict=True) |
| | ] |
| | ) |
| |
|
| | def tokenize( |
| | self, batch: list[dict[str, str | list[int]]] |
| | ) -> dict[str, torch.Tensor | list[list[int]]]: |
| | all_indices = [] |
| | all_tokens = [] |
| | with self.BaseModel.device: |
| | for example in batch: |
| | text = example["text"] |
| | span_indices = example["indices"] |
| | indices = [] |
| | tokens = [] |
| | last_span = len(span_indices) - 2 |
| | for i, position in enumerate(span_indices[:-1]): |
| | span = text[position : span_indices[i + 1]] |
| | span_tokens = self.tokenizer([span], padding=False)["input_ids"][0] |
| | if i > 0: |
| | span_tokens = span_tokens[1:] |
| | if i < last_span: |
| | span_tokens = span_tokens[:-1] |
| | indices.append(len(span_tokens)) |
| | tokens.extend(span_tokens) |
| | all_indices.append(indices) |
| | all_tokens.append(tokens) |
| | padded = self.pad(all_tokens) |
| | result = { |
| | "input_ids": padded.input_ids, |
| | "attention_mask": padded.attention_mask, |
| | "lengths": all_indices, |
| | "candidates": [example["candidates"] for example in batch], |
| | } |
| | if "labels" in batch[0]: |
| | result["labels"] = self.pad_labels([example["labels"] for example in batch]) |
| | return result |
| |
|