PeteBleackley commited on
Commit
fbbff43
·
verified ·
1 Parent(s): a827632

End of training

Browse files
Files changed (5) hide show
  1. DisamBertSingleSense.py +17 -13
  2. README.md +17 -37
  3. config.json +2 -0
  4. model.safetensors +2 -2
  5. training_args.bin +1 -1
DisamBertSingleSense.py CHANGED
@@ -43,21 +43,18 @@ class DisamBertSingleSense(PreTrainedModel):
43
  self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto")
44
  self.config.vocab_size += 2
45
  self.BaseModel.resize_token_embeddings(self.config.vocab_size)
46
- self.classifier_projection = nn.UninitializedParameter()
47
  self.classifier_head = nn.UninitializedParameter()
48
  self.bias = nn.UninitializedParameter()
49
  self.__entities = None
50
  else:
51
  self.BaseModel = ModernBertModel(config)
52
- self.classifier_projection = nn.Parameter(
53
- torch.empty((256,config.hidden_size)))
54
  self.classifier_head = nn.Parameter(
55
- torch.empty((config.ontology_size, 256))
56
  )
57
- self.bias = nn.Parameter(torch.empty((1,config.ontology_size)))
58
  self.__entities = pd.Series(config.entities)
59
  config.init_basemodel = False
60
- self.activation = nn.Tanhshrink()
61
  self.loss = nn.CrossEntropyLoss()
62
  self.post_init()
63
 
@@ -74,6 +71,9 @@ class DisamBertSingleSense(PreTrainedModel):
74
  vectors = []
75
  batch = []
76
  n = 0
 
 
 
77
  with self.BaseModel.device:
78
  torch.cuda.empty_cache()
79
  for entity in entities:
@@ -95,12 +95,10 @@ class DisamBertSingleSense(PreTrainedModel):
95
  self.__entities = pd.Series(entity_ids)
96
  self.config.entities = entity_ids
97
  self.config.ontology_size = len(entity_ids)
98
- (U,S,Vh) = torch.linalg.svd(torch.cat(vectors, dim=0),False)
99
- self.classifier_head = nn.Parameter(U[:,:256])
100
- self.classifier_projection = nn.Parameter(Vh[:256])
101
  self.bias = nn.Parameter(
102
  torch.nn.init.normal_(
103
- torch.empty((1,self.config.ontology_size)),
104
  std=self.classifier_head.std().item() * np.sqrt(self.config.hidden_size),
105
  )
106
  )
@@ -128,9 +126,15 @@ class DisamBertSingleSense(PreTrainedModel):
128
  output_hidden_states=output_hidden_states,
129
  output_attentions=output_attentions,
130
  )
131
- token_vectors = base_model_output.last_hidden_state[:, 0]
132
- projection = self.activation(torch.einsum('ij,kj->ik',token_vectors,self.classifier_projection))
133
- logits = torch.einsum("ij,kj->ik", projection, self.classifier_head) + self.bias
 
 
 
 
 
 
134
 
135
  return TokenClassifierOutput(
136
  logits=logits,
 
43
  self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto")
44
  self.config.vocab_size += 2
45
  self.BaseModel.resize_token_embeddings(self.config.vocab_size)
 
46
  self.classifier_head = nn.UninitializedParameter()
47
  self.bias = nn.UninitializedParameter()
48
  self.__entities = None
49
  else:
50
  self.BaseModel = ModernBertModel(config)
 
 
51
  self.classifier_head = nn.Parameter(
52
+ torch.empty((config.ontology_size, config.hidden_size))
53
  )
54
+ self.bias = nn.Parameter(torch.empty((1, config.ontology_size)))
55
  self.__entities = pd.Series(config.entities)
56
  config.init_basemodel = False
57
+
58
  self.loss = nn.CrossEntropyLoss()
59
  self.post_init()
60
 
 
71
  vectors = []
72
  batch = []
73
  n = 0
74
+ special_tokens = tokenizer.get_added_vocab()
75
+ self.config.start_token = special_tokens['[START]']
76
+ self.config.end_token = special_tokens['[END]']
77
  with self.BaseModel.device:
78
  torch.cuda.empty_cache()
79
  for entity in entities:
 
95
  self.__entities = pd.Series(entity_ids)
96
  self.config.entities = entity_ids
97
  self.config.ontology_size = len(entity_ids)
98
+ self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
 
 
99
  self.bias = nn.Parameter(
100
  torch.nn.init.normal_(
101
+ torch.empty((1, self.config.ontology_size)),
102
  std=self.classifier_head.std().item() * np.sqrt(self.config.hidden_size),
103
  )
104
  )
 
126
  output_hidden_states=output_hidden_states,
127
  output_attentions=output_attentions,
128
  )
129
+ token_vectors = base_model_output.last_hidden_state
130
+ selection = torch.zeros_like(input_ids,dtype=token_vectors.dtype)
131
+ starts = (input_ids==self.config.start_token).nonzero()
132
+ ends = (input_ids==self.config.end_token).nonzero()
133
+ for (startpos,endpos) in zip(starts,ends,strict=True):
134
+ selection[startpos[0],startpos[1]:endpos[1]+1]=1.0
135
+ selection[:,0] = 1.0
136
+ entity_vectors = torch.einsum('ijk,ij->ik',token_vectors,selection)
137
+ logits = torch.einsum("ij,kj->ik", entity_vectors, self.classifier_head) + self.bias
138
 
139
  return TokenClassifierOutput(
140
  logits=logits,
README.md CHANGED
@@ -22,11 +22,11 @@ should probably proofread and complete it, then remove this comment. -->
22
 
23
  This model is a fine-tuned version of [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) on the semcor dataset.
24
  It achieves the following results on the evaluation set:
25
- - Loss: 4.9159
26
- - Precision: 0.6058
27
- - Recall: 0.6152
28
- - F1: 0.6105
29
- - Matthews: 0.6150
30
 
31
  ## Model description
32
 
@@ -52,43 +52,23 @@ The following hyperparameters were used during training:
52
  - optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
53
  - lr_scheduler_type: inverse_sqrt
54
  - lr_scheduler_warmup_steps: 1000
55
- - num_epochs: 30
56
 
57
  ### Training results
58
 
59
  | Training Loss | Epoch | Step | Validation Loss | Precision | Recall | F1 | Matthews |
60
  |:-------------:|:-----:|:------:|:---------------:|:---------:|:------:|:------:|:--------:|
61
- | No log | 0 | 0 | 11.6611 | 0.0 | 0.0 | 0.0 | -0.0000 |
62
- | 2.5218 | 1.0 | 14014 | 4.1247 | 0.5003 | 0.5245 | 0.5121 | 0.5243 |
63
- | 1.7184 | 2.0 | 28028 | 3.8822 | 0.5656 | 0.5727 | 0.5692 | 0.5726 |
64
- | 1.2533 | 3.0 | 42042 | 3.9284 | 0.5859 | 0.5907 | 0.5883 | 0.5905 |
65
- | 0.9708 | 4.0 | 56056 | 4.0396 | 0.5868 | 0.5907 | 0.5888 | 0.5905 |
66
- | 0.7932 | 5.0 | 70070 | 4.1447 | 0.5899 | 0.5968 | 0.5934 | 0.5966 |
67
- | 0.6030 | 6.0 | 84084 | 4.1830 | 0.5932 | 0.6017 | 0.5974 | 0.6014 |
68
- | 0.5155 | 7.0 | 98098 | 4.2383 | 0.6065 | 0.6082 | 0.6074 | 0.6080 |
69
- | 0.4701 | 8.0 | 112112 | 4.2015 | 0.6014 | 0.6122 | 0.6068 | 0.6120 |
70
- | 0.4166 | 9.0 | 126126 | 4.2186 | 0.6096 | 0.6131 | 0.6113 | 0.6128 |
71
- | 0.3191 | 10.0 | 140140 | 4.3041 | 0.6076 | 0.6096 | 0.6086 | 0.6093 |
72
- | 0.2979 | 11.0 | 154154 | 4.3275 | 0.6082 | 0.6104 | 0.6093 | 0.6102 |
73
- | 0.2633 | 12.0 | 168168 | 4.3902 | 0.6171 | 0.6209 | 0.6190 | 0.6207 |
74
- | 0.2061 | 13.0 | 182182 | 4.4546 | 0.6141 | 0.6196 | 0.6168 | 0.6194 |
75
- | 0.1829 | 14.0 | 196196 | 4.3960 | 0.6134 | 0.6161 | 0.6147 | 0.6159 |
76
- | 0.1793 | 15.0 | 210210 | 4.4565 | 0.6151 | 0.6196 | 0.6174 | 0.6194 |
77
- | 0.1473 | 16.0 | 224224 | 4.4976 | 0.6165 | 0.6218 | 0.6192 | 0.6216 |
78
- | 0.1631 | 17.0 | 238238 | 4.4916 | 0.6113 | 0.6179 | 0.6146 | 0.6177 |
79
- | 0.1679 | 18.0 | 252252 | 4.5221 | 0.6114 | 0.6161 | 0.6137 | 0.6159 |
80
- | 0.1567 | 19.0 | 266266 | 4.5560 | 0.6057 | 0.6166 | 0.6111 | 0.6164 |
81
- | 0.1670 | 20.0 | 280280 | 4.6266 | 0.6127 | 0.6179 | 0.6153 | 0.6177 |
82
- | 0.1817 | 21.0 | 294294 | 4.5746 | 0.6117 | 0.6196 | 0.6157 | 0.6194 |
83
- | 0.1752 | 22.0 | 308308 | 4.6536 | 0.6131 | 0.6192 | 0.6161 | 0.6190 |
84
- | 0.2083 | 23.0 | 322322 | 4.7661 | 0.6108 | 0.6192 | 0.6150 | 0.6190 |
85
- | 0.1764 | 24.0 | 336336 | 4.7735 | 0.6105 | 0.6170 | 0.6137 | 0.6168 |
86
- | 0.2072 | 25.0 | 350350 | 4.8155 | 0.6076 | 0.6157 | 0.6116 | 0.6155 |
87
- | 0.1668 | 26.0 | 364364 | 4.7572 | 0.6025 | 0.6109 | 0.6067 | 0.6107 |
88
- | 0.2046 | 27.0 | 378378 | 4.8226 | 0.6028 | 0.6113 | 0.6070 | 0.6111 |
89
- | 0.2653 | 28.0 | 392392 | 4.8000 | 0.6032 | 0.6166 | 0.6098 | 0.6163 |
90
- | 0.3166 | 29.0 | 406406 | 4.8968 | 0.6062 | 0.6174 | 0.6118 | 0.6172 |
91
- | 0.3265 | 30.0 | 420420 | 4.9159 | 0.6058 | 0.6152 | 0.6105 | 0.6150 |
92
 
93
 
94
  ### Framework versions
 
22
 
23
  This model is a fine-tuned version of [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) on the semcor dataset.
24
  It achieves the following results on the evaluation set:
25
+ - Loss: 10.0010
26
+ - Precision: 0.6717
27
+ - Recall: 0.6486
28
+ - F1: 0.6599
29
+ - Matthews: 0.6479
30
 
31
  ## Model description
32
 
 
52
  - optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
53
  - lr_scheduler_type: inverse_sqrt
54
  - lr_scheduler_warmup_steps: 1000
55
+ - num_epochs: 10
56
 
57
  ### Training results
58
 
59
  | Training Loss | Epoch | Step | Validation Loss | Precision | Recall | F1 | Matthews |
60
  |:-------------:|:-----:|:------:|:---------------:|:---------:|:------:|:------:|:--------:|
61
+ | No log | 0 | 0 | 641.2748 | 0.0 | 0.0 | 0.0 | -0.0000 |
62
+ | 4.9398 | 1.0 | 14014 | 7.1390 | 0.5863 | 0.5649 | 0.5754 | 0.5641 |
63
+ | 1.9762 | 2.0 | 28028 | 6.1541 | 0.6409 | 0.6117 | 0.6260 | 0.6110 |
64
+ | 1.1673 | 3.0 | 42042 | 6.2676 | 0.6534 | 0.6328 | 0.6429 | 0.6321 |
65
+ | 0.4893 | 4.0 | 56056 | 6.9641 | 0.6609 | 0.6394 | 0.6499 | 0.6387 |
66
+ | 0.2413 | 5.0 | 70070 | 7.8858 | 0.6637 | 0.6363 | 0.6497 | 0.6356 |
67
+ | 0.1245 | 6.0 | 84084 | 8.9750 | 0.6662 | 0.6310 | 0.6481 | 0.6304 |
68
+ | 0.0557 | 7.0 | 98098 | 9.4948 | 0.6693 | 0.6398 | 0.6542 | 0.6391 |
69
+ | 0.0451 | 8.0 | 112112 | 9.7435 | 0.6682 | 0.6402 | 0.6539 | 0.6395 |
70
+ | 0.0359 | 9.0 | 126126 | 9.9980 | 0.6676 | 0.6306 | 0.6486 | 0.6299 |
71
+ | 0.0188 | 10.0 | 140140 | 10.0010 | 0.6717 | 0.6486 | 0.6599 | 0.6479 |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  ### Framework versions
config.json CHANGED
@@ -17,6 +17,7 @@
17
  "deterministic_flash_attn": false,
18
  "dtype": "float32",
19
  "embedding_dropout": 0.0,
 
20
  "entities": [
21
  "able.a.01",
22
  "unable.a.01",
@@ -117737,6 +117738,7 @@
117737
  "sep_token_id": 50282,
117738
  "sparse_pred_ignore_index": -100,
117739
  "sparse_prediction": false,
 
117740
  "tie_word_embeddings": true,
117741
  "transformers_version": "5.2.0",
117742
  "use_cache": false,
 
17
  "deterministic_flash_attn": false,
18
  "dtype": "float32",
19
  "embedding_dropout": 0.0,
20
+ "end_token": 50369,
21
  "entities": [
22
  "able.a.01",
23
  "unable.a.01",
 
117738
  "sep_token_id": 50282,
117739
  "sparse_pred_ignore_index": -100,
117740
  "sparse_prediction": false,
117741
+ "start_token": 50368,
117742
  "tie_word_embeddings": true,
117743
  "transformers_version": "5.2.0",
117744
  "use_cache": false,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fb2f14abb4480bbe1187b576b4cb231464599407287fe0263c7c64640fb24f65
3
- size 717817772
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a51dac68b3405593343a667569adf0b33a56734d8818c585b478c96647e8171
3
+ size 957996876
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0cbe6643d4a9a9d097d7d190319dbfe4cdc9057b9380ab52337a02f9c1143eb1
3
  size 4856
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:290a6dde229c724e565072da4f33d9559a54b559464d187b49178893aa79cbc3
3
  size 4856