File size: 4,031 Bytes
80ad4cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

import torch
from transformers import AutoTokenizer
import sys
import os
from hydra import compose, initialize_config_dir
from pathlib import Path
import numpy as np

# Add current dir to path
sys.path.append(os.getcwd())

try:
    from DLM_emb_model import MolEmbDLM
except ImportError:
    print("Could not import MolEmbDLM. Make sure you are running from ApexOracle directory.")
    exit(1)

def load_source_model():
    print("Loading Source Model...")
    current_directory = Path(os.getcwd())
    # Replicating logic from DLM_emb_model.py
    with initialize_config_dir(config_dir=str(current_directory/"configs"), version_base=None):
        config = compose(config_name="config")
    
    model_name = "ibm-research/materials.selfies-ted"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    DIT_ckpt_path = '/data2/tianang/projects/mdlm/Checkpoints_fangping/1-255000-fine-tune.ckpt'
    model = MolEmbDLM(config, len(tokenizer.get_vocab()), DIT_ckpt_path, tokenizer.mask_token_id)
    model.eval()
    return model, tokenizer

def load_hf_model():
    print("Loading HF Model...")
    model_path = "/data2/tianang/projects/mdlm/huggingface/huggingface_model"
    # We use the same class but loaded via from_pretrained
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = MolEmbDLM.from_pretrained(model_path)
    except Exception as e:
        print(f"Failed to load HF model: {e}")
        # Fallback to local if needed, though path is absolute
        model = MolEmbDLM.from_pretrained(".")
    model.eval()
    return model, tokenizer

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # Load Source Model
    source_model, source_tokenizer = load_source_model()
    source_model.to(device)
    
    # Load HF Model
    hf_model, hf_tokenizer = load_hf_model()
    hf_model.to(device)
    
    # Test Input (SELFIES)
    selfies = "[C][C][=O][O]" # Ethanol  "[C][C][=O][O]"
    processed_selfies = selfies.replace('][', '] [')
    
    print(f"Testing with SELFIES: {processed_selfies}")
    
    # Tokenize (using source tokenizer for both to ensure identical input ids if tokenizers are same)
    # Note: HF model folder has its own tokenizer files, source uses "ibm-research/materials.selfies-ted".
    # They should be the same, but let's verify input_ids match too.
    
    inputs_source = source_tokenizer(processed_selfies, return_tensors="pt", padding=False, truncation=False)
    inputs_hf = hf_tokenizer(processed_selfies, return_tensors="pt", padding=False, truncation=False)
    
    print(f"Source Input IDs: {inputs_source['input_ids']}")
    print(f"HF Input IDs:     {inputs_hf['input_ids']}")
    
    if not torch.equal(inputs_source['input_ids'], inputs_hf['input_ids']):
        print("WARNING: Tokenizers produced different input IDs!")
    
    # Run Source Model
    inputs_s = {k: v.to(device) for k, v in inputs_source.items() if k in ["input_ids", "attention_mask"]}
    with torch.no_grad():
        emb_source = source_model(**inputs_s)
        
    # Run HF Model
    inputs_h = {k: v.to(device) for k, v in inputs_hf.items() if k in ["input_ids", "attention_mask"]}
    with torch.no_grad():
        emb_hf = hf_model(**inputs_h)

    print(f'Huggingface Embeddings: {emb_hf[0][0]}')
        
    print(f"Source Emb Shape: {emb_source.shape}")
    print(f"HF Emb Shape:     {emb_hf.shape}")
    
    # Compare
    diff = torch.abs(emb_source - emb_hf).sum().item()
    max_diff = torch.abs(emb_source - emb_hf).max().item()
    
    print(f"Sum of Absolute Differences: {diff}")
    print(f"Max Absolute Difference:     {max_diff}")
    
    if diff < 1e-5: # Allow small floating point differences
        print("SUCCESS: Embeddings are identical (or extremely close).")
    else:
        print("FAILURE: Embeddings differ significantly.")
        print(f"Source Mean: {emb_source.mean().item()}")
        print(f"HF Mean:     {emb_hf.mean().item()}")

if __name__ == "__main__":
    main()