PeteBleackley commited on
Commit
a827632
·
verified ·
1 Parent(s): 14c2e65

End of training

Browse files
Files changed (4) hide show
  1. DisamBertSingleSense.py +10 -4
  2. README.md +36 -37
  3. model.safetensors +2 -2
  4. training_args.bin +1 -1
DisamBertSingleSense.py CHANGED
@@ -43,18 +43,21 @@ class DisamBertSingleSense(PreTrainedModel):
43
  self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto")
44
  self.config.vocab_size += 2
45
  self.BaseModel.resize_token_embeddings(self.config.vocab_size)
 
46
  self.classifier_head = nn.UninitializedParameter()
47
  self.bias = nn.UninitializedParameter()
48
  self.__entities = None
49
  else:
50
  self.BaseModel = ModernBertModel(config)
 
 
51
  self.classifier_head = nn.Parameter(
52
- torch.empty((config.ontology_size, config.hidden_size))
53
  )
54
  self.bias = nn.Parameter(torch.empty((1,config.ontology_size)))
55
  self.__entities = pd.Series(config.entities)
56
  config.init_basemodel = False
57
-
58
  self.loss = nn.CrossEntropyLoss()
59
  self.post_init()
60
 
@@ -92,7 +95,9 @@ class DisamBertSingleSense(PreTrainedModel):
92
  self.__entities = pd.Series(entity_ids)
93
  self.config.entities = entity_ids
94
  self.config.ontology_size = len(entity_ids)
95
- self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
 
 
96
  self.bias = nn.Parameter(
97
  torch.nn.init.normal_(
98
  torch.empty((1,self.config.ontology_size)),
@@ -124,7 +129,8 @@ class DisamBertSingleSense(PreTrainedModel):
124
  output_attentions=output_attentions,
125
  )
126
  token_vectors = base_model_output.last_hidden_state[:, 0]
127
- logits = torch.einsum("ij,kj->ik", token_vectors, self.classifier_head) + self.bias
 
128
 
129
  return TokenClassifierOutput(
130
  logits=logits,
 
43
  self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto")
44
  self.config.vocab_size += 2
45
  self.BaseModel.resize_token_embeddings(self.config.vocab_size)
46
+ self.classifier_projection = nn.UninitializedParameter()
47
  self.classifier_head = nn.UninitializedParameter()
48
  self.bias = nn.UninitializedParameter()
49
  self.__entities = None
50
  else:
51
  self.BaseModel = ModernBertModel(config)
52
+ self.classifier_projection = nn.Parameter(
53
+ torch.empty((256,config.hidden_size)))
54
  self.classifier_head = nn.Parameter(
55
+ torch.empty((config.ontology_size, 256))
56
  )
57
  self.bias = nn.Parameter(torch.empty((1,config.ontology_size)))
58
  self.__entities = pd.Series(config.entities)
59
  config.init_basemodel = False
60
+ self.activation = nn.Tanhshrink()
61
  self.loss = nn.CrossEntropyLoss()
62
  self.post_init()
63
 
 
95
  self.__entities = pd.Series(entity_ids)
96
  self.config.entities = entity_ids
97
  self.config.ontology_size = len(entity_ids)
98
+ (U,S,Vh) = torch.linalg.svd(torch.cat(vectors, dim=0),False)
99
+ self.classifier_head = nn.Parameter(U[:,:256])
100
+ self.classifier_projection = nn.Parameter(Vh[:256])
101
  self.bias = nn.Parameter(
102
  torch.nn.init.normal_(
103
  torch.empty((1,self.config.ontology_size)),
 
129
  output_attentions=output_attentions,
130
  )
131
  token_vectors = base_model_output.last_hidden_state[:, 0]
132
+ projection = self.activation(torch.einsum('ij,kj->ik',token_vectors,self.classifier_projection))
133
+ logits = torch.einsum("ij,kj->ik", projection, self.classifier_head) + self.bias
134
 
135
  return TokenClassifierOutput(
136
  logits=logits,
README.md CHANGED
@@ -22,11 +22,11 @@ should probably proofread and complete it, then remove this comment. -->
22
 
23
  This model is a fine-tuned version of [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) on the semcor dataset.
24
  It achieves the following results on the evaluation set:
25
- - Loss: 4.3051
26
- - Precision: 0.6129
27
- - Recall: 0.6236
28
- - F1: 0.6182
29
- - Matthews: 0.6233
30
 
31
  ## Model description
32
 
@@ -53,43 +53,42 @@ The following hyperparameters were used during training:
53
  - lr_scheduler_type: inverse_sqrt
54
  - lr_scheduler_warmup_steps: 1000
55
  - num_epochs: 30
56
- - label_smoothing_factor: 0.1
57
 
58
  ### Training results
59
 
60
  | Training Loss | Epoch | Step | Validation Loss | Precision | Recall | F1 | Matthews |
61
  |:-------------:|:-----:|:------:|:---------------:|:---------:|:------:|:------:|:--------:|
62
- | No log | 0 | 0 | 208.8371 | 0.0 | 0.0 | 0.0 | 0.0 |
63
- | 6.4576 | 1.0 | 14014 | 7.0514 | 0.5818 | 0.5259 | 0.5524 | 0.5256 |
64
- | 4.4009 | 2.0 | 28028 | 5.0733 | 0.5949 | 0.5819 | 0.5884 | 0.5819 |
65
- | 2.8900 | 3.0 | 42042 | 4.5159 | 0.6520 | 0.6131 | 0.6319 | 0.6127 |
66
- | 2.4798 | 4.0 | 56056 | 4.2910 | 0.6449 | 0.6060 | 0.6249 | 0.6058 |
67
- | 2.1994 | 5.0 | 70070 | 4.1419 | 0.6295 | 0.6126 | 0.6209 | 0.6126 |
68
- | 1.9889 | 6.0 | 84084 | 4.0561 | 0.6316 | 0.6192 | 0.6253 | 0.6191 |
69
- | 1.8689 | 7.0 | 98098 | 3.9877 | 0.6350 | 0.6183 | 0.6266 | 0.6182 |
70
- | 1.7944 | 8.0 | 112112 | 3.9447 | 0.6216 | 0.6218 | 0.6217 | 0.6217 |
71
- | 1.6724 | 9.0 | 126126 | 3.9353 | 0.6037 | 0.6096 | 0.6066 | 0.6094 |
72
- | 1.6316 | 10.0 | 140140 | 3.9487 | 0.6135 | 0.6148 | 0.6141 | 0.6147 |
73
- | 1.6296 | 11.0 | 154154 | 3.9428 | 0.6160 | 0.6231 | 0.6195 | 0.6231 |
74
- | 1.5991 | 12.0 | 168168 | 4.0174 | 0.6137 | 0.6161 | 0.6149 | 0.6160 |
75
- | 1.5809 | 13.0 | 182182 | 4.0325 | 0.6087 | 0.6166 | 0.6126 | 0.6165 |
76
- | 1.5724 | 14.0 | 196196 | 4.0345 | 0.6157 | 0.6236 | 0.6196 | 0.6235 |
77
- | 1.5707 | 15.0 | 210210 | 4.0787 | 0.6142 | 0.6236 | 0.6189 | 0.6235 |
78
- | 1.5606 | 16.0 | 224224 | 4.0881 | 0.6146 | 0.6205 | 0.6175 | 0.6204 |
79
- | 1.5534 | 17.0 | 238238 | 4.1319 | 0.6041 | 0.6139 | 0.6090 | 0.6137 |
80
- | 1.5543 | 18.0 | 252252 | 4.1268 | 0.6133 | 0.6231 | 0.6182 | 0.6229 |
81
- | 1.5438 | 19.0 | 266266 | 4.1633 | 0.6080 | 0.6174 | 0.6127 | 0.6172 |
82
- | 1.5446 | 20.0 | 280280 | 4.1796 | 0.6080 | 0.6201 | 0.6140 | 0.6198 |
83
- | 1.5378 | 21.0 | 294294 | 4.2057 | 0.6144 | 0.6236 | 0.6190 | 0.6233 |
84
- | 1.5371 | 22.0 | 308308 | 4.2225 | 0.6119 | 0.6218 | 0.6168 | 0.6216 |
85
- | 1.5343 | 23.0 | 322322 | 4.2246 | 0.6051 | 0.6179 | 0.6114 | 0.6176 |
86
- | 1.5313 | 24.0 | 336336 | 4.2584 | 0.6086 | 0.6166 | 0.6126 | 0.6163 |
87
- | 1.5306 | 25.0 | 350350 | 4.2558 | 0.6084 | 0.6183 | 0.6133 | 0.6181 |
88
- | 1.5268 | 26.0 | 364364 | 4.2737 | 0.6134 | 0.6231 | 0.6182 | 0.6229 |
89
- | 1.5271 | 27.0 | 378378 | 4.2826 | 0.6059 | 0.6174 | 0.6116 | 0.6172 |
90
- | 1.5267 | 28.0 | 392392 | 4.2831 | 0.6041 | 0.6161 | 0.6100 | 0.6159 |
91
- | 1.5250 | 29.0 | 406406 | 4.2994 | 0.6095 | 0.6192 | 0.6143 | 0.6189 |
92
- | 1.5238 | 30.0 | 420420 | 4.3051 | 0.6129 | 0.6236 | 0.6182 | 0.6233 |
93
 
94
 
95
  ### Framework versions
 
22
 
23
  This model is a fine-tuned version of [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) on the semcor dataset.
24
  It achieves the following results on the evaluation set:
25
+ - Loss: 4.9159
26
+ - Precision: 0.6058
27
+ - Recall: 0.6152
28
+ - F1: 0.6105
29
+ - Matthews: 0.6150
30
 
31
  ## Model description
32
 
 
53
  - lr_scheduler_type: inverse_sqrt
54
  - lr_scheduler_warmup_steps: 1000
55
  - num_epochs: 30
 
56
 
57
  ### Training results
58
 
59
  | Training Loss | Epoch | Step | Validation Loss | Precision | Recall | F1 | Matthews |
60
  |:-------------:|:-----:|:------:|:---------------:|:---------:|:------:|:------:|:--------:|
61
+ | No log | 0 | 0 | 11.6611 | 0.0 | 0.0 | 0.0 | -0.0000 |
62
+ | 2.5218 | 1.0 | 14014 | 4.1247 | 0.5003 | 0.5245 | 0.5121 | 0.5243 |
63
+ | 1.7184 | 2.0 | 28028 | 3.8822 | 0.5656 | 0.5727 | 0.5692 | 0.5726 |
64
+ | 1.2533 | 3.0 | 42042 | 3.9284 | 0.5859 | 0.5907 | 0.5883 | 0.5905 |
65
+ | 0.9708 | 4.0 | 56056 | 4.0396 | 0.5868 | 0.5907 | 0.5888 | 0.5905 |
66
+ | 0.7932 | 5.0 | 70070 | 4.1447 | 0.5899 | 0.5968 | 0.5934 | 0.5966 |
67
+ | 0.6030 | 6.0 | 84084 | 4.1830 | 0.5932 | 0.6017 | 0.5974 | 0.6014 |
68
+ | 0.5155 | 7.0 | 98098 | 4.2383 | 0.6065 | 0.6082 | 0.6074 | 0.6080 |
69
+ | 0.4701 | 8.0 | 112112 | 4.2015 | 0.6014 | 0.6122 | 0.6068 | 0.6120 |
70
+ | 0.4166 | 9.0 | 126126 | 4.2186 | 0.6096 | 0.6131 | 0.6113 | 0.6128 |
71
+ | 0.3191 | 10.0 | 140140 | 4.3041 | 0.6076 | 0.6096 | 0.6086 | 0.6093 |
72
+ | 0.2979 | 11.0 | 154154 | 4.3275 | 0.6082 | 0.6104 | 0.6093 | 0.6102 |
73
+ | 0.2633 | 12.0 | 168168 | 4.3902 | 0.6171 | 0.6209 | 0.6190 | 0.6207 |
74
+ | 0.2061 | 13.0 | 182182 | 4.4546 | 0.6141 | 0.6196 | 0.6168 | 0.6194 |
75
+ | 0.1829 | 14.0 | 196196 | 4.3960 | 0.6134 | 0.6161 | 0.6147 | 0.6159 |
76
+ | 0.1793 | 15.0 | 210210 | 4.4565 | 0.6151 | 0.6196 | 0.6174 | 0.6194 |
77
+ | 0.1473 | 16.0 | 224224 | 4.4976 | 0.6165 | 0.6218 | 0.6192 | 0.6216 |
78
+ | 0.1631 | 17.0 | 238238 | 4.4916 | 0.6113 | 0.6179 | 0.6146 | 0.6177 |
79
+ | 0.1679 | 18.0 | 252252 | 4.5221 | 0.6114 | 0.6161 | 0.6137 | 0.6159 |
80
+ | 0.1567 | 19.0 | 266266 | 4.5560 | 0.6057 | 0.6166 | 0.6111 | 0.6164 |
81
+ | 0.1670 | 20.0 | 280280 | 4.6266 | 0.6127 | 0.6179 | 0.6153 | 0.6177 |
82
+ | 0.1817 | 21.0 | 294294 | 4.5746 | 0.6117 | 0.6196 | 0.6157 | 0.6194 |
83
+ | 0.1752 | 22.0 | 308308 | 4.6536 | 0.6131 | 0.6192 | 0.6161 | 0.6190 |
84
+ | 0.2083 | 23.0 | 322322 | 4.7661 | 0.6108 | 0.6192 | 0.6150 | 0.6190 |
85
+ | 0.1764 | 24.0 | 336336 | 4.7735 | 0.6105 | 0.6170 | 0.6137 | 0.6168 |
86
+ | 0.2072 | 25.0 | 350350 | 4.8155 | 0.6076 | 0.6157 | 0.6116 | 0.6155 |
87
+ | 0.1668 | 26.0 | 364364 | 4.7572 | 0.6025 | 0.6109 | 0.6067 | 0.6107 |
88
+ | 0.2046 | 27.0 | 378378 | 4.8226 | 0.6028 | 0.6113 | 0.6070 | 0.6111 |
89
+ | 0.2653 | 28.0 | 392392 | 4.8000 | 0.6032 | 0.6166 | 0.6098 | 0.6163 |
90
+ | 0.3166 | 29.0 | 406406 | 4.8968 | 0.6062 | 0.6174 | 0.6118 | 0.6172 |
91
+ | 0.3265 | 30.0 | 420420 | 4.9159 | 0.6058 | 0.6152 | 0.6105 | 0.6150 |
92
 
93
 
94
  ### Framework versions
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f23f0a9070c06ea4016a745d0c112399fece10b493814edd9213c75d987e690f
3
- size 957996876
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb2f14abb4480bbe1187b576b4cb231464599407287fe0263c7c64640fb24f65
3
+ size 717817772
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:aaaef26dc3ff2a41322089fbda6952b609589eb9324e926c820581bc232909b1
3
  size 4856
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cbe6643d4a9a9d097d7d190319dbfe4cdc9057b9380ab52337a02f9c1143eb1
3
  size 4856