kaurm43 commited on
Commit
e6d5395
·
verified ·
1 Parent(s): 735362f

Update PolyFusion/Transformer.py

Browse files
Files changed (1) hide show
  1. PolyFusion/Transformer.py +125 -77
PolyFusion/Transformer.py CHANGED
@@ -1,14 +1,23 @@
1
  """
 
2
  Fingerprint masked language modeling (MLM) using a Transformer encoder.
 
 
 
 
 
 
3
  """
4
 
 
 
5
  import os
6
  import json
7
  import time
8
  import sys
9
  import csv
10
  import argparse
11
- from typing import List
12
 
13
  # Increase max CSV field size limit (fingerprints can be long)
14
  csv.field_size_limit(sys.maxsize)
@@ -130,6 +139,7 @@ def load_fingerprints(csv_path: str, target_rows: int, chunksize: int) -> List[L
130
 
131
  class FingerprintDataset(Dataset):
132
  """Dataset of fixed-length fingerprint bit vectors (stored as torch.long tensors)."""
 
133
  def __init__(self, fps: List[torch.Tensor]):
134
  self.fps = fps
135
 
@@ -150,8 +160,8 @@ def collate_batch(batch):
150
  B = len(batch)
151
  if B == 0:
152
  return {
153
- "z": torch.zeros((0, FP_LENGTH), dtype=torch.long),
154
- "labels_z": torch.zeros((0, FP_LENGTH), dtype=torch.long),
155
  "attention_mask": torch.zeros((0, FP_LENGTH), dtype=torch.bool),
156
  }
157
 
@@ -159,35 +169,11 @@ def collate_batch(batch):
159
  for item in batch:
160
  if isinstance(item, torch.Tensor):
161
  tensors.append(item)
162
- elif isinstance(item, dict):
163
- if "fp" in item:
164
- val = item["fp"]
165
- if not isinstance(val, torch.Tensor):
166
- val = torch.tensor(val, dtype=torch.long)
167
- tensors.append(val)
168
- else:
169
- found = None
170
- for v in item.values():
171
- if isinstance(v, torch.Tensor):
172
- found = v
173
- break
174
- elif isinstance(v, np.ndarray):
175
- found = torch.tensor(v, dtype=torch.long)
176
- break
177
- elif isinstance(v, list):
178
- try:
179
- found = torch.tensor(v, dtype=torch.long)
180
- break
181
- except Exception:
182
- continue
183
- if found is None:
184
- raise KeyError(f"collate_batch: couldn't find tensor-like fp in item keys: {list(item.keys())}")
185
- tensors.append(found)
186
  else:
187
  tensors.append(torch.tensor(item, dtype=torch.long))
188
 
189
  all_inputs = torch.stack(tensors, dim=0).long()
190
- labels_z = torch.full_like(all_inputs, fill_value=-100, dtype=torch.long)
191
  z_masked = all_inputs.clone()
192
 
193
  for i in range(B):
@@ -199,7 +185,7 @@ def collate_batch(batch):
199
 
200
  sel_idx = torch.nonzero(is_selected).squeeze(-1)
201
  if sel_idx.numel() > 0:
202
- labels_z[i, sel_idx] = z[sel_idx]
203
 
204
  probs = torch.rand(sel_idx.size(0))
205
  mask_choice = probs < 0.8
@@ -212,14 +198,22 @@ def collate_batch(batch):
212
  z_masked[i, sel_idx[rand_choice]] = rand_bits
213
 
214
  attention_mask = torch.ones_like(all_inputs, dtype=torch.bool)
215
- return {"z": z_masked, "labels_z": labels_z, "attention_mask": attention_mask}
216
 
217
 
218
  class FingerprintEncoder(nn.Module):
219
  """Transformer encoder over a length-FP_LENGTH token sequence with small vocab {0,1,MASK}."""
220
- def __init__(self, vocab_size=VOCAB_SIZE, hidden_dim=HIDDEN_DIM, seq_len=FP_LENGTH,
221
- num_layers=TRANSFORMER_NUM_LAYERS, nhead=TRANSFORMER_NHEAD, dim_feedforward=TRANSFORMER_FF,
222
- dropout=DROPOUT):
 
 
 
 
 
 
 
 
223
  super().__init__()
224
  self.token_emb = nn.Embedding(vocab_size, hidden_dim)
225
  self.pos_emb = nn.Embedding(seq_len, hidden_dim)
@@ -242,31 +236,73 @@ class FingerprintEncoder(nn.Module):
242
  return self.transformer(x, src_key_padding_mask=key_padding_mask)
243
 
244
 
245
- class MaskedFingerprintModel(nn.Module):
246
- """Encoder + token classification head; returns scalar loss when labels_z provided."""
247
- def __init__(self, hidden_dim=HIDDEN_DIM, vocab_size=VOCAB_SIZE):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  super().__init__()
249
- self.encoder = FingerprintEncoder(vocab_size=vocab_size, hidden_dim=hidden_dim)
 
 
 
 
 
 
 
 
250
  self.mlm_head = nn.Linear(hidden_dim, vocab_size)
 
251
 
252
- def forward(self, z, attention_mask=None, labels_z=None):
253
- embeddings = self.encoder(z, attention_mask=attention_mask)
254
- logits = self.mlm_head(embeddings)
 
 
 
255
 
256
- if labels_z is not None:
257
- mask = labels_z != -100
258
- if mask.sum() == 0:
259
- return torch.tensor(0.0, device=z.device)
 
 
260
 
 
 
 
 
261
  logits_masked = logits[mask]
262
- labels_masked = labels_z[mask].long()
263
  return F.cross_entropy(logits_masked, labels_masked)
264
 
265
- return logits
 
 
 
266
 
267
 
268
  class ValLossCallback(TrainerCallback):
269
  """Tracks best eval loss, prints metrics, saves best model, early-stops."""
 
270
  def __init__(self, best_model_dir: str, val_loader: DataLoader, patience: int = 10, trainer_ref=None):
271
  self.best_val_loss = float("inf")
272
  self.epochs_no_improve = 0
@@ -301,12 +337,12 @@ class ValLossCallback(TrainerCallback):
301
 
302
  with torch.no_grad():
303
  for batch in self.val_loader:
304
- z = batch["z"].to(device_local)
305
- labels_z = batch["labels_z"].to(device_local)
306
- attention_mask = batch.get("attention_mask", torch.ones_like(z, dtype=torch.bool)).to(device_local)
307
 
308
  try:
309
- loss = model_eval(z, attention_mask=attention_mask, labels_z=labels_z)
310
  except Exception:
311
  loss = None
312
 
@@ -314,17 +350,16 @@ class ValLossCallback(TrainerCallback):
314
  total_loss += loss.item()
315
  n_batches += 1
316
 
317
- logits = model_eval(z, attention_mask=attention_mask)
318
-
319
- mask = labels_z != -100
320
  if mask.sum().item() == 0:
321
  continue
322
 
323
  logits_masked_list.append(logits[mask])
324
- labels_masked_list.append(labels_z[mask])
325
 
326
  pred_bits = torch.argmax(logits[mask], dim=-1)
327
- true_b = labels_z[mask]
328
 
329
  preds_bits.extend(pred_bits.cpu().tolist())
330
  true_bits.extend(true_b.cpu().tolist())
@@ -379,25 +414,38 @@ def train_and_eval(args: argparse.Namespace) -> None:
379
  train_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in train_idx]
380
  val_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in val_idx]
381
 
382
- # Compute class weights
383
- counts = np.ones((2,), dtype=np.float64)
384
- for fp in train_fps:
385
- vals = fp.cpu().numpy().astype(int)
386
- counts[0] += np.sum(vals == 0)
387
- counts[1] += np.sum(vals == 1)
388
- freq = counts / counts.sum()
389
- inv_freq = 1.0 / (freq + 1e-12)
390
- class_weights_arr = inv_freq / inv_freq.mean()
391
- class_weights = torch.tensor(class_weights_arr, dtype=torch.float)
392
- print("Class weights (for bit 0 and bit 1):", class_weights.numpy())
393
-
394
  train_dataset = FingerprintDataset(train_fps)
395
  val_dataset = FingerprintDataset(val_fps)
396
 
397
- train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate_batch, drop_last=False, num_workers=args.num_workers)
398
- val_loader = DataLoader(val_dataset, batch_size=EVAL_BATCH_SIZE, shuffle=False, collate_fn=collate_batch, drop_last=False, num_workers=args.num_workers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
- model = MaskedFingerprintModel(hidden_dim=HIDDEN_DIM, vocab_size=VOCAB_SIZE)
401
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
402
  model.to(device)
403
 
@@ -451,21 +499,21 @@ def train_and_eval(args: argparse.Namespace) -> None:
451
 
452
  with torch.no_grad():
453
  for batch in val_loader:
454
- z = batch["z"].to(device)
455
- labels_z = batch["labels_z"].to(device)
456
- attention_mask = batch.get("attention_mask", torch.ones_like(z, dtype=torch.bool)).to(device)
457
 
458
- logits = model(z, attention_mask=attention_mask)
459
 
460
- mask = labels_z != -100
461
  if mask.sum().item() == 0:
462
  continue
463
 
464
  logits_masked_final.append(logits[mask])
465
- labels_masked_final.append(labels_z[mask])
466
 
467
  pred_bits = torch.argmax(logits[mask], dim=-1)
468
- true_b = labels_z[mask]
469
 
470
  preds_bits_all.extend(pred_bits.cpu().tolist())
471
  true_bits_all.extend(true_b.cpu().tolist())
 
1
  """
2
+ Transformer.py
3
  Fingerprint masked language modeling (MLM) using a Transformer encoder.
4
+
5
+ This file provides (and uses internally):
6
+ - PooledFingerprintEncoder (used by CL.py AND used for MLM training here)
7
+ * forward(...) returns pooled embedding if labels are None (for CL.py)
8
+ * forward(...) returns loss if labels provided (Trainer-compatible for MLM)
9
+ * token_logits(...) returns per-token logits for reconstruction in CL.py
10
  """
11
 
12
+ from __future__ import annotations
13
+
14
  import os
15
  import json
16
  import time
17
  import sys
18
  import csv
19
  import argparse
20
+ from typing import List, Optional
21
 
22
  # Increase max CSV field size limit (fingerprints can be long)
23
  csv.field_size_limit(sys.maxsize)
 
139
 
140
  class FingerprintDataset(Dataset):
141
  """Dataset of fixed-length fingerprint bit vectors (stored as torch.long tensors)."""
142
+
143
  def __init__(self, fps: List[torch.Tensor]):
144
  self.fps = fps
145
 
 
160
  B = len(batch)
161
  if B == 0:
162
  return {
163
+ "input_ids": torch.zeros((0, FP_LENGTH), dtype=torch.long),
164
+ "labels": torch.zeros((0, FP_LENGTH), dtype=torch.long),
165
  "attention_mask": torch.zeros((0, FP_LENGTH), dtype=torch.bool),
166
  }
167
 
 
169
  for item in batch:
170
  if isinstance(item, torch.Tensor):
171
  tensors.append(item)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  else:
173
  tensors.append(torch.tensor(item, dtype=torch.long))
174
 
175
  all_inputs = torch.stack(tensors, dim=0).long()
176
+ labels = torch.full_like(all_inputs, fill_value=-100, dtype=torch.long)
177
  z_masked = all_inputs.clone()
178
 
179
  for i in range(B):
 
185
 
186
  sel_idx = torch.nonzero(is_selected).squeeze(-1)
187
  if sel_idx.numel() > 0:
188
+ labels[i, sel_idx] = z[sel_idx]
189
 
190
  probs = torch.rand(sel_idx.size(0))
191
  mask_choice = probs < 0.8
 
198
  z_masked[i, sel_idx[rand_choice]] = rand_bits
199
 
200
  attention_mask = torch.ones_like(all_inputs, dtype=torch.bool)
201
+ return {"input_ids": z_masked, "labels": labels, "attention_mask": attention_mask}
202
 
203
 
204
  class FingerprintEncoder(nn.Module):
205
  """Transformer encoder over a length-FP_LENGTH token sequence with small vocab {0,1,MASK}."""
206
+
207
+ def __init__(
208
+ self,
209
+ vocab_size=VOCAB_SIZE,
210
+ hidden_dim=HIDDEN_DIM,
211
+ seq_len=FP_LENGTH,
212
+ num_layers=TRANSFORMER_NUM_LAYERS,
213
+ nhead=TRANSFORMER_NHEAD,
214
+ dim_feedforward=TRANSFORMER_FF,
215
+ dropout=DROPOUT,
216
+ ):
217
  super().__init__()
218
  self.token_emb = nn.Embedding(vocab_size, hidden_dim)
219
  self.pos_emb = nn.Embedding(seq_len, hidden_dim)
 
236
  return self.transformer(x, src_key_padding_mask=key_padding_mask)
237
 
238
 
239
+ # =============================================================================
240
+ # Wrapper used by CL.py AND used here for MLM training
241
+ # =============================================================================
242
+
243
+ class PooledFingerprintEncoder(nn.Module):
244
+ """
245
+ Dual-use:
246
+ - labels is None -> return pooled embedding (B, emb_dim) [for CL.py]
247
+ - labels provided -> return loss scalar [Trainer-compatible MLM]
248
+ Also provides token_logits(...) used by CL.py reconstruction.
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ vocab_size=VOCAB_SIZE,
254
+ hidden_dim=HIDDEN_DIM,
255
+ seq_len=FP_LENGTH,
256
+ num_layers=TRANSFORMER_NUM_LAYERS,
257
+ nhead=TRANSFORMER_NHEAD,
258
+ dim_feedforward=TRANSFORMER_FF,
259
+ dropout=DROPOUT,
260
+ emb_dim: int = 600,
261
+ ):
262
  super().__init__()
263
+ self.encoder = FingerprintEncoder(
264
+ vocab_size=vocab_size,
265
+ hidden_dim=hidden_dim,
266
+ seq_len=seq_len,
267
+ num_layers=num_layers,
268
+ nhead=nhead,
269
+ dim_feedforward=dim_feedforward,
270
+ dropout=dropout,
271
+ )
272
  self.mlm_head = nn.Linear(hidden_dim, vocab_size)
273
+ self.pool_proj = nn.Linear(hidden_dim, emb_dim)
274
 
275
+ def _pool(self, h, attention_mask=None):
276
+ if attention_mask is None:
277
+ return h.mean(dim=1)
278
+ mask = attention_mask.unsqueeze(-1).float()
279
+ denom = mask.sum(dim=1).clamp(min=1.0)
280
+ return (h * mask).sum(dim=1) / denom
281
 
282
+ def token_logits(self, input_ids, attention_mask=None):
283
+ h = self.encoder(input_ids, attention_mask=attention_mask)
284
+ return self.mlm_head(h)
285
+
286
+ def forward(self, input_ids, attention_mask=None, labels=None):
287
+ logits = self.token_logits(input_ids, attention_mask=attention_mask)
288
 
289
+ if labels is not None:
290
+ mask = labels != -100
291
+ if mask.sum() == 0:
292
+ return torch.tensor(0.0, device=input_ids.device)
293
  logits_masked = logits[mask]
294
+ labels_masked = labels[mask].long()
295
  return F.cross_entropy(logits_masked, labels_masked)
296
 
297
+ # pooled embedding for CL
298
+ h = self.encoder(input_ids, attention_mask=attention_mask)
299
+ pooled = self._pool(h, attention_mask=attention_mask)
300
+ return self.pool_proj(pooled)
301
 
302
 
303
  class ValLossCallback(TrainerCallback):
304
  """Tracks best eval loss, prints metrics, saves best model, early-stops."""
305
+
306
  def __init__(self, best_model_dir: str, val_loader: DataLoader, patience: int = 10, trainer_ref=None):
307
  self.best_val_loss = float("inf")
308
  self.epochs_no_improve = 0
 
337
 
338
  with torch.no_grad():
339
  for batch in self.val_loader:
340
+ input_ids = batch["input_ids"].to(device_local)
341
+ labels = batch["labels"].to(device_local)
342
+ attention_mask = batch.get("attention_mask", torch.ones_like(input_ids, dtype=torch.bool)).to(device_local)
343
 
344
  try:
345
+ loss = model_eval(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
346
  except Exception:
347
  loss = None
348
 
 
350
  total_loss += loss.item()
351
  n_batches += 1
352
 
353
+ logits = model_eval.token_logits(input_ids=input_ids, attention_mask=attention_mask)
354
+ mask = labels != -100
 
355
  if mask.sum().item() == 0:
356
  continue
357
 
358
  logits_masked_list.append(logits[mask])
359
+ labels_masked_list.append(labels[mask])
360
 
361
  pred_bits = torch.argmax(logits[mask], dim=-1)
362
+ true_b = labels[mask]
363
 
364
  preds_bits.extend(pred_bits.cpu().tolist())
365
  true_bits.extend(true_b.cpu().tolist())
 
414
  train_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in train_idx]
415
  val_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in val_idx]
416
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  train_dataset = FingerprintDataset(train_fps)
418
  val_dataset = FingerprintDataset(val_fps)
419
 
420
+ train_loader = DataLoader(
421
+ train_dataset,
422
+ batch_size=TRAIN_BATCH_SIZE,
423
+ shuffle=True,
424
+ collate_fn=collate_batch,
425
+ drop_last=False,
426
+ num_workers=args.num_workers,
427
+ )
428
+ val_loader = DataLoader(
429
+ val_dataset,
430
+ batch_size=EVAL_BATCH_SIZE,
431
+ shuffle=False,
432
+ collate_fn=collate_batch,
433
+ drop_last=False,
434
+ num_workers=args.num_workers,
435
+ )
436
+
437
+ # Use wrapper so it's also used inside this file (not just for CL.py)
438
+ model = PooledFingerprintEncoder(
439
+ vocab_size=VOCAB_SIZE,
440
+ hidden_dim=HIDDEN_DIM,
441
+ seq_len=FP_LENGTH,
442
+ num_layers=TRANSFORMER_NUM_LAYERS,
443
+ nhead=TRANSFORMER_NHEAD,
444
+ dim_feedforward=TRANSFORMER_FF,
445
+ dropout=DROPOUT,
446
+ emb_dim=600,
447
+ )
448
 
 
449
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
450
  model.to(device)
451
 
 
499
 
500
  with torch.no_grad():
501
  for batch in val_loader:
502
+ input_ids = batch["input_ids"].to(device)
503
+ labels = batch["labels"].to(device)
504
+ attention_mask = batch.get("attention_mask", torch.ones_like(input_ids, dtype=torch.bool)).to(device)
505
 
506
+ logits = model.token_logits(input_ids=input_ids, attention_mask=attention_mask)
507
 
508
+ mask = labels != -100
509
  if mask.sum().item() == 0:
510
  continue
511
 
512
  logits_masked_final.append(logits[mask])
513
+ labels_masked_final.append(labels[mask])
514
 
515
  pred_bits = torch.argmax(logits[mask], dim=-1)
516
+ true_b = labels[mask]
517
 
518
  preds_bits_all.extend(pred_bits.cpu().tolist())
519
  true_bits_all.extend(true_b.cpu().tolist())