Spaces:
Running
Running
Update PolyFusion/Transformer.py
Browse files- PolyFusion/Transformer.py +4 -11
PolyFusion/Transformer.py
CHANGED
|
@@ -1,12 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
Transformer.py
|
| 3 |
Fingerprint masked language modeling (MLM) using a Transformer encoder.
|
| 4 |
-
|
| 5 |
-
This file provides (and uses internally):
|
| 6 |
-
- PooledFingerprintEncoder (used by CL.py AND used for MLM training here)
|
| 7 |
-
* forward(...) returns pooled embedding if labels are None (for CL.py)
|
| 8 |
-
* forward(...) returns loss if labels provided (Trainer-compatible for MLM)
|
| 9 |
-
* token_logits(...) returns per-token logits for reconstruction in CL.py
|
| 10 |
"""
|
| 11 |
|
| 12 |
from __future__ import annotations
|
|
@@ -237,15 +231,15 @@ class FingerprintEncoder(nn.Module):
|
|
| 237 |
|
| 238 |
|
| 239 |
# =============================================================================
|
| 240 |
-
# Wrapper used
|
| 241 |
# =============================================================================
|
| 242 |
|
| 243 |
class PooledFingerprintEncoder(nn.Module):
|
| 244 |
"""
|
| 245 |
Dual-use:
|
| 246 |
-
- labels is None -> return pooled embedding (B, emb_dim)
|
| 247 |
-
- labels provided -> return loss scalar
|
| 248 |
-
Also provides token_logits(...) used
|
| 249 |
"""
|
| 250 |
|
| 251 |
def __init__(
|
|
@@ -434,7 +428,6 @@ def train_and_eval(args: argparse.Namespace) -> None:
|
|
| 434 |
num_workers=args.num_workers,
|
| 435 |
)
|
| 436 |
|
| 437 |
-
# Use wrapper so it's also used inside this file (not just for CL.py)
|
| 438 |
model = PooledFingerprintEncoder(
|
| 439 |
vocab_size=VOCAB_SIZE,
|
| 440 |
hidden_dim=HIDDEN_DIM,
|
|
|
|
| 1 |
"""
|
| 2 |
Transformer.py
|
| 3 |
Fingerprint masked language modeling (MLM) using a Transformer encoder.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
|
|
|
| 231 |
|
| 232 |
|
| 233 |
# =============================================================================
|
| 234 |
+
# Wrapper used for MLM training
|
| 235 |
# =============================================================================
|
| 236 |
|
| 237 |
class PooledFingerprintEncoder(nn.Module):
|
| 238 |
"""
|
| 239 |
Dual-use:
|
| 240 |
+
- labels is None -> return pooled embedding (B, emb_dim)
|
| 241 |
+
- labels provided -> return loss scalar [Trainer-compatible MLM]
|
| 242 |
+
Also provides token_logits(...) used for reconstruction.
|
| 243 |
"""
|
| 244 |
|
| 245 |
def __init__(
|
|
|
|
| 428 |
num_workers=args.num_workers,
|
| 429 |
)
|
| 430 |
|
|
|
|
| 431 |
model = PooledFingerprintEncoder(
|
| 432 |
vocab_size=VOCAB_SIZE,
|
| 433 |
hidden_dim=HIDDEN_DIM,
|