Fixed weight loading, refactored model
Browse files- config.json +1 -1
- modeling_m5_encoder.py +43 -2
config.json
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
"AutoConfig": "modeling_m5_encoder.M5EncoderConfig",
|
| 7 |
-
"AutoModel": "modeling_m5_encoder.
|
| 8 |
"AutoModelForSequenceClassification": "modeling_m5_encoder.M5ModelForRegression"
|
| 9 |
},
|
| 10 |
"classifier_dropout": 0,
|
|
|
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
"AutoConfig": "modeling_m5_encoder.M5EncoderConfig",
|
| 7 |
+
"AutoModel": "modeling_m5_encoder.M5Model",
|
| 8 |
"AutoModelForSequenceClassification": "modeling_m5_encoder.M5ModelForRegression"
|
| 9 |
},
|
| 10 |
"classifier_dropout": 0,
|
modeling_m5_encoder.py
CHANGED
|
@@ -54,18 +54,59 @@ class M5EncoderConfig(T5Config):
|
|
| 54 |
|
| 55 |
class M5Encoder(PreTrainedModel):
|
| 56 |
config_class = M5EncoderConfig
|
|
|
|
| 57 |
|
| 58 |
def __init__(self, config):
|
| 59 |
super().__init__(config)
|
| 60 |
self.model = M5EncoderModel(config)
|
| 61 |
-
#self.model = torch.compile(self.model, mode="max-autotune", fullgraph=True)
|
| 62 |
|
| 63 |
def forward(self, input_ids, attention_mask=None, relative_position=None, **kwargs):
|
| 64 |
return self.model(input_ids=input_ids,
|
| 65 |
attention_mask=attention_mask,
|
| 66 |
relative_position=relative_position)
|
| 67 |
|
| 68 |
-
def get_positional_embeddings_and_align(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
return get_positional_encodings_and_align(smiles, token_regr, seed)
|
| 70 |
|
| 71 |
class M5EncoderModel(T5EncoderModel):
|
|
|
|
| 54 |
|
| 55 |
class M5Encoder(PreTrainedModel):
|
| 56 |
config_class = M5EncoderConfig
|
| 57 |
+
base_model_prefix = "encoder"
|
| 58 |
|
| 59 |
def __init__(self, config):
|
| 60 |
super().__init__(config)
|
| 61 |
self.model = M5EncoderModel(config)
|
|
|
|
| 62 |
|
| 63 |
def forward(self, input_ids, attention_mask=None, relative_position=None, **kwargs):
|
| 64 |
return self.model(input_ids=input_ids,
|
| 65 |
attention_mask=attention_mask,
|
| 66 |
relative_position=relative_position)
|
| 67 |
|
| 68 |
+
def get_positional_embeddings_and_align(
|
| 69 |
+
self,
|
| 70 |
+
smiles: str,
|
| 71 |
+
seed: int,
|
| 72 |
+
token_regr: Optional[np.ndarray] = None,
|
| 73 |
+
) -> tuple[str, np.ndarray, Optional[np.ndarray]]:
|
| 74 |
+
"""
|
| 75 |
+
Convert a SMILES string into a SELFIES tokenization, compute pairwise
|
| 76 |
+
molecular-graph distance encodings, and optionally align token-level
|
| 77 |
+
regression labels to the new token order.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
smiles: Input molecule as a SMILES string. Does not need to be
|
| 81 |
+
canonical — canonicalization and optional randomization are
|
| 82 |
+
applied internally.
|
| 83 |
+
seed: Epoch/seed value controlling SMILES augmentation. When 0,
|
| 84 |
+
the canonical SELFIES is used; any other value produces a
|
| 85 |
+
reproducible randomized SELFIES variant.
|
| 86 |
+
token_regr: Optional array for reproducibility.
|
| 87 |
+
Array of per-atom regression labels (e.g.
|
| 88 |
+
Löwdin charges) aligned to the original SMILES atom order.
|
| 89 |
+
If provided, labels are re-aligned to match the SELFIES token
|
| 90 |
+
order of the (possibly randomized) output SMILES.
|
| 91 |
+
Shape: ``(n_atoms,)``.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
A tuple of:
|
| 95 |
+
- **selfies** (``str``): SELFIES encoding of the (possibly
|
| 96 |
+
randomized) SMILES.
|
| 97 |
+
- **pos_encod** (``np.ndarray``): Pairwise distance matrix of
|
| 98 |
+
shape ``(seq_len, seq_len)`` with ``dtype=np.int16``. Entries
|
| 99 |
+
are shortest-path graph distances between atoms, capped at
|
| 100 |
+
``np.iinfo(np.int16).max - 1``. Special values: ``0`` for
|
| 101 |
+
CLS-to-token, token-to-CLS, and ring/dot-separated fragment
|
| 102 |
+
pairs; ``-1`` for intra-branch/ring structural tokens;
|
| 103 |
+
``np.iinfo(np.int16).max`` for padding positions.
|
| 104 |
+
- **token_regr_selfies** (``np.ndarray`` or ``None``): Labels
|
| 105 |
+
re-aligned to SELFIES token positions, shape
|
| 106 |
+
``(seq_len - 1,)``, with ``np.nan`` for non-atom tokens
|
| 107 |
+
(branches, rings, dots). ``None`` if ``token_regr`` was not
|
| 108 |
+
provided.
|
| 109 |
+
"""
|
| 110 |
return get_positional_encodings_and_align(smiles, token_regr, seed)
|
| 111 |
|
| 112 |
class M5EncoderModel(T5EncoderModel):
|