Spaces:
Running
Running
Update PolyFusion/DeBERTav2.py
Browse files- PolyFusion/DeBERTav2.py +6 -30
PolyFusion/DeBERTav2.py
CHANGED
|
@@ -1,14 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
DeBERTav2.py
|
| 3 |
DeBERTaV2 masked language modeling pretraining for polymer SMILES (PSMILES).
|
| 4 |
-
|
| 5 |
-
This file provides:
|
| 6 |
-
- build_psmiles_tokenizer(spm_path, max_len)
|
| 7 |
-
- PSMILESDebertaEncoder: a dual-use wrapper
|
| 8 |
-
* If labels provided -> behaves like MLM model (HF Trainer compatible)
|
| 9 |
-
* If labels not provided -> returns pooled embedding (for CL.py)
|
| 10 |
-
* token_logits(...) helper for reconstruction in CL.py
|
| 11 |
-
- End-to-end MLM training utilities (kept aligned with your original script)
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
|
@@ -142,7 +134,6 @@ def train_sentencepiece_if_needed(train_txt: str, spm_model_prefix: str, vocab_s
|
|
| 142 |
|
| 143 |
def build_psmiles_tokenizer(spm_path: str, max_len: int = 128):
|
| 144 |
"""
|
| 145 |
-
Build tokenizer exactly as CL.py expects.
|
| 146 |
Uses SentencePiece-backed DebertaV2Tokenizer.
|
| 147 |
"""
|
| 148 |
from transformers import DebertaV2Tokenizer
|
|
@@ -333,18 +324,15 @@ def compute_metrics(eval_pred):
|
|
| 333 |
|
| 334 |
|
| 335 |
# =============================================================================
|
| 336 |
-
# Encoder wrapper
|
| 337 |
# =============================================================================
|
| 338 |
|
| 339 |
class PSMILESDebertaEncoder:
|
| 340 |
"""
|
| 341 |
Dual-use wrapper:
|
| 342 |
-
|
| 343 |
- For MLM training (HF Trainer):
|
| 344 |
forward(input_ids, attention_mask, labels) -> HF outputs (with .loss, .logits)
|
| 345 |
-
-
|
| 346 |
-
forward(input_ids, attention_mask) -> pooled embedding (B, emb_dim)
|
| 347 |
-
- token_logits(...) helper for reconstruction in CL.py
|
| 348 |
"""
|
| 349 |
|
| 350 |
def __init__(
|
|
@@ -380,19 +368,9 @@ class PSMILESDebertaEncoder:
|
|
| 380 |
)
|
| 381 |
self.model = DebertaV2ForMaskedLM(config)
|
| 382 |
|
| 383 |
-
# pool_proj required by CL.py
|
| 384 |
# Use hidden size from config if available
|
| 385 |
hs = int(getattr(self.model.config, "hidden_size", hidden_size))
|
| 386 |
self.pool_proj = nn.Linear(hs, emb_dim)
|
| 387 |
-
|
| 388 |
-
# allow .to() and .parameters() by delegating via nn.Module-like behavior
|
| 389 |
-
# (We keep it simple: expose these methods explicitly.)
|
| 390 |
-
# Note: CL.py uses encoder as nn.Module; to ensure compatibility, we provide:
|
| 391 |
-
# - to()
|
| 392 |
-
# - parameters()
|
| 393 |
-
# - state_dict()/load_state_dict()
|
| 394 |
-
# - train()/eval()
|
| 395 |
-
# - __call__ routes to forward()
|
| 396 |
self._device = None
|
| 397 |
|
| 398 |
# ---- nn.Module-like API ----
|
|
@@ -429,14 +407,14 @@ class PSMILESDebertaEncoder:
|
|
| 429 |
try:
|
| 430 |
self.model.load_state_dict(state_dict, strict=strict)
|
| 431 |
except Exception:
|
| 432 |
-
# ignore if incompatible
|
| 433 |
pass
|
| 434 |
return self
|
| 435 |
|
| 436 |
def __call__(self, input_ids, attention_mask=None, labels=None):
|
| 437 |
return self.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 438 |
|
| 439 |
-
# ----
|
| 440 |
def _pool_hidden(self, last_hidden_state, attention_mask=None):
|
| 441 |
"""
|
| 442 |
Pool token embeddings -> sequence embedding.
|
|
@@ -455,7 +433,7 @@ class PSMILESDebertaEncoder:
|
|
| 455 |
def forward(self, input_ids, attention_mask=None, labels=None):
|
| 456 |
"""
|
| 457 |
If labels is provided -> MLM mode: return HF outputs (Trainer compatible).
|
| 458 |
-
Else -> encoder mode: return pooled embedding
|
| 459 |
"""
|
| 460 |
if labels is not None:
|
| 461 |
return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
|
@@ -467,7 +445,6 @@ class PSMILESDebertaEncoder:
|
|
| 467 |
|
| 468 |
def token_logits(self, input_ids, attention_mask=None, labels=None):
|
| 469 |
"""
|
| 470 |
-
CL helper:
|
| 471 |
- If labels provided: returns loss tensor from HF MLM forward
|
| 472 |
- Else: returns token logits (B, L, V)
|
| 473 |
"""
|
|
@@ -487,7 +464,6 @@ def build_model_and_trainer(tokenizer, dataset_train, dataset_test, spm_model_pa
|
|
| 487 |
vocab_size = len(tokenizer)
|
| 488 |
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
| 489 |
|
| 490 |
-
# Use wrapper so it is used in THIS file and also imported by CL.py
|
| 491 |
model = PSMILESDebertaEncoder(
|
| 492 |
model_dir_or_name=None,
|
| 493 |
vocab_size=vocab_size,
|
|
@@ -603,4 +579,4 @@ def main():
|
|
| 603 |
|
| 604 |
|
| 605 |
if __name__ == "__main__":
|
| 606 |
-
main()
|
|
|
|
| 1 |
"""
|
| 2 |
DeBERTav2.py
|
| 3 |
DeBERTaV2 masked language modeling pretraining for polymer SMILES (PSMILES).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
|
|
|
| 134 |
|
| 135 |
def build_psmiles_tokenizer(spm_path: str, max_len: int = 128):
|
| 136 |
"""
|
|
|
|
| 137 |
Uses SentencePiece-backed DebertaV2Tokenizer.
|
| 138 |
"""
|
| 139 |
from transformers import DebertaV2Tokenizer
|
|
|
|
| 324 |
|
| 325 |
|
| 326 |
# =============================================================================
|
| 327 |
+
# Encoder wrapper for MLM training
|
| 328 |
# =============================================================================
|
| 329 |
|
| 330 |
class PSMILESDebertaEncoder:
|
| 331 |
"""
|
| 332 |
Dual-use wrapper:
|
|
|
|
| 333 |
- For MLM training (HF Trainer):
|
| 334 |
forward(input_ids, attention_mask, labels) -> HF outputs (with .loss, .logits)
|
| 335 |
+
- token_logits(...) helper for reconstruction
|
|
|
|
|
|
|
| 336 |
"""
|
| 337 |
|
| 338 |
def __init__(
|
|
|
|
| 368 |
)
|
| 369 |
self.model = DebertaV2ForMaskedLM(config)
|
| 370 |
|
|
|
|
| 371 |
# Use hidden size from config if available
|
| 372 |
hs = int(getattr(self.model.config, "hidden_size", hidden_size))
|
| 373 |
self.pool_proj = nn.Linear(hs, emb_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
self._device = None
|
| 375 |
|
| 376 |
# ---- nn.Module-like API ----
|
|
|
|
| 407 |
try:
|
| 408 |
self.model.load_state_dict(state_dict, strict=strict)
|
| 409 |
except Exception:
|
| 410 |
+
# ignore if incompatible
|
| 411 |
pass
|
| 412 |
return self
|
| 413 |
|
| 414 |
def __call__(self, input_ids, attention_mask=None, labels=None):
|
| 415 |
return self.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 416 |
|
| 417 |
+
# ---- Core helpers ----
|
| 418 |
def _pool_hidden(self, last_hidden_state, attention_mask=None):
|
| 419 |
"""
|
| 420 |
Pool token embeddings -> sequence embedding.
|
|
|
|
| 433 |
def forward(self, input_ids, attention_mask=None, labels=None):
|
| 434 |
"""
|
| 435 |
If labels is provided -> MLM mode: return HF outputs (Trainer compatible).
|
| 436 |
+
Else -> encoder mode: return pooled embedding.
|
| 437 |
"""
|
| 438 |
if labels is not None:
|
| 439 |
return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
|
|
|
| 445 |
|
| 446 |
def token_logits(self, input_ids, attention_mask=None, labels=None):
|
| 447 |
"""
|
|
|
|
| 448 |
- If labels provided: returns loss tensor from HF MLM forward
|
| 449 |
- Else: returns token logits (B, L, V)
|
| 450 |
"""
|
|
|
|
| 464 |
vocab_size = len(tokenizer)
|
| 465 |
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
| 466 |
|
|
|
|
| 467 |
model = PSMILESDebertaEncoder(
|
| 468 |
model_dir_or_name=None,
|
| 469 |
vocab_size=vocab_size,
|
|
|
|
| 579 |
|
| 580 |
|
| 581 |
if __name__ == "__main__":
|
| 582 |
+
main()
|