kaurm43 commited on
Commit
2b6a0dd
·
verified ·
1 Parent(s): e9d2273

Update PolyAgent/orchestrator.py

Browse files
Files changed (1) hide show
  1. PolyAgent/orchestrator.py +8 -2
PolyAgent/orchestrator.py CHANGED
@@ -1645,12 +1645,18 @@ class PolymerOrchestrator:
1645
  spm_path=self.config.spm_model_path,
1646
  max_len=128,
1647
  )
1648
- vocab_sz = getattr(self._psmiles_tokenizer, "vocab_size", 270)
 
1649
 
1650
  gine = GineEncoder().to(self.config.device)
1651
  schnet = NodeSchNetWrapper().to(self.config.device)
1652
  fp = FingerprintEncoder().to(self.config.device)
1653
- psm = PSMILESDebertaEncoder(model_dir_or_name=None).to(self.config.device)
 
 
 
 
 
1654
  model = MultimodalContrastiveModel(gine, schnet, fp, psm, emb_dim=600).to(self.config.device)
1655
 
1656
  try:
 
1645
  spm_path=self.config.spm_model_path,
1646
  max_len=128,
1647
  )
1648
+ vocab_sz = len(self._psmiles_tokenizer)
1649
+ pad_id = self._psmiles_tokenizer.pad_token_id if self._psmiles_tokenizer.pad_token_id is not None else 0
1650
 
1651
  gine = GineEncoder().to(self.config.device)
1652
  schnet = NodeSchNetWrapper().to(self.config.device)
1653
  fp = FingerprintEncoder().to(self.config.device)
1654
+
1655
+ psm = PSMILESDebertaEncoder(
1656
+ model_dir_or_name=None,
1657
+ vocab_size=vocab_sz,
1658
+ pad_token_id=pad_id,
1659
+ ).to(self.config.device)
1660
  model = MultimodalContrastiveModel(gine, schnet, fp, psm, emb_dim=600).to(self.config.device)
1661
 
1662
  try: