PeteBleackley commited on
Commit
86d4840
·
verified ·
1 Parent(s): d5c9b75

Upload DisamBert

Browse files
Files changed (3) hide show
  1. DisamBert.py +5 -5
  2. config.json +2 -1
  3. model.safetensors +1 -1
DisamBert.py CHANGED
@@ -47,9 +47,9 @@ class DisamBert(PreTrainedModel):
47
  else:
48
  self.BaseModel = ModernBertModel(config)
49
  self.classifier_head = nn.Parameter(
50
- torch.empty((config.vocab_size, config.hidden_size))
51
  )
52
- self.bias = nn.Parameter(torch.empty((config.vocab_size, 1)))
53
  self.__entities = pd.Series(config.entities)
54
  config.init_basemodel = False
55
  self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
@@ -87,11 +87,11 @@ class DisamBert(PreTrainedModel):
87
 
88
  self.__entities = pd.Series(entity_ids)
89
  self.config.entities = entity_ids
90
- self.config.vocab_size = len(entity_ids)
91
  self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
92
  self.bias = nn.Parameter(
93
  torch.nn.init.normal_(
94
- torch.empty((self.config.vocab_size, 1)), std=self.classifier_head.std().item()
95
  )
96
  )
97
 
@@ -183,7 +183,7 @@ class DisamBert(PreTrainedModel):
183
  torch.cat(
184
  [
185
  sentence,
186
- torch.zeros((self.__entities.shape[0], maxlength - length)),
187
  ],
188
  dim=1,
189
  )
 
47
  else:
48
  self.BaseModel = ModernBertModel(config)
49
  self.classifier_head = nn.Parameter(
50
+ torch.empty((config.ontology_size, config.hidden_size))
51
  )
52
+ self.bias = nn.Parameter(torch.empty((config.ontology_size, 1)))
53
  self.__entities = pd.Series(config.entities)
54
  config.init_basemodel = False
55
  self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
 
87
 
88
  self.__entities = pd.Series(entity_ids)
89
  self.config.entities = entity_ids
90
+ self.config.ontology_size = len(entity_ids)
91
  self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
92
  self.bias = nn.Parameter(
93
  torch.nn.init.normal_(
94
+ torch.empty((self.config.ontology_size, 1)), std=self.classifier_head.std().item()
95
  )
96
  )
97
 
 
183
  torch.cat(
184
  [
185
  sentence,
186
+ torch.zeros((self.config.ontology_size, maxlength - length)),
187
  ],
188
  dim=1,
189
  )
config.json CHANGED
@@ -117722,6 +117722,7 @@
117722
  "norm_eps": 1e-05,
117723
  "num_attention_heads": 12,
117724
  "num_hidden_layers": 22,
 
117725
  "pad_token_id": 50283,
117726
  "position_embedding_type": "absolute",
117727
  "repad_logits_with_grad": false,
@@ -117742,5 +117743,5 @@
117742
  "tokenizer_path": "answerdotai/ModernBERT-base",
117743
  "transformers_version": "5.0.0",
117744
  "use_cache": false,
117745
- "vocab_size": 117660
117746
  }
 
117722
  "norm_eps": 1e-05,
117723
  "num_attention_heads": 12,
117724
  "num_hidden_layers": 22,
117725
+ "ontology_size": 117660,
117726
  "pad_token_id": 50283,
117727
  "position_embedding_type": "absolute",
117728
  "repad_logits_with_grad": false,
 
117743
  "tokenizer_path": "answerdotai/ModernBERT-base",
117744
  "transformers_version": "5.0.0",
117745
  "use_cache": false,
117746
+ "vocab_size": 50368
117747
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ff4e9bebae857919d9ca236d04b7bb8aae63f405f9cd624bc7ee5ac59f2bd54f
3
  size 957993808
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cad83795b87cb440ec0f169e2264c4b5072c8b762307706c2c04db26fbced65
3
  size 957993808