ApexOracle / temp_fangping.py
Kiria-Nozan's picture
solve same embedding bug
80ad4cd
import pandas as pd
import numpy as np
from DLM_emb_model import MolEmbDLM
from transformers import AutoTokenizer
import torch
import selfies as sf
MODEL_DIR = "Kiria-Nozan/ApexOracle"
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = MolEmbDLM.from_pretrained(MODEL_DIR)
model.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Load CSV data
df = pd.read_csv("temp_data/polymers_lit_scraped.csv")
# Extract all unique monomer SMILES
monomer_columns = ["monomer A", "monomer B", "monomer C", "monomer D", "monomer E", "monomer F"]
all_monomers = set()
for col in monomer_columns:
if col in df.columns:
monomers = df[col].dropna().unique()
all_monomers.update(monomers)
print(f"Total unique monomers: {len(all_monomers)}")
# Convert SMILES to SELFIES and prepare for embedding
monomer_selfies = {}
valid_monomers = []
for smiles in all_monomers:
try:
selfies = sf.encoder(smiles)
monomer_selfies[smiles] = selfies
valid_monomers.append((smiles, selfies))
except Exception as e:
print(f"Error converting {smiles} to SELFIES: {e}")
print(f"Valid monomers for embedding: {len(valid_monomers)}")
# Generate embeddings for all monomers
monomer_embeddings = {}
for smiles, selfies in valid_monomers:
# Prepare input similar to example.py
batch = tokenizer(
selfies.replace('][', '] ['),
padding="max_length",
max_length=1024,
truncation=True,
return_tensors="pt",
)
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
embeddings = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"]+1-batch["attention_mask"],
)
# Store the embedding (average pooling over sequence length)
monomer_embeddings[smiles] = embeddings[0][0].cpu().numpy()
print(f"Generated embeddings for {len(monomer_embeddings)} monomers")
print(f"Embedding shape: {list(monomer_embeddings.values())[0].shape}")
# Check for identical embeddings
print("\nChecking for identical embeddings...")
embedding_list = list(monomer_embeddings.items())
identical_pairs = []
for i in range(len(embedding_list)):
for j in range(i + 1, len(embedding_list)):
smiles1, emb1 = embedding_list[i]
smiles2, emb2 = embedding_list[j]
# Check if embeddings are identical (with small tolerance for floating point precision)
if np.allclose(emb1, emb2, rtol=1e-09, atol=1e-09):
identical_pairs.append((smiles1, smiles2))
if identical_pairs:
print(f"Found {len(identical_pairs)} pairs of identical embeddings:")
for smiles1, smiles2 in identical_pairs:
print(f" {smiles1} <-> {smiles2}")
# Analyze the identical groups
print("\nAnalyzing identical embedding groups...")
# Create groups of molecules with identical embeddings
identical_groups = {}
processed = set()
for smiles1, smiles2 in identical_pairs:
if smiles1 not in processed and smiles2 not in processed:
# Find all molecules identical to smiles1
group = {smiles1, smiles2}
for other_smiles1, other_smiles2 in identical_pairs:
if other_smiles1 in group:
group.add(other_smiles2)
elif other_smiles2 in group:
group.add(other_smiles1)
group_key = frozenset(group)
if group_key not in identical_groups:
identical_groups[group_key] = group
processed.update(group)
print(f"Found {len(identical_groups)} groups of molecules with identical embeddings:")
for i, group in enumerate(identical_groups.values(), 1):
print(f"\nGroup {i} ({len(group)} molecules):")
for smiles in sorted(group):
selfies_repr = monomer_selfies.get(smiles, "N/A")
print(f" SMILES: {smiles}")
print(f" SELFIES: {selfies_repr}")
print()
else:
print("No identical embeddings found.")
# Save results
np.save("temp_data/monomer_embeddings.npy", monomer_embeddings)
print("Embeddings saved to monomer_embeddings.npy")