IlPakoZ commited on
Commit
eff6238
·
1 Parent(s): 76b03d7

Fixed weight loading, refactored model

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. modeling_m5_encoder.py +43 -2
config.json CHANGED
@@ -4,7 +4,7 @@
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "modeling_m5_encoder.M5EncoderConfig",
7
- "AutoModel": "modeling_m5_encoder.M5ModelForRegression",
8
  "AutoModelForSequenceClassification": "modeling_m5_encoder.M5ModelForRegression"
9
  },
10
  "classifier_dropout": 0,
 
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "modeling_m5_encoder.M5EncoderConfig",
7
+ "AutoModel": "modeling_m5_encoder.M5Model",
8
  "AutoModelForSequenceClassification": "modeling_m5_encoder.M5ModelForRegression"
9
  },
10
  "classifier_dropout": 0,
modeling_m5_encoder.py CHANGED
@@ -54,18 +54,59 @@ class M5EncoderConfig(T5Config):
54
 
55
  class M5Encoder(PreTrainedModel):
56
  config_class = M5EncoderConfig
 
57
 
58
  def __init__(self, config):
59
  super().__init__(config)
60
  self.model = M5EncoderModel(config)
61
- #self.model = torch.compile(self.model, mode="max-autotune", fullgraph=True)
62
 
63
  def forward(self, input_ids, attention_mask=None, relative_position=None, **kwargs):
64
  return self.model(input_ids=input_ids,
65
  attention_mask=attention_mask,
66
  relative_position=relative_position)
67
 
68
- def get_positional_embeddings_and_align(self, smiles, token_regr, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return get_positional_encodings_and_align(smiles, token_regr, seed)
70
 
71
  class M5EncoderModel(T5EncoderModel):
 
54
 
55
  class M5Encoder(PreTrainedModel):
56
  config_class = M5EncoderConfig
57
+ base_model_prefix = "encoder"
58
 
59
  def __init__(self, config):
60
  super().__init__(config)
61
  self.model = M5EncoderModel(config)
 
62
 
63
  def forward(self, input_ids, attention_mask=None, relative_position=None, **kwargs):
64
  return self.model(input_ids=input_ids,
65
  attention_mask=attention_mask,
66
  relative_position=relative_position)
67
 
68
+ def get_positional_embeddings_and_align(
69
+ self,
70
+ smiles: str,
71
+ seed: int,
72
+ token_regr: Optional[np.ndarray] = None,
73
+ ) -> tuple[str, np.ndarray, Optional[np.ndarray]]:
74
+ """
75
+ Convert a SMILES string into a SELFIES tokenization, compute pairwise
76
+ molecular-graph distance encodings, and optionally align token-level
77
+ regression labels to the new token order.
78
+
79
+ Args:
80
+ smiles: Input molecule as a SMILES string. Does not need to be
81
+ canonical — canonicalization and optional randomization are
82
+ applied internally.
83
+ seed: Epoch/seed value controlling SMILES augmentation. When 0,
84
+ the canonical SELFIES is used; any other value produces a
85
+ reproducible randomized SELFIES variant.
86
+ token_regr: Optional array for reproducibility.
87
+ Array of per-atom regression labels (e.g.
88
+ Löwdin charges) aligned to the original SMILES atom order.
89
+ If provided, labels are re-aligned to match the SELFIES token
90
+ order of the (possibly randomized) output SMILES.
91
+ Shape: ``(n_atoms,)``.
92
+
93
+ Returns:
94
+ A tuple of:
95
+ - **selfies** (``str``): SELFIES encoding of the (possibly
96
+ randomized) SMILES.
97
+ - **pos_encod** (``np.ndarray``): Pairwise distance matrix of
98
+ shape ``(seq_len, seq_len)`` with ``dtype=np.int16``. Entries
99
+ are shortest-path graph distances between atoms, capped at
100
+ ``np.iinfo(np.int16).max - 1``. Special values: ``0`` for
101
+ CLS-to-token, token-to-CLS, and ring/dot-separated fragment
102
+ pairs; ``-1`` for intra-branch/ring structural tokens;
103
+ ``np.iinfo(np.int16).max`` for padding positions.
104
+ - **token_regr_selfies** (``np.ndarray`` or ``None``): Labels
105
+ re-aligned to SELFIES token positions, shape
106
+ ``(seq_len - 1,)``, with ``np.nan`` for non-atom tokens
107
+ (branches, rings, dots). ``None`` if ``token_regr`` was not
108
+ provided.
109
+ """
110
  return get_positional_encodings_and_align(smiles, token_regr, seed)
111
 
112
  class M5EncoderModel(T5EncoderModel):