Feature Extraction
Transformers
Safetensors
English
modernbert
Generated from Trainer
custom_code
text-embeddings-inference
Instructions to use GliteTech/DisamBertCrossEncoder-base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use GliteTech/DisamBertCrossEncoder-base with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="GliteTech/DisamBertCrossEncoder-base", trust_remote_code=True)# Load model directly from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("GliteTech/DisamBertCrossEncoder-base", trust_remote_code=True) model = AutoModel.from_pretrained("GliteTech/DisamBertCrossEncoder-base", trust_remote_code=True) - Notebooks
- Google Colab
- Kaggle
| from collections.abc import Generator, Iterable | |
| from dataclasses import dataclass | |
| from enum import StrEnum | |
| import pprint | |
| import torch | |
| import torch.nn as nn | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModel, | |
| BatchEncoding, | |
| ModernBertModel, | |
| PreTrainedConfig, | |
| PreTrainedModel, | |
| PreTrainedTokenizer, | |
| ) | |
| from transformers.modeling_outputs import TokenClassifierOutput | |
| BATCH_SIZE = 16 | |
| class ModelURI(StrEnum): | |
| BASE = "answerdotai/ModernBERT-base" | |
| LARGE = "answerdotai/ModernBERT-large" | |
| class LexicalExample: | |
| concept: str | |
| definition: str | |
| class PaddedBatch: | |
| input_ids: torch.Tensor | |
| attention_mask: torch.Tensor | |
| class DisamBertSingleSense(PreTrainedModel): | |
| def __init__(self, config: PreTrainedConfig): | |
| super().__init__(config) | |
| if config.init_basemodel: | |
| self.BaseModel = AutoModel.from_pretrained(config.name_or_path, | |
| attn_implementation="flash_attention_2", | |
| dtype=torch.bfloat16, | |
| device_map="auto") | |
| self.config.vocab_size += 3 | |
| self.BaseModel.resize_token_embeddings(self.config.vocab_size) | |
| else: | |
| self.BaseModel = ModernBertModel(config) | |
| config.init_basemodel = False | |
| self.loss = nn.CrossEntropyLoss() | |
| self.post_init() | |
| def from_base(cls, base_id: ModelURI): | |
| config = AutoConfig.from_pretrained(base_id) | |
| config.init_basemodel = True | |
| return cls(config) | |
| def add_special_tokens(self, start: int, end: int, gloss: int): | |
| self.config.start_token = start | |
| self.config.end_token = end | |
| self.config.gloss_token = gloss | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| labels: Iterable[int] | None = None, | |
| output_hidden_states: bool = False, | |
| output_attentions: bool = False, | |
| ) -> TokenClassifierOutput: | |
| base_model_output = self.BaseModel( | |
| input_ids, | |
| attention_mask, | |
| output_hidden_states=output_hidden_states, | |
| output_attentions=output_attentions, | |
| ) | |
| token_vectors = base_model_output.last_hidden_state | |
| selection = torch.zeros_like(input_ids, dtype=token_vectors.dtype) | |
| starts = (input_ids == self.config.start_token).nonzero() | |
| ends = (input_ids == self.config.end_token).nonzero() | |
| for startpos, endpos in zip(starts, ends, strict=True): | |
| selection[startpos[0], startpos[1] : endpos[1] + 1] = 1.0 | |
| entity_vectors = torch.einsum("ijk,ij->ik", token_vectors, selection) | |
| gloss_vectors = self.gloss_vectors( | |
| token_vectors, | |
| input_ids, | |
| ) | |
| logits = torch.einsum("ij,ikj->ik", entity_vectors, gloss_vectors) | |
| return TokenClassifierOutput( | |
| logits=logits, | |
| loss=self.loss(logits, labels) if labels is not None else None, | |
| hidden_states=base_model_output.hidden_states if output_hidden_states else None, | |
| attentions=base_model_output.attentions if output_attentions else None, | |
| ) | |
| def gloss_vectors(self, token_vectors: torch.Tensor, input_ids:torch.Tensor)->torch.Tensor: | |
| with self.device: | |
| selection = (input_ids==self.config.gloss_token) | |
| candidates_per_row = selection.sum(axis=1) | |
| max_candidates = candidates_per_row.max() | |
| indices = torch.flatten(selection) | |
| vectors = torch.reshape(token_vectors, | |
| (token_vectors.shape[0]*token_vectors.shape[1], | |
| token_vectors.shape[2])) | |
| gloss_vectors = vectors[indices] | |
| return torch.stack([torch.cat([chunk,torch.zeros((max_candidates-chunk.shape[0], | |
| chunk.shape[1]), | |
| dtype=torch.bfloat16)]) | |
| for chunk in torch.split(gloss_vectors, | |
| tuple(candidates_per_row.tolist()))]) | |
| class CandidateLabeller: | |
| def __init__(self, tokenizer: PreTrainedTokenizer, | |
| ontology: Generator[LexicalExample], | |
| device:torch.device, | |
| retain_candidates: bool = False): | |
| self.tokenizer = tokenizer | |
| self.device = device | |
| self.glosses = { | |
| example.concept: example.definition | |
| for example in ontology | |
| } | |
| self.retain_candidates = retain_candidates | |
| def __call__(self, batch: list[dict]) -> dict: | |
| with self.device: | |
| glosses = ["\n".join(self.glosses[candidate] | |
| for candidate in example) | |
| for example in batch['candidates']] | |
| tokens = self.tokenizer(batch["text"],glosses,padding=True,return_tensors="pt") | |
| result = {"input_ids":tokens.input_ids, | |
| "attention_mask":tokens.attention_mask} | |
| if "label" in batch: | |
| result["labels"] = torch.tensor( | |
| [candidates.index(label) | |
| for (candidates,label) in zip(batch['candidates'], | |
| batch['label'], | |
| strict=True)] | |
| ) | |
| if self.retain_candidates: | |
| result['candidates'] = batch['candidates'] | |
| return result | |