PeteBleackley commited on
Commit
9252eda
·
verified ·
1 Parent(s): 1b09699

Upload DisamBert

Browse files
Files changed (3) hide show
  1. DisamBert.py +121 -55
  2. config.json +1 -0
  3. model.safetensors +1 -1
DisamBert.py CHANGED
@@ -1,11 +1,20 @@
1
  from collections.abc import Generator, Iterable
2
  from dataclasses import dataclass
3
  from enum import StrEnum
 
4
 
5
  import pandas as pd
6
  import torch
7
  import torch.nn as nn
8
- from transformers import AutoConfig, AutoModel, AutoTokenizer, ModernBertModel, PreTrainedConfig, PreTrainedModel
 
 
 
 
 
 
 
 
9
 
10
  BATCH_SIZE = 64
11
 
@@ -36,10 +45,13 @@ class DisamBert(PreTrainedModel):
36
  self.__entities = None
37
  else:
38
  self.BaseModel = ModernBertModel(config)
39
- self.classifier_head = nn.Parameter(torch.empty((config.vocab_size,config.hidden_size)))
40
- self._entities
 
 
41
  config.init_basemodel = False
42
  self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
 
43
  self.post_init()
44
 
45
  @classmethod
@@ -82,51 +94,58 @@ class DisamBert(PreTrainedModel):
82
  self.__entities = pd.Series(self.config.entities)
83
  return self.__entities
84
 
85
- def forward(self, sentences: Iterable[str], indices: Iterable[list[int]]) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
86
  assert not nn.parameter.is_lazy(self.classifier_head), (
87
  "Run init_classifier to initialise weights"
88
  )
89
- all_indices = []
90
- all_tokens = []
91
- with self.BaseModel.device:
92
- for sentence, span_indices in zip(sentences, indices, strict=True):
93
- indices = []
94
- tokens = []
95
- last_span = len(span_indices) - 2
96
- for i, position in enumerate(span_indices[:-1]):
97
- span = sentence[position : span_indices[i + 1]]
98
- span_tokens = self.tokenizer([span], padding=False)["input_ids"][0]
99
- if i > 0:
100
- span_tokens = span_tokens[1:]
101
- if i < last_span:
102
- span_tokens = span_tokens[:-1]
103
- indices.append(len(span_tokens))
104
- tokens.extend(span_tokens)
105
- all_indices.append(indices)
106
- all_tokens.append(tokens)
107
- sentence_lengths = [len(boundaries) for boundaries in all_indices]
108
- maxlen = max(sentence_lengths)
109
- batch = self.pad(all_tokens)
110
- token_vectors = self.BaseModel(batch.input_ids, batch.attention_mask).last_hidden_state
111
- span_vectors = torch.cat(
112
- [
113
- torch.vstack(
114
- [
115
- torch.sum(chunk, dim=0)
116
- for chunk in self.split(token_vectors[i], sentence_indices)
117
- ]
118
- )
119
- for (i, sentence_indices) in enumerate(all_indices)
120
- ]
121
- )
122
- logits = torch.einsum("ij,kj->ki", span_vectors, self.classifier_head)
123
- split_logits = torch.split(logits, sentence_lengths, dim=1)
124
- return torch.stack(
125
- [
126
- self.extend_to_max_length(sentence, length, maxlen)
127
- for (sentence, length) in zip(split_logits, sentence_lengths, strict=True)
128
- ]
129
- )
130
 
131
  def split(self, vectors: torch.Tensor, lengths: list[int]) -> tuple[torch.Tensor, ...]:
132
  maxlen = vectors.shape[0]
@@ -135,7 +154,7 @@ class DisamBert(PreTrainedModel):
135
  chunks = vectors.split((lengths + [maxlen - total_length]) if is_padded else lengths)
136
  return chunks[:-1] if is_padded else chunks
137
 
138
- def pad(self, tokens: list[int]) -> PaddedBatch:
139
  lengths = [len(sentence) for sentence in tokens]
140
  maxlen = max(lengths)
141
  input_ids = torch.tensor(
@@ -152,14 +171,61 @@ class DisamBert(PreTrainedModel):
152
  def extend_to_max_length(
153
  self, sentence: torch.Tensor, length: int, maxlength: int
154
  ) -> torch.Tensor:
155
- return (
156
- torch.cat(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  [
158
- sentence,
159
- torch.zeros((self.__entities.shape[0], maxlength - length)),
160
- ],
161
- dim=1,
162
  )
163
- if length < maxlength
164
- else sentence
165
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from collections.abc import Generator, Iterable
2
  from dataclasses import dataclass
3
  from enum import StrEnum
4
+ from itertools import chain
5
 
6
  import pandas as pd
7
  import torch
8
  import torch.nn as nn
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoModel,
12
+ AutoTokenizer,
13
+ ModernBertModel,
14
+ PreTrainedConfig,
15
+ PreTrainedModel,
16
+ )
17
+ from transformers.modeling_outputs import TokenClassifierOutput
18
 
19
  BATCH_SIZE = 64
20
 
 
45
  self.__entities = None
46
  else:
47
  self.BaseModel = ModernBertModel(config)
48
+ self.classifier_head = nn.Parameter(
49
+ torch.empty((config.vocab_size, config.hidden_size))
50
+ )
51
+ self.__entities = pd.Series(config.entities)
52
  config.init_basemodel = False
53
  self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
54
+ self.loss = nn.CrossEntropyLoss()
55
  self.post_init()
56
 
57
  @classmethod
 
94
  self.__entities = pd.Series(self.config.entities)
95
  return self.__entities
96
 
97
+ def forward(
98
+ self,
99
+ input_ids: torch.Tensor,
100
+ attention_mask: torch.Tensor,
101
+ lengths: list[list[int]],
102
+ candidates: list[list[list[int]]],
103
+ labels: Iterable[list[int]] | None = None,
104
+ output_hidden_states: bool = False,
105
+ output_attentions: bool = False,
106
+ ) -> TokenClassifierOutput:
107
  assert not nn.parameter.is_lazy(self.classifier_head), (
108
  "Run init_classifier to initialise weights"
109
  )
110
+ base_model_output = self.BaseModel(
111
+ input_ids,
112
+ attention_mask,
113
+ output_hidden_states=output_hidden_states,
114
+ output_attentions=output_attentions,
115
+ )
116
+ token_vectors = base_model_output.last_hidden_state
117
+ span_vectors = torch.cat(
118
+ [
119
+ torch.vstack(
120
+ [
121
+ torch.sum(chunk, dim=0)
122
+ for chunk in self.split(token_vectors[i], sentence_indices)
123
+ ]
124
+ )
125
+ for (i, sentence_indices) in enumerate(lengths)
126
+ ]
127
+ )
128
+ logits = torch.einsum("ij,kj->ki", span_vectors, self.classifier_head)
129
+ logits1 = logits - logits.min()
130
+ mask = torch.zeros_like(logits)
131
+ for (i,concepts) in enumerate(chain.from_iterable(candidates)):
132
+ mask[concepts,i] = torch.tensor(1.0)
133
+ logits2 = logits1 * mask
134
+ sentence_lengths = [len(sentence_indices) for sentence_indices in lengths]
135
+ maxlen = max(sentence_lengths)
136
+ split_logits = torch.split(logits2, sentence_lengths, dim=1)
137
+ logits3 = torch.stack(
138
+ [
139
+ self.extend_to_max_length(sentence, length, maxlen)
140
+ for (sentence, length) in zip(split_logits, sentence_lengths, strict=True)
141
+ ]
142
+ )
143
+ return TokenClassifierOutput(
144
+ logits=logits3,
145
+ loss=self.loss(logits3, labels) if labels is not None else None,
146
+ hidden_states=base_model_output.hidden_states if output_hidden_states else None,
147
+ attentions=base_model_output.attentions if output_attentions else None,
148
+ )
 
 
149
 
150
  def split(self, vectors: torch.Tensor, lengths: list[int]) -> tuple[torch.Tensor, ...]:
151
  maxlen = vectors.shape[0]
 
154
  chunks = vectors.split((lengths + [maxlen - total_length]) if is_padded else lengths)
155
  return chunks[:-1] if is_padded else chunks
156
 
157
+ def pad(self, tokens: Iterable[list[int]]) -> PaddedBatch:
158
  lengths = [len(sentence) for sentence in tokens]
159
  maxlen = max(lengths)
160
  input_ids = torch.tensor(
 
171
  def extend_to_max_length(
172
  self, sentence: torch.Tensor, length: int, maxlength: int
173
  ) -> torch.Tensor:
174
+ with self.BaseModel.device:
175
+ return (
176
+ torch.cat(
177
+ [
178
+ sentence,
179
+ torch.zeros((self.__entities.shape[0], maxlength - length)),
180
+ ],
181
+ dim=1,
182
+ )
183
+ if length < maxlength
184
+ else sentence
185
+ )
186
+
187
+ def pad_labels(self, labels: list[list[int]]) -> torch.Tensor:
188
+ unk = len(self.config.entities) - 1
189
+ lengths = [len(seq) for seq in labels]
190
+ maxlen = max(lengths)
191
+ with self.BaseModel.device:
192
+ return torch.tensor(
193
  [
194
+ seq + [unk] * (maxlen - length)
195
+ for (seq, length) in zip(labels, lengths, strict=True)
196
+ ]
 
197
  )
198
+
199
+ def tokenize(
200
+ self, batch: list[dict[str, str | list[int]]]
201
+ ) -> dict[str, torch.Tensor | list[list[int]]]:
202
+ all_indices = []
203
+ all_tokens = []
204
+ with self.BaseModel.device:
205
+ for example in batch:
206
+ text = example["text"]
207
+ span_indices = example["indices"]
208
+ indices = []
209
+ tokens = []
210
+ last_span = len(span_indices) - 2
211
+ for i, position in enumerate(span_indices[:-1]):
212
+ span = text[position : span_indices[i + 1]]
213
+ span_tokens = self.tokenizer([span], padding=False)["input_ids"][0]
214
+ if i > 0:
215
+ span_tokens = span_tokens[1:]
216
+ if i < last_span:
217
+ span_tokens = span_tokens[:-1]
218
+ indices.append(len(span_tokens))
219
+ tokens.extend(span_tokens)
220
+ all_indices.append(indices)
221
+ all_tokens.append(tokens)
222
+ padded = self.pad(all_tokens)
223
+ result = {
224
+ "input_ids": padded.input_ids,
225
+ "attention_mask": padded.attention_mask,
226
+ "lengths": all_indices,
227
+ "candidates": [example['candidates'] for example in batch]
228
+ }
229
+ if "labels" in batch[0]:
230
+ result["labels"] = self.pad_labels([example["labels"] for example in batch])
231
+ return result
config.json CHANGED
@@ -117741,5 +117741,6 @@
117741
  "tie_word_embeddings": true,
117742
  "tokenizer_path": "answerdotai/ModernBERT-base",
117743
  "transformers_version": "5.0.0",
 
117744
  "vocab_size": 117660
117745
  }
 
117741
  "tie_word_embeddings": true,
117742
  "tokenizer_path": "answerdotai/ModernBERT-base",
117743
  "transformers_version": "5.0.0",
117744
+ "use_cache": false,
117745
  "vocab_size": 117660
117746
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:79d0851573b5002b29d196af74a0b87c06e774b30889fe729bd17f323af7fc2f
3
  size 957523088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50c403c889a37e9ed106f0912eafe6e97fd2e9bffff26a34d9af7b284643657e
3
  size 957523088