kaurm43 commited on
Commit
b5ed4b6
·
verified ·
1 Parent(s): 47269bd

Update PolyFusion/Transformer.py

Browse files
Files changed (1) hide show
  1. PolyFusion/Transformer.py +4 -11
PolyFusion/Transformer.py CHANGED
@@ -1,12 +1,6 @@
1
  """
2
  Transformer.py
3
  Fingerprint masked language modeling (MLM) using a Transformer encoder.
4
-
5
- This file provides (and uses internally):
6
- - PooledFingerprintEncoder (used by CL.py AND used for MLM training here)
7
- * forward(...) returns pooled embedding if labels are None (for CL.py)
8
- * forward(...) returns loss if labels provided (Trainer-compatible for MLM)
9
- * token_logits(...) returns per-token logits for reconstruction in CL.py
10
  """
11
 
12
  from __future__ import annotations
@@ -237,15 +231,15 @@ class FingerprintEncoder(nn.Module):
237
 
238
 
239
  # =============================================================================
240
- # Wrapper used by CL.py AND used here for MLM training
241
  # =============================================================================
242
 
243
  class PooledFingerprintEncoder(nn.Module):
244
  """
245
  Dual-use:
246
- - labels is None -> return pooled embedding (B, emb_dim) [for CL.py]
247
- - labels provided -> return loss scalar [Trainer-compatible MLM]
248
- Also provides token_logits(...) used by CL.py reconstruction.
249
  """
250
 
251
  def __init__(
@@ -434,7 +428,6 @@ def train_and_eval(args: argparse.Namespace) -> None:
434
  num_workers=args.num_workers,
435
  )
436
 
437
- # Use wrapper so it's also used inside this file (not just for CL.py)
438
  model = PooledFingerprintEncoder(
439
  vocab_size=VOCAB_SIZE,
440
  hidden_dim=HIDDEN_DIM,
 
1
  """
2
  Transformer.py
3
  Fingerprint masked language modeling (MLM) using a Transformer encoder.
 
 
 
 
 
 
4
  """
5
 
6
  from __future__ import annotations
 
231
 
232
 
233
  # =============================================================================
234
+ # Wrapper used for MLM training
235
  # =============================================================================
236
 
237
  class PooledFingerprintEncoder(nn.Module):
238
  """
239
  Dual-use:
240
+ - labels is None -> return pooled embedding (B, emb_dim)
241
+ - labels provided -> return loss scalar [Trainer-compatible MLM]
242
+ Also provides token_logits(...) used for reconstruction.
243
  """
244
 
245
  def __init__(
 
428
  num_workers=args.num_workers,
429
  )
430
 
 
431
  model = PooledFingerprintEncoder(
432
  vocab_size=VOCAB_SIZE,
433
  hidden_dim=HIDDEN_DIM,