PeteBleackley commited on
Commit
d5c9b75
·
verified ·
1 Parent(s): 9252eda

Upload DisamBert

Browse files
Files changed (2) hide show
  1. DisamBert.py +11 -4
  2. model.safetensors +2 -2
DisamBert.py CHANGED
@@ -42,12 +42,14 @@ class DisamBert(PreTrainedModel):
42
  if config.init_basemodel:
43
  self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto")
44
  self.classifier_head = nn.UninitializedParameter()
 
45
  self.__entities = None
46
  else:
47
  self.BaseModel = ModernBertModel(config)
48
  self.classifier_head = nn.Parameter(
49
  torch.empty((config.vocab_size, config.hidden_size))
50
  )
 
51
  self.__entities = pd.Series(config.entities)
52
  config.init_basemodel = False
53
  self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
@@ -87,6 +89,11 @@ class DisamBert(PreTrainedModel):
87
  self.config.entities = entity_ids
88
  self.config.vocab_size = len(entity_ids)
89
  self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
 
 
 
 
 
90
 
91
  @property
92
  def entities(self) -> pd.Series:
@@ -125,11 +132,11 @@ class DisamBert(PreTrainedModel):
125
  for (i, sentence_indices) in enumerate(lengths)
126
  ]
127
  )
128
- logits = torch.einsum("ij,kj->ki", span_vectors, self.classifier_head)
129
  logits1 = logits - logits.min()
130
  mask = torch.zeros_like(logits)
131
- for (i,concepts) in enumerate(chain.from_iterable(candidates)):
132
- mask[concepts,i] = torch.tensor(1.0)
133
  logits2 = logits1 * mask
134
  sentence_lengths = [len(sentence_indices) for sentence_indices in lengths]
135
  maxlen = max(sentence_lengths)
@@ -224,7 +231,7 @@ class DisamBert(PreTrainedModel):
224
  "input_ids": padded.input_ids,
225
  "attention_mask": padded.attention_mask,
226
  "lengths": all_indices,
227
- "candidates": [example['candidates'] for example in batch]
228
  }
229
  if "labels" in batch[0]:
230
  result["labels"] = self.pad_labels([example["labels"] for example in batch])
 
42
  if config.init_basemodel:
43
  self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto")
44
  self.classifier_head = nn.UninitializedParameter()
45
+ self.bias = nn.UninitializedParameter()
46
  self.__entities = None
47
  else:
48
  self.BaseModel = ModernBertModel(config)
49
  self.classifier_head = nn.Parameter(
50
  torch.empty((config.vocab_size, config.hidden_size))
51
  )
52
+ self.bias = nn.Parameter(torch.empty((config.vocab_size, 1)))
53
  self.__entities = pd.Series(config.entities)
54
  config.init_basemodel = False
55
  self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
 
89
  self.config.entities = entity_ids
90
  self.config.vocab_size = len(entity_ids)
91
  self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
92
+ self.bias = nn.Parameter(
93
+ torch.nn.init.normal_(
94
+ torch.empty((self.config.vocab_size, 1)), std=self.classifier_head.std().item()
95
+ )
96
+ )
97
 
98
  @property
99
  def entities(self) -> pd.Series:
 
132
  for (i, sentence_indices) in enumerate(lengths)
133
  ]
134
  )
135
+ logits = torch.einsum("ij,kj->ki", span_vectors, self.classifier_head) + self.bias
136
  logits1 = logits - logits.min()
137
  mask = torch.zeros_like(logits)
138
+ for i, concepts in enumerate(chain.from_iterable(candidates)):
139
+ mask[concepts, i] = torch.tensor(1.0)
140
  logits2 = logits1 * mask
141
  sentence_lengths = [len(sentence_indices) for sentence_indices in lengths]
142
  maxlen = max(sentence_lengths)
 
231
  "input_ids": padded.input_ids,
232
  "attention_mask": padded.attention_mask,
233
  "lengths": all_indices,
234
+ "candidates": [example["candidates"] for example in batch],
235
  }
236
  if "labels" in batch[0]:
237
  result["labels"] = self.pad_labels([example["labels"] for example in batch])
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:50c403c889a37e9ed106f0912eafe6e97fd2e9bffff26a34d9af7b284643657e
3
- size 957523088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff4e9bebae857919d9ca236d04b7bb8aae63f405f9cd624bc7ee5ac59f2bd54f
3
+ size 957993808