kaurm43 commited on
Commit
124c90f
·
verified ·
1 Parent(s): 698cedb

Update PolyFusion/DeBERTav2.py

Browse files
Files changed (1) hide show
  1. PolyFusion/DeBERTav2.py +195 -32
PolyFusion/DeBERTav2.py CHANGED
@@ -1,21 +1,34 @@
1
  """
 
2
  DeBERTaV2 masked language modeling pretraining for polymer SMILES (PSMILES).
 
 
 
 
 
 
 
 
3
  """
4
 
 
 
5
  import os
6
  import time
7
  import json
8
  import shutil
9
  import argparse
10
  import warnings
11
- from typing import Optional, List
12
 
13
  warnings.filterwarnings("ignore")
14
 
 
15
  def set_cuda_visible_devices(gpu: str = "0") -> None:
16
  """Set CUDA_VISIBLE_DEVICES before importing torch/transformers heavy modules."""
17
  os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)
18
 
 
19
  def parse_args() -> argparse.Namespace:
20
  """CLI arguments for paths and key training/data settings."""
21
  parser = argparse.ArgumentParser(description="DeBERTaV2 MLM pretraining for polymer pSMILES.")
@@ -127,13 +140,18 @@ def train_sentencepiece_if_needed(train_txt: str, spm_model_prefix: str, vocab_s
127
  return model_path
128
 
129
 
130
- def build_tokenizer(spm_model_path: str):
131
- """Create a DebertaV2Tokenizer backed by a SentencePiece model."""
 
 
 
132
  from transformers import DebertaV2Tokenizer
133
 
134
- tokenizer = DebertaV2Tokenizer(vocab_file=spm_model_path, do_lower_case=False)
135
- tokenizer.add_special_tokens({"pad_token": "<pad>", "mask_token": "<mask>"})
136
- return tokenizer
 
 
137
 
138
 
139
  def tokenize_and_save_dataset(train_psmiles: List[str], val_psmiles: List[str], tokenizer, save_dir: str) -> None:
@@ -144,7 +162,7 @@ def tokenize_and_save_dataset(train_psmiles: List[str], val_psmiles: List[str],
144
  hf_val = Dataset.from_dict({"text": val_psmiles})
145
 
146
  def tokenize_batch(examples):
147
- return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
148
 
149
  train_tok = hf_train.map(tokenize_batch, batched=True, batch_size=10_000, num_proc=10)
150
  val_tok = hf_val.map(tokenize_batch, batched=True, batch_size=10_000, num_proc=10)
@@ -169,14 +187,13 @@ def load_tokenized_dataset(tokenized_dir: str):
169
 
170
  class EpochMetricsCallback:
171
  """
172
- TrainerCallback that:
173
  - Tracks best validation loss
174
  - Implements early stopping on val_loss with patience
175
  - Saves best model + tokenizer.model copy
176
  - Prints epoch-level stats
177
  """
178
 
179
- # NOTE: We import TrainerCallback lazily to keep module import minimal in helpers.
180
  def __init__(self, tokenizer_model_path: str, output_dir: str, patience: int = 10):
181
  from transformers.trainer_callback import TrainerCallback
182
  from sentencepiece import SentencePieceProcessor
@@ -216,7 +233,6 @@ class EpochMetricsCallback:
216
  self._last_train_loss = None
217
 
218
  def as_trainer_callback(self):
219
- """Return an instance that HuggingFace Trainer can register."""
220
  return self._cb_cls(self)
221
 
222
  def _save_model(self, trainer_obj, suffix: str) -> None:
@@ -312,33 +328,181 @@ def compute_metrics(eval_pred):
312
  preds = np.argmax(masked_logits, axis=-1)
313
 
314
  f1 = f1_score(masked_labels, preds, average="weighted")
315
- accuracy = np.mean(masked_labels == preds)
316
  return {"eval_f1": f1, "eval_accuracy": accuracy}
317
 
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  def build_model_and_trainer(tokenizer, dataset_train, dataset_test, spm_model_path: str, output_dir: str):
320
  """Construct model, training args, callback, and Trainer."""
321
  import torch
322
- import numpy as np
323
- from transformers import DebertaV2Config, DebertaV2ForMaskedLM, Trainer, TrainingArguments
324
- from transformers import DataCollatorForLanguageModeling
325
 
326
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
327
 
328
  vocab_size = len(tokenizer)
329
  pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
330
 
331
- config = DebertaV2Config(
 
 
332
  vocab_size=vocab_size,
 
333
  hidden_size=600,
334
  num_attention_heads=12,
335
  num_hidden_layers=12,
336
  intermediate_size=512,
337
- pad_token_id=pad_token_id,
338
  )
339
-
340
- model = DebertaV2ForMaskedLM(config)
341
- model.resize_token_embeddings(len(tokenizer))
 
 
342
 
343
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
344
  model.to(device)
@@ -351,7 +515,7 @@ def build_model_and_trainer(tokenizer, dataset_train, dataset_test, spm_model_pa
351
  per_device_eval_batch_size=8,
352
  eval_accumulation_steps=1000,
353
  gradient_accumulation_steps=4,
354
- eval_strategy="epoch", # kept exactly as provided
355
  logging_strategy="steps",
356
  logging_steps=500,
357
  logging_first_step=True,
@@ -364,8 +528,9 @@ def build_model_and_trainer(tokenizer, dataset_train, dataset_test, spm_model_pa
364
  )
365
 
366
  callback_wrapper = EpochMetricsCallback(tokenizer_model_path=spm_model_path, output_dir=output_dir, patience=10)
 
367
  trainer = Trainer(
368
- model=model,
369
  args=training_args,
370
  train_dataset=dataset_train,
371
  eval_dataset=dataset_test,
@@ -373,26 +538,21 @@ def build_model_and_trainer(tokenizer, dataset_train, dataset_test, spm_model_pa
373
  compute_metrics=compute_metrics,
374
  callbacks=[callback_wrapper.as_trainer_callback()],
375
  )
376
-
377
  callback_wrapper.trainer_ref = trainer
378
  return model, trainer, callback_wrapper
379
 
380
 
381
  def run_training(csv_file: str, nrows: int, train_txt: str, spm_prefix: str, tokenized_dir: str, output_dir: str) -> None:
382
  """End-to-end: load data, train tokenizer (if needed), tokenize, train model, print final report."""
383
- import torch
384
-
385
  psmiles_list = load_psmiles_from_csv(csv_file, nrows=nrows)
386
  train_psmiles, val_psmiles = train_val_split(psmiles_list, test_size=0.2, random_state=42)
387
 
388
  write_sentencepiece_training_text(train_psmiles, train_txt)
389
  spm_model_path = train_sentencepiece_if_needed(train_txt, spm_prefix, vocab_size=265)
390
 
391
- tokenizer = build_tokenizer(spm_model_path)
392
 
393
- # Tokenize and save dataset
394
  tokenize_and_save_dataset(train_psmiles, val_psmiles, tokenizer, tokenized_dir)
395
-
396
  dataset_train, dataset_test = load_tokenized_dataset(tokenized_dir)
397
 
398
  model, trainer, callback = build_model_and_trainer(
@@ -408,6 +568,10 @@ def run_training(csv_file: str, nrows: int, train_txt: str, spm_prefix: str, tok
408
  total_time = time.time() - start_time
409
 
410
  # Final report
 
 
 
 
411
  print(f"\n=== Final Results ===")
412
  print(f"Total Training Time (s): {total_time:.2f}")
413
  print(f"Best Validation Loss: {callback.best_val_loss:.4f}")
@@ -415,11 +579,10 @@ def run_training(csv_file: str, nrows: int, train_txt: str, spm_prefix: str, tok
415
  print(f"Best Validation Accuracy: {callback.best_val_accuracy:.4f}" if callback.best_val_accuracy is not None else "Best Validation Accuracy: None")
416
  print(f"Best Perplexity: {callback.best_perplexity:.2f}" if callback.best_perplexity is not None else "Best Perplexity: None")
417
  print(f"Best Model Epoch: {int(callback.best_epoch)}")
418
- print(f"Final Training Loss: {train_output.training_loss:.4f}")
419
-
420
- total_params = sum(p.numel() for p in model.parameters())
421
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
422
- non_trainable_params = total_params - trainable_params
423
  print(f"Total Parameters: {total_params}")
424
  print(f"Trainable Parameters: {trainable_params}")
425
  print(f"Non-trainable Parameters: {non_trainable_params}")
 
1
  """
2
+ DeBERTav2.py
3
  DeBERTaV2 masked language modeling pretraining for polymer SMILES (PSMILES).
4
+
5
+ This file provides:
6
+ - build_psmiles_tokenizer(spm_path, max_len)
7
+ - PSMILESDebertaEncoder: a dual-use wrapper
8
+ * If labels provided -> behaves like MLM model (HF Trainer compatible)
9
+ * If labels not provided -> returns pooled embedding (for CL.py)
10
+ * token_logits(...) helper for reconstruction in CL.py
11
+ - End-to-end MLM training utilities (kept aligned with your original script)
12
  """
13
 
14
+ from __future__ import annotations
15
+
16
  import os
17
  import time
18
  import json
19
  import shutil
20
  import argparse
21
  import warnings
22
+ from typing import Optional, List, Tuple
23
 
24
  warnings.filterwarnings("ignore")
25
 
26
+
27
  def set_cuda_visible_devices(gpu: str = "0") -> None:
28
  """Set CUDA_VISIBLE_DEVICES before importing torch/transformers heavy modules."""
29
  os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)
30
 
31
+
32
  def parse_args() -> argparse.Namespace:
33
  """CLI arguments for paths and key training/data settings."""
34
  parser = argparse.ArgumentParser(description="DeBERTaV2 MLM pretraining for polymer pSMILES.")
 
140
  return model_path
141
 
142
 
143
+ def build_psmiles_tokenizer(spm_path: str, max_len: int = 128):
144
+ """
145
+ Build tokenizer exactly as CL.py expects.
146
+ Uses SentencePiece-backed DebertaV2Tokenizer.
147
+ """
148
  from transformers import DebertaV2Tokenizer
149
 
150
+ tok = DebertaV2Tokenizer(vocab_file=spm_path, do_lower_case=False)
151
+ tok.add_special_tokens({"pad_token": "<pad>", "mask_token": "<mask>"})
152
+ # store max_len for convenience (not required by HF)
153
+ tok.model_max_length = max_len
154
+ return tok
155
 
156
 
157
  def tokenize_and_save_dataset(train_psmiles: List[str], val_psmiles: List[str], tokenizer, save_dir: str) -> None:
 
162
  hf_val = Dataset.from_dict({"text": val_psmiles})
163
 
164
  def tokenize_batch(examples):
165
+ return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=tokenizer.model_max_length)
166
 
167
  train_tok = hf_train.map(tokenize_batch, batched=True, batch_size=10_000, num_proc=10)
168
  val_tok = hf_val.map(tokenize_batch, batched=True, batch_size=10_000, num_proc=10)
 
187
 
188
  class EpochMetricsCallback:
189
  """
190
+ TrainerCallback wrapper that:
191
  - Tracks best validation loss
192
  - Implements early stopping on val_loss with patience
193
  - Saves best model + tokenizer.model copy
194
  - Prints epoch-level stats
195
  """
196
 
 
197
  def __init__(self, tokenizer_model_path: str, output_dir: str, patience: int = 10):
198
  from transformers.trainer_callback import TrainerCallback
199
  from sentencepiece import SentencePieceProcessor
 
233
  self._last_train_loss = None
234
 
235
  def as_trainer_callback(self):
 
236
  return self._cb_cls(self)
237
 
238
  def _save_model(self, trainer_obj, suffix: str) -> None:
 
328
  preds = np.argmax(masked_logits, axis=-1)
329
 
330
  f1 = f1_score(masked_labels, preds, average="weighted")
331
+ accuracy = float(np.mean(masked_labels == preds))
332
  return {"eval_f1": f1, "eval_accuracy": accuracy}
333
 
334
 
335
+ # =============================================================================
336
+ # Encoder wrapper used by CL.py AND used here for MLM training
337
+ # =============================================================================
338
+
339
+ class PSMILESDebertaEncoder:
340
+ """
341
+ Dual-use wrapper:
342
+
343
+ - For MLM training (HF Trainer):
344
+ forward(input_ids, attention_mask, labels) -> HF outputs (with .loss, .logits)
345
+ - For CL:
346
+ forward(input_ids, attention_mask) -> pooled embedding (B, emb_dim)
347
+ - token_logits(...) helper for reconstruction in CL.py
348
+ """
349
+
350
+ def __init__(
351
+ self,
352
+ model_dir_or_name: Optional[str] = None,
353
+ hidden_size: int = 600,
354
+ num_hidden_layers: int = 12,
355
+ num_attention_heads: int = 12,
356
+ intermediate_size: int = 512,
357
+ vocab_size: Optional[int] = None,
358
+ pad_token_id: int = 0,
359
+ emb_dim: int = 600,
360
+ ):
361
+ import torch
362
+ import torch.nn as nn
363
+ from transformers import DebertaV2Config, DebertaV2ForMaskedLM
364
+
365
+ self.torch = torch
366
+ self.nn = nn
367
+
368
+ if model_dir_or_name is not None:
369
+ self.model = DebertaV2ForMaskedLM.from_pretrained(model_dir_or_name)
370
+ else:
371
+ if vocab_size is None:
372
+ vocab_size = 265 # fallback; will be resized by caller if tokenizer provided
373
+ config = DebertaV2Config(
374
+ vocab_size=vocab_size,
375
+ hidden_size=hidden_size,
376
+ num_attention_heads=num_attention_heads,
377
+ num_hidden_layers=num_hidden_layers,
378
+ intermediate_size=intermediate_size,
379
+ pad_token_id=pad_token_id,
380
+ )
381
+ self.model = DebertaV2ForMaskedLM(config)
382
+
383
+ # pool_proj required by CL.py
384
+ # Use hidden size from config if available
385
+ hs = int(getattr(self.model.config, "hidden_size", hidden_size))
386
+ self.pool_proj = nn.Linear(hs, emb_dim)
387
+
388
+ # allow .to() and .parameters() by delegating via nn.Module-like behavior
389
+ # (We keep it simple: expose these methods explicitly.)
390
+ # Note: CL.py uses encoder as nn.Module; to ensure compatibility, we provide:
391
+ # - to()
392
+ # - parameters()
393
+ # - state_dict()/load_state_dict()
394
+ # - train()/eval()
395
+ # - __call__ routes to forward()
396
+ self._device = None
397
+
398
+ # ---- nn.Module-like API ----
399
+ def to(self, device):
400
+ self.model.to(device)
401
+ self.pool_proj.to(device)
402
+ self._device = device
403
+ return self
404
+
405
+ def train(self, mode: bool = True):
406
+ self.model.train(mode)
407
+ self.pool_proj.train(mode)
408
+ return self
409
+
410
+ def eval(self):
411
+ return self.train(False)
412
+
413
+ def parameters(self):
414
+ for p in self.model.parameters():
415
+ yield p
416
+ for p in self.pool_proj.parameters():
417
+ yield p
418
+
419
+ def state_dict(self):
420
+ sd = {"model": self.model.state_dict(), "pool_proj": self.pool_proj.state_dict()}
421
+ return sd
422
+
423
+ def load_state_dict(self, state_dict, strict: bool = False):
424
+ if isinstance(state_dict, dict) and "model" in state_dict and "pool_proj" in state_dict:
425
+ self.model.load_state_dict(state_dict["model"], strict=strict)
426
+ self.pool_proj.load_state_dict(state_dict["pool_proj"], strict=strict)
427
+ else:
428
+ # allow loading a raw HF state_dict (best-effort)
429
+ try:
430
+ self.model.load_state_dict(state_dict, strict=strict)
431
+ except Exception:
432
+ # ignore if incompatible; CL often uses strict=False
433
+ pass
434
+ return self
435
+
436
+ def __call__(self, input_ids, attention_mask=None, labels=None):
437
+ return self.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
438
+
439
+ # ---- core helpers ----
440
+ def _pool_hidden(self, last_hidden_state, attention_mask=None):
441
+ """
442
+ Pool token embeddings -> sequence embedding.
443
+ Use attention-masked mean pooling (robust).
444
+ """
445
+ import torch
446
+
447
+ if attention_mask is None:
448
+ return last_hidden_state.mean(dim=1)
449
+
450
+ mask = attention_mask.to(last_hidden_state.device).unsqueeze(-1).float()
451
+ denom = mask.sum(dim=1).clamp(min=1.0)
452
+ pooled = (last_hidden_state * mask).sum(dim=1) / denom
453
+ return pooled
454
+
455
+ def forward(self, input_ids, attention_mask=None, labels=None):
456
+ """
457
+ If labels is provided -> MLM mode: return HF outputs (Trainer compatible).
458
+ Else -> encoder mode: return pooled embedding for CL.
459
+ """
460
+ if labels is not None:
461
+ return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
462
+
463
+ out = self.model.deberta(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
464
+ last_hidden = out.last_hidden_state
465
+ pooled = self._pool_hidden(last_hidden, attention_mask=attention_mask)
466
+ return self.pool_proj(pooled)
467
+
468
+ def token_logits(self, input_ids, attention_mask=None, labels=None):
469
+ """
470
+ CL helper:
471
+ - If labels provided: returns loss tensor from HF MLM forward
472
+ - Else: returns token logits (B, L, V)
473
+ """
474
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
475
+ if labels is not None:
476
+ return outputs.loss
477
+ return outputs.logits
478
+
479
+
480
  def build_model_and_trainer(tokenizer, dataset_train, dataset_test, spm_model_path: str, output_dir: str):
481
  """Construct model, training args, callback, and Trainer."""
482
  import torch
483
+ from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
 
 
484
 
485
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
486
 
487
  vocab_size = len(tokenizer)
488
  pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
489
 
490
+ # Use wrapper so it is used in THIS file and also imported by CL.py
491
+ model = PSMILESDebertaEncoder(
492
+ model_dir_or_name=None,
493
  vocab_size=vocab_size,
494
+ pad_token_id=pad_token_id,
495
  hidden_size=600,
496
  num_attention_heads=12,
497
  num_hidden_layers=12,
498
  intermediate_size=512,
499
+ emb_dim=600,
500
  )
501
+ # resize HF embeddings
502
+ try:
503
+ model.model.resize_token_embeddings(len(tokenizer))
504
+ except Exception:
505
+ pass
506
 
507
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
508
  model.to(device)
 
515
  per_device_eval_batch_size=8,
516
  eval_accumulation_steps=1000,
517
  gradient_accumulation_steps=4,
518
+ eval_strategy="epoch",
519
  logging_strategy="steps",
520
  logging_steps=500,
521
  logging_first_step=True,
 
528
  )
529
 
530
  callback_wrapper = EpochMetricsCallback(tokenizer_model_path=spm_model_path, output_dir=output_dir, patience=10)
531
+
532
  trainer = Trainer(
533
+ model=model, # wrapper is Trainer-compatible
534
  args=training_args,
535
  train_dataset=dataset_train,
536
  eval_dataset=dataset_test,
 
538
  compute_metrics=compute_metrics,
539
  callbacks=[callback_wrapper.as_trainer_callback()],
540
  )
 
541
  callback_wrapper.trainer_ref = trainer
542
  return model, trainer, callback_wrapper
543
 
544
 
545
  def run_training(csv_file: str, nrows: int, train_txt: str, spm_prefix: str, tokenized_dir: str, output_dir: str) -> None:
546
  """End-to-end: load data, train tokenizer (if needed), tokenize, train model, print final report."""
 
 
547
  psmiles_list = load_psmiles_from_csv(csv_file, nrows=nrows)
548
  train_psmiles, val_psmiles = train_val_split(psmiles_list, test_size=0.2, random_state=42)
549
 
550
  write_sentencepiece_training_text(train_psmiles, train_txt)
551
  spm_model_path = train_sentencepiece_if_needed(train_txt, spm_prefix, vocab_size=265)
552
 
553
+ tokenizer = build_psmiles_tokenizer(spm_path=spm_model_path, max_len=128)
554
 
 
555
  tokenize_and_save_dataset(train_psmiles, val_psmiles, tokenizer, tokenized_dir)
 
556
  dataset_train, dataset_test = load_tokenized_dataset(tokenized_dir)
557
 
558
  model, trainer, callback = build_model_and_trainer(
 
568
  total_time = time.time() - start_time
569
 
570
  # Final report
571
+ total_params = sum(p.numel() for p in model.parameters())
572
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
573
+ non_trainable_params = total_params - trainable_params
574
+
575
  print(f"\n=== Final Results ===")
576
  print(f"Total Training Time (s): {total_time:.2f}")
577
  print(f"Best Validation Loss: {callback.best_val_loss:.4f}")
 
579
  print(f"Best Validation Accuracy: {callback.best_val_accuracy:.4f}" if callback.best_val_accuracy is not None else "Best Validation Accuracy: None")
580
  print(f"Best Perplexity: {callback.best_perplexity:.2f}" if callback.best_perplexity is not None else "Best Perplexity: None")
581
  print(f"Best Model Epoch: {int(callback.best_epoch)}")
582
+ try:
583
+ print(f"Final Training Loss: {train_output.training_loss:.4f}")
584
+ except Exception:
585
+ pass
 
586
  print(f"Total Parameters: {total_params}")
587
  print(f"Trainable Parameters: {trainable_params}")
588
  print(f"Non-trainable Parameters: {non_trainable_params}")