kaurm43 commited on
Commit
58d199f
·
verified ·
1 Parent(s): 6637320

Update PolyFusion/DeBERTav2.py

Browse files
Files changed (1) hide show
  1. PolyFusion/DeBERTav2.py +6 -30
PolyFusion/DeBERTav2.py CHANGED
@@ -1,14 +1,6 @@
1
  """
2
  DeBERTav2.py
3
  DeBERTaV2 masked language modeling pretraining for polymer SMILES (PSMILES).
4
-
5
- This file provides:
6
- - build_psmiles_tokenizer(spm_path, max_len)
7
- - PSMILESDebertaEncoder: a dual-use wrapper
8
- * If labels provided -> behaves like MLM model (HF Trainer compatible)
9
- * If labels not provided -> returns pooled embedding (for CL.py)
10
- * token_logits(...) helper for reconstruction in CL.py
11
- - End-to-end MLM training utilities (kept aligned with your original script)
12
  """
13
 
14
  from __future__ import annotations
@@ -142,7 +134,6 @@ def train_sentencepiece_if_needed(train_txt: str, spm_model_prefix: str, vocab_s
142
 
143
  def build_psmiles_tokenizer(spm_path: str, max_len: int = 128):
144
  """
145
- Build tokenizer exactly as CL.py expects.
146
  Uses SentencePiece-backed DebertaV2Tokenizer.
147
  """
148
  from transformers import DebertaV2Tokenizer
@@ -333,18 +324,15 @@ def compute_metrics(eval_pred):
333
 
334
 
335
  # =============================================================================
336
- # Encoder wrapper used by CL.py AND used here for MLM training
337
  # =============================================================================
338
 
339
  class PSMILESDebertaEncoder:
340
  """
341
  Dual-use wrapper:
342
-
343
  - For MLM training (HF Trainer):
344
  forward(input_ids, attention_mask, labels) -> HF outputs (with .loss, .logits)
345
- - For CL:
346
- forward(input_ids, attention_mask) -> pooled embedding (B, emb_dim)
347
- - token_logits(...) helper for reconstruction in CL.py
348
  """
349
 
350
  def __init__(
@@ -380,19 +368,9 @@ class PSMILESDebertaEncoder:
380
  )
381
  self.model = DebertaV2ForMaskedLM(config)
382
 
383
- # pool_proj required by CL.py
384
  # Use hidden size from config if available
385
  hs = int(getattr(self.model.config, "hidden_size", hidden_size))
386
  self.pool_proj = nn.Linear(hs, emb_dim)
387
-
388
- # allow .to() and .parameters() by delegating via nn.Module-like behavior
389
- # (We keep it simple: expose these methods explicitly.)
390
- # Note: CL.py uses encoder as nn.Module; to ensure compatibility, we provide:
391
- # - to()
392
- # - parameters()
393
- # - state_dict()/load_state_dict()
394
- # - train()/eval()
395
- # - __call__ routes to forward()
396
  self._device = None
397
 
398
  # ---- nn.Module-like API ----
@@ -429,14 +407,14 @@ class PSMILESDebertaEncoder:
429
  try:
430
  self.model.load_state_dict(state_dict, strict=strict)
431
  except Exception:
432
- # ignore if incompatible; CL often uses strict=False
433
  pass
434
  return self
435
 
436
  def __call__(self, input_ids, attention_mask=None, labels=None):
437
  return self.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
438
 
439
- # ---- core helpers ----
440
  def _pool_hidden(self, last_hidden_state, attention_mask=None):
441
  """
442
  Pool token embeddings -> sequence embedding.
@@ -455,7 +433,7 @@ class PSMILESDebertaEncoder:
455
  def forward(self, input_ids, attention_mask=None, labels=None):
456
  """
457
  If labels is provided -> MLM mode: return HF outputs (Trainer compatible).
458
- Else -> encoder mode: return pooled embedding for CL.
459
  """
460
  if labels is not None:
461
  return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
@@ -467,7 +445,6 @@ class PSMILESDebertaEncoder:
467
 
468
  def token_logits(self, input_ids, attention_mask=None, labels=None):
469
  """
470
- CL helper:
471
  - If labels provided: returns loss tensor from HF MLM forward
472
  - Else: returns token logits (B, L, V)
473
  """
@@ -487,7 +464,6 @@ def build_model_and_trainer(tokenizer, dataset_train, dataset_test, spm_model_pa
487
  vocab_size = len(tokenizer)
488
  pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
489
 
490
- # Use wrapper so it is used in THIS file and also imported by CL.py
491
  model = PSMILESDebertaEncoder(
492
  model_dir_or_name=None,
493
  vocab_size=vocab_size,
@@ -603,4 +579,4 @@ def main():
603
 
604
 
605
  if __name__ == "__main__":
606
- main()
 
1
  """
2
  DeBERTav2.py
3
  DeBERTaV2 masked language modeling pretraining for polymer SMILES (PSMILES).
 
 
 
 
 
 
 
 
4
  """
5
 
6
  from __future__ import annotations
 
134
 
135
  def build_psmiles_tokenizer(spm_path: str, max_len: int = 128):
136
  """
 
137
  Uses SentencePiece-backed DebertaV2Tokenizer.
138
  """
139
  from transformers import DebertaV2Tokenizer
 
324
 
325
 
326
  # =============================================================================
327
+ # Encoder wrapper for MLM training
328
  # =============================================================================
329
 
330
  class PSMILESDebertaEncoder:
331
  """
332
  Dual-use wrapper:
 
333
  - For MLM training (HF Trainer):
334
  forward(input_ids, attention_mask, labels) -> HF outputs (with .loss, .logits)
335
+ - token_logits(...) helper for reconstruction
 
 
336
  """
337
 
338
  def __init__(
 
368
  )
369
  self.model = DebertaV2ForMaskedLM(config)
370
 
 
371
  # Use hidden size from config if available
372
  hs = int(getattr(self.model.config, "hidden_size", hidden_size))
373
  self.pool_proj = nn.Linear(hs, emb_dim)
 
 
 
 
 
 
 
 
 
374
  self._device = None
375
 
376
  # ---- nn.Module-like API ----
 
407
  try:
408
  self.model.load_state_dict(state_dict, strict=strict)
409
  except Exception:
410
+ # ignore if incompatible
411
  pass
412
  return self
413
 
414
  def __call__(self, input_ids, attention_mask=None, labels=None):
415
  return self.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
416
 
417
+ # ---- Core helpers ----
418
  def _pool_hidden(self, last_hidden_state, attention_mask=None):
419
  """
420
  Pool token embeddings -> sequence embedding.
 
433
  def forward(self, input_ids, attention_mask=None, labels=None):
434
  """
435
  If labels is provided -> MLM mode: return HF outputs (Trainer compatible).
436
+ Else -> encoder mode: return pooled embedding.
437
  """
438
  if labels is not None:
439
  return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
 
445
 
446
  def token_logits(self, input_ids, attention_mask=None, labels=None):
447
  """
 
448
  - If labels provided: returns loss tensor from HF MLM forward
449
  - Else: returns token logits (B, L, V)
450
  """
 
464
  vocab_size = len(tokenizer)
465
  pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
466
 
 
467
  model = PSMILESDebertaEncoder(
468
  model_dir_or_name=None,
469
  vocab_size=vocab_size,
 
579
 
580
 
581
  if __name__ == "__main__":
582
+ main()