Fixed missing casting of positional encodings to int32
Browse files- README.md +4 -2
- modeling_m5_encoder.py +3 -4
README.md
CHANGED
|
@@ -56,6 +56,7 @@ selfies, pos_encod, _ = model.get_positional_encodings_and_align(smiles, seed=0)
|
|
| 56 |
encoding = tokenizer(selfies, return_tensors="pt")
|
| 57 |
input_ids = encoding["input_ids"]
|
| 58 |
attn_mask = encoding["attention_mask"]
|
|
|
|
| 59 |
rel_pos = torch.tensor(pos_encod).unsqueeze(0) # (1, seq_len, seq_len)
|
| 60 |
|
| 61 |
outputs = model(input_ids=input_ids, attention_mask=attn_mask, relative_position=rel_pos)
|
|
@@ -155,6 +156,7 @@ The HDF5 files are available for download below. These are intended to be proces
|
|
| 155 |
|
| 156 |
## Limitations
|
| 157 |
|
| 158 |
-
- **Token length:** The built-in `prepare_data` helper encodes pairwise molecular-graph distances in an `int16` matrix.
|
|
|
|
| 159 |
- **Conformer handling:** Duplicate SMILES representing different conformers are kept in the dataset. The model therefore predicts an implicit average over conformers rather than a geometry-specific value, which may reduce accuracy for conformation-sensitive properties.
|
| 160 |
-
- **Scope:** The model is pretrained on
|
|
|
|
| 56 |
encoding = tokenizer(selfies, return_tensors="pt")
|
| 57 |
input_ids = encoding["input_ids"]
|
| 58 |
attn_mask = encoding["attention_mask"]
|
| 59 |
+
|
| 60 |
rel_pos = torch.tensor(pos_encod).unsqueeze(0) # (1, seq_len, seq_len)
|
| 61 |
|
| 62 |
outputs = model(input_ids=input_ids, attention_mask=attn_mask, relative_position=rel_pos)
|
|
|
|
| 156 |
|
| 157 |
## Limitations
|
| 158 |
|
| 159 |
+
- **Token length:** The built-in `prepare_data` helper encodes pairwise molecular-graph distances in an `int16` matrix.
|
| 160 |
+
This was done to decrease the size of pairwise-distance matrices in case one intends to pre-compute them before training. Due to the `prepare_data` limitations, molecules whose SELFIES tokenization exceeds **32,766 tokens** (`numpy.iinfo(numpy.int16).max - 1`) are not supported. In practice, most molecule will be well below this limit.
|
| 161 |
- **Conformer handling:** Duplicate SMILES representing different conformers are kept in the dataset. The model therefore predicts an implicit average over conformers rather than a geometry-specific value, which may reduce accuracy for conformation-sensitive properties.
|
| 162 |
+
- **Scope:** The model is pretrained on molecules present in PubChemQC. Performance on certain compounds types and large macromolecules outside the training distribution has not been evaluated. Therefore, the model will be stronger with molecules of MW <= 1000 or number of heavy atoms <= 79.
|
modeling_m5_encoder.py
CHANGED
|
@@ -33,8 +33,8 @@ class M5EncoderConfig(T5Config):
|
|
| 33 |
dropout_rate = 0,
|
| 34 |
feed_forward_proj = "gated-gelu",
|
| 35 |
classifier_dropout=0,
|
| 36 |
-
relative_attention_max_distance=
|
| 37 |
-
relative_attention_num_buckets=
|
| 38 |
vocab_size=1032,
|
| 39 |
num_decoder_layers=0,
|
| 40 |
**kwargs,
|
|
@@ -263,12 +263,11 @@ class M5EncoderModel(T5EncoderModel):
|
|
| 263 |
input_ids=input_ids,
|
| 264 |
attention_mask=attention_mask,
|
| 265 |
inputs_embeds=inputs_embeds,
|
| 266 |
-
|
| 267 |
head_mask=head_mask,
|
| 268 |
output_attentions=output_attentions,
|
| 269 |
output_hidden_states=output_hidden_states,
|
| 270 |
return_dict=return_dict,
|
| 271 |
-
relative_position=relative_position
|
| 272 |
)
|
| 273 |
|
| 274 |
return encoder_outputs
|
|
|
|
| 33 |
dropout_rate = 0,
|
| 34 |
feed_forward_proj = "gated-gelu",
|
| 35 |
classifier_dropout=0,
|
| 36 |
+
relative_attention_max_distance=96,
|
| 37 |
+
relative_attention_num_buckets=32,
|
| 38 |
vocab_size=1032,
|
| 39 |
num_decoder_layers=0,
|
| 40 |
**kwargs,
|
|
|
|
| 263 |
input_ids=input_ids,
|
| 264 |
attention_mask=attention_mask,
|
| 265 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 266 |
head_mask=head_mask,
|
| 267 |
output_attentions=output_attentions,
|
| 268 |
output_hidden_states=output_hidden_states,
|
| 269 |
return_dict=return_dict,
|
| 270 |
+
relative_position=relative_position.to(dtype=torch.int32) if relative_position is not None else None
|
| 271 |
)
|
| 272 |
|
| 273 |
return encoder_outputs
|