| import matplotlib.pyplot as plt |
| import matplotlib as mpl |
| import numpy as np |
| import os |
| import pandas as pd |
| from rdkit import Chem, DataStructs |
| from rdkit.Chem import AllChem |
| from rdkit.ML.Cluster import Butina |
| from lightning.pytorch import seed_everything |
| import torch |
| from tqdm import tqdm |
| from transformers import AutoModelForMaskedLM |
| from datasets import Dataset, DatasetDict |
| from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
| seed_everything(1986) |
|
|
| df = pd.read_csv("caco2.csv") |
|
|
| mols = [] |
| canon = [] |
| keep_rows = [] |
| bad = 0 |
|
|
| for i, smi in enumerate(df["SMILES"].astype(str)): |
| m = Chem.MolFromSmiles(smi) |
| if m is None: |
| bad += 1 |
| continue |
| smi_can = Chem.MolToSmiles(m, canonical=True, isomericSmiles=True) |
| mols.append(m) |
| canon.append(smi_can) |
| keep_rows.append(i) |
|
|
| df = df.iloc[keep_rows].reset_index(drop=True) |
| df["SMILES_CANON"] = canon |
|
|
| print(f"Invalid SMILES dropped: {bad} / {len(df) + bad}") |
|
|
| |
| dup_mask = df.duplicated(subset=["SMILES_CANON"], keep="first") |
| df = df.loc[~dup_mask].reset_index(drop=True) |
| mols = [m for m, isdup in zip(mols, dup_mask) if not isdup] |
|
|
| |
| morgan = AllChem.GetMorganGenerator(radius=2, fpSize=2048, includeChirality=True) |
| fps = [morgan.GetFingerprint(m) for m in mols] |
|
|
| |
| sim_thresh = 0.6 |
| dist_thresh = 1.0 - sim_thresh |
|
|
| dists = [] |
| n = len(fps) |
| for i in range(1, n): |
| sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]) |
| dists.extend([1.0 - x for x in sims]) |
|
|
| clusters = Butina.ClusterData(dists, nPts=n, distThresh=dist_thresh, isDistData=True) |
|
|
| cluster_ids = np.empty(n, dtype=int) |
| for cid, idxs in enumerate(clusters): |
| for idx in idxs: |
| cluster_ids[idx] = cid |
|
|
| df["cluster_id"] = cluster_ids |
|
|
| |
| train_fraction = 0.8 |
| rng = np.random.default_rng() |
|
|
| unique_clusters = df["cluster_id"].unique() |
| rng.shuffle(unique_clusters) |
|
|
| train_target = int(train_fraction * len(df)) |
| train_clusters = set() |
| count = 0 |
| for cid in unique_clusters: |
| csize = (df["cluster_id"] == cid).sum() |
| if count + csize <= train_target: |
| train_clusters.add(cid) |
| count += csize |
|
|
| df["split"] = np.where(df["cluster_id"].isin(train_clusters), "train", "val") |
|
|
| df[df["split"] == "train"].to_csv("caco2_train.csv", index=False) |
| df[df["split"] == "val"].to_csv("caco2_val.csv", index=False) |
| df.to_csv("caco2_meta_with_split.csv", index=False) |
|
|
| print(df["split"].value_counts()) |
|
|
| |
| |
| |
| MAX_LENGTH = 768 |
| BATCH_SIZE = 128 |
|
|
| TRAIN_CSV = "caco2_train.csv" |
| VAL_CSV = "caco2_val.csv" |
|
|
| SMILES_COL = "SMILES" |
| LABEL_COL = "Caco2" |
|
|
| OUT_PATH = "./Classifier_Weight/training_data_cleaned/permeability_caco2/caco2_smiles_with_embeddings" |
|
|
| |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| |
| |
| |
| print("Loading tokenizer and model...") |
| tokenizer = SMILES_SPE_Tokenizer( |
| "./Classifier_Weight/tokenizer/new_vocab.txt", |
| "./Classifier_Weight/tokenizer/new_splits.txt", |
| ) |
|
|
| embedding_model = AutoModelForMaskedLM.from_pretrained("aaronfeller/PeptideCLM-23M-all").roformer |
| embedding_model.to(device) |
| embedding_model.eval() |
|
|
| HIDDEN_KEY = "last_hidden_state" |
|
|
| def get_special_ids(tokenizer): |
| cand = [ |
| getattr(tokenizer, "pad_token_id", None), |
| getattr(tokenizer, "cls_token_id", None), |
| getattr(tokenizer, "sep_token_id", None), |
| getattr(tokenizer, "bos_token_id", None), |
| getattr(tokenizer, "eos_token_id", None), |
| getattr(tokenizer, "mask_token_id", None), |
| ] |
| special_ids = sorted({x for x in cand if x is not None}) |
| if len(special_ids) == 0: |
| print("[WARN] No special token ids found on tokenizer; pooling will only exclude padding via attention_mask.") |
| return special_ids |
|
|
| SPECIAL_IDS = get_special_ids(tokenizer) |
| SPECIAL_IDS_T = torch.tensor(SPECIAL_IDS, device=device, dtype=torch.long) if len(SPECIAL_IDS) else None |
|
|
| @torch.no_grad() |
| def embed_batch_return_both(batch_sequences, max_length, device): |
| tok = tokenizer( |
| batch_sequences, |
| return_tensors="pt", |
| padding=True, |
| max_length=max_length, |
| truncation=True, |
| ) |
| input_ids = tok["input_ids"].to(device) |
| attention_mask = tok["attention_mask"].to(device) |
|
|
| outputs = embedding_model(input_ids=input_ids, attention_mask=attention_mask) |
| last_hidden = outputs.last_hidden_state |
|
|
| valid = attention_mask.bool() |
| if SPECIAL_IDS_T is not None and SPECIAL_IDS_T.numel() > 0: |
| valid = valid & (~torch.isin(input_ids, SPECIAL_IDS_T)) |
|
|
| |
| valid_f = valid.unsqueeze(-1).float() |
| summed = torch.sum(last_hidden * valid_f, dim=1) |
| denom = torch.clamp(valid_f.sum(dim=1), min=1e-9) |
| pooled = (summed / denom).detach().cpu().numpy() |
|
|
| |
| token_emb_list = [] |
| mask_list = [] |
| lengths = [] |
| for b in range(last_hidden.shape[0]): |
| emb = last_hidden[b, valid[b]] |
| token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy()) |
| L_i = emb.shape[0] |
| lengths.append(int(L_i)) |
| mask_list.append(np.ones((L_i,), dtype=np.int8)) |
|
|
| return pooled, token_emb_list, mask_list, lengths |
|
|
| def generate_embeddings_batched_both(sequences, batch_size, max_length): |
| pooled_all = [] |
| token_emb_all = [] |
| mask_all = [] |
| lengths_all = [] |
|
|
| for i in tqdm(range(0, len(sequences), batch_size), desc="Embedding batches"): |
| batch = sequences[i:i + batch_size] |
| pooled, token_list, m_list, lens = embed_batch_return_both(batch, max_length, device) |
| pooled_all.append(pooled) |
| token_emb_all.extend(token_list) |
| mask_all.extend(m_list) |
| lengths_all.extend(lens) |
|
|
| pooled_all = np.vstack(pooled_all) |
| return pooled_all, token_emb_all, mask_all, lengths_all |
|
|
| from datasets import Dataset, DatasetDict |
|
|
| def make_split_datasets(csv_path, split_name): |
| df = pd.read_csv(csv_path) |
| df = df.dropna(subset=[SMILES_COL, LABEL_COL]).reset_index(drop=True) |
| df["sequence"] = df[SMILES_COL].astype(str) |
|
|
| labels = pd.to_numeric(df[LABEL_COL], errors="coerce") |
| df = df.loc[~labels.isna()].reset_index(drop=True) |
| sequences = df["sequence"].tolist() |
| labels = pd.to_numeric(df[LABEL_COL], errors="coerce").tolist() |
|
|
| |
| pooled_embs, token_emb_list, mask_list, lengths = generate_embeddings_batched_both( |
| sequences, batch_size=BATCH_SIZE, max_length=MAX_LENGTH |
| ) |
|
|
| pooled_ds = Dataset.from_dict({ |
| "sequence": sequences, |
| "label": labels, |
| "embedding": pooled_embs, |
| }) |
|
|
| full_ds = Dataset.from_dict({ |
| "sequence": sequences, |
| "label": labels, |
| "embedding": token_emb_list, |
| "attention_mask": mask_list, |
| "length": lengths, |
| }) |
|
|
| return pooled_ds, full_ds |
|
|
| train_pooled, train_full = make_split_datasets(TRAIN_CSV, "train") |
| val_pooled, val_full = make_split_datasets(VAL_CSV, "val") |
|
|
| ds_pooled = DatasetDict({"train": train_pooled, "val": val_pooled}) |
| ds_full = DatasetDict({"train": train_full, "val": val_full}) |
|
|
| ds_pooled.save_to_disk(OUT_PATH) |
| ds_full.save_to_disk(OUT_PATH + "_unpooled") |
|
|