IlPakoZ commited on
Commit
01a9e83
·
1 Parent(s): 86b3b93

Added data_collater function, refactor

Browse files
Files changed (2) hide show
  1. README.md +28 -1
  2. modeling_m5_encoder.py +104 -3
README.md CHANGED
@@ -40,6 +40,33 @@ model = AutoModelForSequenceClassification.from_pretrained(
40
  )
41
  ```
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  ## Architecture
45
 
@@ -114,7 +141,7 @@ The processed dataset contains **82,686,706 SMILES sequences**, each paired with
114
  | Split | Sequences | Tokens (approx.) |
115
  |---|---|---|
116
  | Train | 66,149,364 | ~2.5 B (×2 with augmentation → ~5 B) |
117
- | Validation | 8,268,673 | |
118
  | Test | 8,268,669 | ~ 0.82 B (×2 with augmentation → ~1.64 B) |
119
 
120
  Training augmentation generates randomized SELFIES on the fly from each SMILES. Labels are normalized before training.
 
40
  )
41
  ```
42
 
43
+ ### Preparing inputs
44
+
45
+ Inputs require SELFIES tokenization **and** a precomputed distance matrix
46
+ (`relative_position`). Use the helper bundled in the repo:
47
+
48
+ ```python
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained("IlPakoZ/m5-encoder", trust_remote_code=True)
51
+
52
+ smiles = "CCO"
53
+
54
+ # seed = 0 produces the canonical SELFIES, other values generate random reproducible variations
55
+ selfies, pos_encod, _ = model.get_positional_encodings_and_align(smiles, seed=0)
56
+
57
+ encoding = tokenizer(selfies, return_tensors="pt")
58
+ input_ids = encoding["input_ids"]
59
+ attn_mask = encoding["attention_mask"]
60
+ rel_pos = torch.tensor(pos_encod).unsqueeze(0) # (1, seq_len, seq_len)
61
+
62
+ outputs = model(input_ids=input_ids, attention_mask=attn_mask, relative_position=rel_pos)
63
+ hidden = outputs.last_hidden_state # (1, seq_len, 512)
64
+ ```
65
+
66
+ A function ``model.collate_for_dataset`` is also available to perform collation for use in Pytorch's DataLoader. The function gets a list of tuples, each of which is composed of:
67
+ - the first element is a dictionary with keys ``"input_ids"`` (``np.ndarray``, shape ``(L,)``) and ``"attention_mask"`` (``np.ndarray``, shape ``(L,)``), as produced by a tokenizer
68
+ - the second element contains the positional embedding matrix;
69
+ - (optional) token regression labels. This is maintained mostly for reproducibility of our paper's results, but it can be left to None in most circumstances.
70
 
71
  ## Architecture
72
 
 
141
  | Split | Sequences | Tokens (approx.) |
142
  |---|---|---|
143
  | Train | 66,149,364 | ~2.5 B (×2 with augmentation → ~5 B) |
144
+ | Validation | 8,268,673 | tbd |
145
  | Test | 8,268,669 | ~ 0.82 B (×2 with augmentation → ~1.64 B) |
146
 
147
  Training augmentation generates randomized SELFIES on the fly from each SMILES. Labels are normalized before training.
modeling_m5_encoder.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import math
4
  import logging
5
 
6
- from typing import Optional, Union
7
  import torch.nn as nn
8
  from transformers import PreTrainedModel, T5EncoderModel, T5ForConditionalGeneration, T5ForQuestionAnswering, T5ForTokenClassification, T5Model, load_tf_weights_in_t5
9
  from torch import nn
@@ -64,9 +64,9 @@ class M5Encoder(PreTrainedModel):
64
  return self.model(input_ids=input_ids,
65
  attention_mask=attention_mask,
66
  relative_position=relative_position)
67
-
 
68
  def get_positional_encodings_and_align(
69
- self,
70
  smiles: str,
71
  seed: int,
72
  token_regr: Optional[np.ndarray] = None,
@@ -107,7 +107,107 @@ class M5Encoder(PreTrainedModel):
107
  (branches, rings, dots). ``None`` if ``token_regr`` was not
108
  provided.
109
  """
 
110
  return get_positional_encodings_and_align(smiles, token_regr, seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  class M5EncoderModel(T5EncoderModel):
113
  def __init__(self, config: T5Config):
@@ -161,6 +261,7 @@ class M5EncoderModel(T5EncoderModel):
161
  input_ids=input_ids,
162
  attention_mask=attention_mask,
163
  inputs_embeds=inputs_embeds,
 
164
  head_mask=head_mask,
165
  output_attentions=output_attentions,
166
  output_hidden_states=output_hidden_states,
 
3
  import math
4
  import logging
5
 
6
+ from typing import Any, Optional, Union, Sequence
7
  import torch.nn as nn
8
  from transformers import PreTrainedModel, T5EncoderModel, T5ForConditionalGeneration, T5ForQuestionAnswering, T5ForTokenClassification, T5Model, load_tf_weights_in_t5
9
  from torch import nn
 
64
  return self.model(input_ids=input_ids,
65
  attention_mask=attention_mask,
66
  relative_position=relative_position)
67
+
68
+ @staticmethod
69
  def get_positional_encodings_and_align(
 
70
  smiles: str,
71
  seed: int,
72
  token_regr: Optional[np.ndarray] = None,
 
107
  (branches, rings, dots). ``None`` if ``token_regr`` was not
108
  provided.
109
  """
110
+
111
  return get_positional_encodings_and_align(smiles, token_regr, seed)
112
+
113
+ @staticmethod
114
+ def collate_for_dataset(batch: list[dict[str, Any]], n_global_regr: int = 0, PAD_TOKEN_ID: int = 2):
115
+ """
116
+ Collate processed data for pytorch dataloaders.
117
+
118
+ Each item in ``batch`` is a 3-tuple ``(token_dict, pos_encod, reg)``
119
+ where:
120
+
121
+ - ``token_dict`` is a dict with keys ``"input_ids"`` (``np.ndarray``,
122
+ shape ``(L,)``) and ``"attention_mask"`` (``np.ndarray``, shape
123
+ ``(L,)``), as produced by a tokenizer.
124
+ - ``pos_encod`` is an ``np.ndarray`` of shape ``(L, L)`` and dtype
125
+ ``np.int16`` holding pairwise molecular-graph distances, as returned
126
+ by :meth:`get_positional_encodings_and_align`.
127
+ - ``reg`` is an ``np.ndarray`` of shape
128
+ ``(n_global_regr + L - 1,)`` containing first the
129
+ ``n_global_regr`` sequence-level regression targets followed by
130
+ ``L - 1`` token-level targets (one per non-CLS token). Ignored when
131
+ ``n_global_regr == 0``.
132
+
133
+ All sequences are right-padded to the length of the longest sequence
134
+ in the batch (``L_max``):
135
+
136
+ - ``input_ids`` is padded with ``PAD_TOKEN_ID``.
137
+ - ``attention_mask`` is padded with ``0``.
138
+ - ``pos_encod`` is padded with ``np.iinfo(np.int16).max``; the
139
+ diagonal of the padded region is set to ``0`` to be consistent with
140
+ real token self-distances.
141
+ - ``labels`` (when present) is padded with ``float("nan")`` so that
142
+ padding positions can be masked out in the loss.
143
+
144
+ Args:
145
+ batch: List of ``(token_dict, pos_encod, reg)`` tuples, one per
146
+ sample.
147
+ n_global_regr: Number of sequence-level regression targets at the
148
+ start of each ``reg`` array. When ``0``, no ``"labels"`` key
149
+ is included in the returned dict.
150
+ PAD_TOKEN_ID: Token id used to fill padded positions in
151
+ ``input_ids``. Defaults to ``2``.
152
+
153
+ Returns:
154
+ A dict with the following keys:
155
+
156
+ - ``"input_ids"`` — ``torch.LongTensor`` of shape
157
+ ``(B, L_max)``.
158
+ - ``"attention_mask"`` — ``torch.LongTensor`` of shape
159
+ ``(B, L_max)``; ``1`` for real tokens, ``0`` for padding.
160
+ - ``"positional_encodings"`` — ``torch.ShortTensor`` of shape
161
+ ``(B, L_max, L_max)``.
162
+ - ``"labels"`` *(only when* ``n_global_regr > 0`` *)* —
163
+ ``torch.FloatTensor`` of shape
164
+ ``(B, n_global_regr + L_max - 1)``; ``nan`` for padding
165
+ positions.
166
+ """
167
+ token_dicts, pos_encod, regs = zip(*batch)
168
+ lengths = [td["input_ids"].shape[0] for td in token_dicts]
169
+ L_max = max(lengths)
170
+ B = len(batch)
171
+
172
+ input_ids_out = np.full((B, L_max), PAD_TOKEN_ID, dtype=np.int64)
173
+ attn_mask_out = np.zeros((B, L_max), dtype=np.int64)
174
+ pos_encod_out = np.full((B, L_max, L_max), np.iinfo(np.int16).max, dtype=np.int16)
175
+
176
+ if n_global_regr > 0:
177
+ reg_out = np.full((B, n_global_regr + L_max - 1), float("nan"), dtype=np.float32)
178
+
179
+ # Set diagonal to 0 up-front for the full L_max grid; individual items
180
+ # already have their diagonal zeroed — this covers the padded extension.
181
+ diag_idx = np.arange(L_max)
182
+ pos_encod_out[:, diag_idx, diag_idx] = 0
183
+
184
+ for i, (td, pe, reg) in enumerate(zip(token_dicts, pos_encod, regs)):
185
+ L = lengths[i]
186
+
187
+ # Token ids & attention mask
188
+ input_ids_out[i, :L] = td["input_ids"]
189
+ attn_mask_out[i, :L] = td["attention_mask"]
190
+
191
+ # Positional embedding (L x L)
192
+ pos_encod_out[i, :L, :L] = pe
193
+
194
+ # Regression: global part + token part (length L - 1, excluding CLS)
195
+ if n_global_regr > 0:
196
+ reg_out[i, :n_global_regr] = reg[:n_global_regr]
197
+ reg_out[i, n_global_regr:n_global_regr + L - 1] = reg[n_global_regr:]
198
+
199
+ out = {
200
+ "input_ids": torch.from_numpy(input_ids_out),
201
+ "attention_mask": torch.from_numpy(attn_mask_out),
202
+ "positional_encodings": torch.from_numpy(pos_encod_out),
203
+ }
204
+
205
+ if n_global_regr > 0:
206
+ out["labels"] = torch.from_numpy(reg_out)
207
+
208
+ return out
209
+
210
+
211
 
212
  class M5EncoderModel(T5EncoderModel):
213
  def __init__(self, config: T5Config):
 
261
  input_ids=input_ids,
262
  attention_mask=attention_mask,
263
  inputs_embeds=inputs_embeds,
264
+
265
  head_mask=head_mask,
266
  output_attentions=output_attentions,
267
  output_hidden_states=output_hidden_states,