| """ |
| Pipeline: |
| 1. Read *_meta_with_split.csv (sequence, label, id, split) |
| 2. Convert wt sequences to SMILES via: fasta2smi -i peptides.fasta -o peptides.p2smi |
| 3. Parse .p2smi format: "{seq}-linear: {SMILES}" |
| 4. Embed SMILES with ChemBERTa to save pooled + unpooled DatasetDicts |
| 5. Embed SMILES with PeptideCLM to save pooled + unpooled DatasetDicts |
| """ |
|
|
| import os |
| import subprocess |
| import tempfile |
| import sys |
| import numpy as np |
| import torch |
| import pandas as pd |
| from tqdm import tqdm |
| from datasets import Dataset, DatasetDict |
| from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM |
|
|
| PROJECT_ROOT = "<>" |
|
|
| |
| META_CSV = ( |
| f"{PROJECT_ROOT}/training_data_cleaned/" |
| "permeability_penetrance/permeability_meta_with_split.csv" |
| ) |
| BASE_OUT = f"{PROJECT_ROOT}/alternative_embeddings" |
|
|
| |
| CHEMBERTA_MODEL = "DeepChem/ChemBERTa-77M-MLM" |
| CHEMBERTA_OUT = f"{BASE_OUT}/permeability_chemberta/perm_smiles_with_embeddings" |
|
|
| |
| sys.path.append(PROJECT_ROOT) |
| from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
| PEPTIDECLM_MODEL = "aaronfeller/PeptideCLM-23M-all" |
| PEPTIDECLM_TOKENIZER = f"{PROJECT_ROOT}/tokenizer/new_vocab.txt" |
| PEPTIDECLM_SPLITS = f"{PROJECT_ROOT}/tokenizer/new_splits.txt" |
| PEPTIDECLM_OUT = f"{BASE_OUT}/permeability_peptideclm/perm_smiles_with_embeddings" |
|
|
| |
| SEQ_COL = "sequence" |
| LABEL_COL = "label" |
| SPLIT_COL = "split" |
| ID_COL = "id" |
|
|
| |
| FASTA2SMI_BIN = "fasta2smi" |
|
|
| |
| MAX_LENGTH_CHEMBERTA = 512 |
| MAX_LENGTH_PEPTIDECLM = 768 |
| BATCH_SIZE = 128 |
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| |
| |
| |
| def sequences_to_smiles(sequences: list[str], ids: list[str]) -> dict[str, str]: |
| """ |
| .p2smi format produced by fasta2smi: |
| MIIFAIAASHKK-linear: N[C@@H](CCSC)C(=O)... |
| KIAKLKAKIQ...-linear: N[C@@H](CCCCN)C(=O)... |
| """ |
| with tempfile.TemporaryDirectory() as tmpdir: |
| fasta_path = os.path.join(tmpdir, "peptides.fasta") |
| p2smi_path = os.path.join(tmpdir, "peptides.p2smi") |
|
|
| with open(fasta_path, "w") as fh: |
| for sid, seq in zip(ids, sequences): |
| fh.write(f">{sid}\n{seq}\n") |
|
|
| cmd = [FASTA2SMI_BIN, "-i", fasta_path, "-o", p2smi_path] |
| print(f" Running: {' '.join(cmd)}") |
| result = subprocess.run(cmd, capture_output=True, text=True) |
| if result.returncode != 0: |
| raise RuntimeError( |
| f"fasta2smi failed (exit {result.returncode}):\n" |
| f" stdout: {result.stdout}\n stderr: {result.stderr}" |
| ) |
|
|
| seq2smi = _parse_p2smi(p2smi_path) |
|
|
| n_ok = len(seq2smi) |
| n_fail = len(sequences) - n_ok |
| print(f" fasta2smi: {n_ok}/{len(sequences)} converted ({n_fail} failed/skipped)") |
| return seq2smi |
|
|
|
|
| def _parse_p2smi(path: str) -> dict[str, str]: |
| seq2smi: dict[str, str] = {} |
| with open(path) as fh: |
| for line in fh: |
| line = line.strip() |
| if not line or line.startswith("#"): |
| continue |
| |
| if "-linear: " not in line: |
| print(f" [WARN] Unexpected p2smi line, skipping: {line[:80]}") |
| continue |
| aa_seq, smi = line.split("-linear: ", maxsplit=1) |
| smi = smi.strip() |
| if smi and smi.lower() not in ("none", "null", "n/a"): |
| seq2smi[aa_seq] = smi |
| return seq2smi |
|
|
|
|
| |
| |
| |
| def _get_special_ids_tensor(tokenizer): |
| attrs = [ |
| "pad_token_id", "cls_token_id", "sep_token_id", |
| "bos_token_id", "eos_token_id", "mask_token_id", |
| ] |
| ids = sorted({getattr(tokenizer, a, None) for a in attrs} - {None}) |
| return torch.tensor(ids, device=device, dtype=torch.long) if ids else None |
|
|
|
|
| @torch.no_grad() |
| def _embed_batch(tokenizer, model, special_ids_t, sequences, max_length): |
| tok = tokenizer( |
| sequences, return_tensors="pt", |
| padding=True, max_length=max_length, truncation=True, |
| ) |
| input_ids = tok["input_ids"].to(device) |
| attention_mask = tok["attention_mask"].to(device) |
|
|
| out = model(input_ids=input_ids, attention_mask=attention_mask) |
| last_hidden = out.last_hidden_state |
|
|
| valid = attention_mask.bool() |
| if special_ids_t is not None: |
| valid = valid & (~torch.isin(input_ids, special_ids_t)) |
|
|
| valid_f = valid.unsqueeze(-1).float() |
| pooled = ( |
| torch.sum(last_hidden * valid_f, dim=1) |
| / torch.clamp(valid_f.sum(dim=1), min=1e-9) |
| ).cpu().numpy() |
|
|
| token_embs, masks, lengths = [], [], [] |
| for b in range(last_hidden.shape[0]): |
| emb = last_hidden[b, valid[b]].cpu().to(torch.float16).numpy() |
| token_embs.append(emb) |
| masks.append(np.ones(emb.shape[0], dtype=np.int8)) |
| lengths.append(emb.shape[0]) |
|
|
| return pooled, token_embs, masks, lengths |
|
|
|
|
| def _embed_all(tokenizer, model, special_ids_t, sequences, max_length): |
| pooled_all, token_all, mask_all, len_all = [], [], [], [] |
| for i in tqdm(range(0, len(sequences), BATCH_SIZE), desc=" batches"): |
| p, t, m, l = _embed_batch( |
| tokenizer, model, special_ids_t, |
| sequences[i:i+BATCH_SIZE], max_length, |
| ) |
| pooled_all.append(p) |
| token_all.extend(t) |
| mask_all.extend(m) |
| len_all.extend(l) |
| return np.vstack(pooled_all), token_all, mask_all, len_all |
|
|
|
|
| def _build_datasets(wt_seqs, smiles, labels, tokenizer, model, special_ids_t, max_length): |
| pooled, tok_embs, masks, lengths = _embed_all( |
| tokenizer, model, special_ids_t, smiles, max_length |
| ) |
| pooled_ds = Dataset.from_dict({ |
| "sequence": wt_seqs, |
| "smiles": smiles, |
| "label": labels, |
| "embedding": pooled, |
| }) |
| full_ds = Dataset.from_dict({ |
| "sequence": wt_seqs, |
| "smiles": smiles, |
| "label": labels, |
| "embedding": tok_embs, |
| "attention_mask": masks, |
| "length": lengths, |
| }) |
| return pooled_ds, full_ds |
|
|
|
|
| def _save(splits: dict, out_path: str): |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| DatasetDict({k: v[0] for k, v in splits.items()}).save_to_disk(out_path) |
| DatasetDict({k: v[1] for k, v in splits.items()}).save_to_disk(out_path + "_unpooled") |
| print(f" Saved pooled to {out_path}") |
| print(f" Saved unpooled to {out_path}_unpooled") |
|
|
|
|
| |
| |
| |
| def run_chemberta(meta: pd.DataFrame): |
| print(f"\n{'='*60}") |
| print(" Encoder: ChemBERTa") |
| print(f"{'='*60}") |
|
|
| print(f" Loading {CHEMBERTA_MODEL} ...") |
| tokenizer = AutoTokenizer.from_pretrained(CHEMBERTA_MODEL) |
| model = AutoModel.from_pretrained(CHEMBERTA_MODEL).to(device).eval() |
| special_ids_t = _get_special_ids_tensor(tokenizer) |
|
|
| splits: dict[str, tuple] = {} |
| for split_name in ["train", "val"]: |
| df = meta[meta[SPLIT_COL] == split_name].reset_index(drop=True) |
| print(f"\n [{split_name}] {len(df)} rows") |
| if df.empty: |
| print(" [WARN] Empty split, skipping.") |
| continue |
| pooled_ds, full_ds = _build_datasets( |
| df[SEQ_COL].tolist(), df["smiles"].tolist(), |
| df[LABEL_COL].tolist(), |
| tokenizer, model, special_ids_t, MAX_LENGTH_CHEMBERTA, |
| ) |
| splits[split_name] = (pooled_ds, full_ds) |
|
|
| _save(splits, CHEMBERTA_OUT) |
|
|
| |
| del model |
| torch.cuda.empty_cache() |
|
|
|
|
| |
| |
| |
| def run_peptideclm(meta: pd.DataFrame): |
| print(f"\n{'='*60}") |
| print(" Encoder: PeptideCLM") |
| print(f"{'='*60}") |
|
|
| print(f" Loading tokenizer from {PEPTIDECLM_TOKENIZER} ...") |
| tokenizer = SMILES_SPE_Tokenizer(PEPTIDECLM_TOKENIZER, PEPTIDECLM_SPLITS) |
|
|
| print(f" Loading {PEPTIDECLM_MODEL} ...") |
| full_model = AutoModelForMaskedLM.from_pretrained(PEPTIDECLM_MODEL) |
| model = full_model.roformer.to(device).eval() |
| special_ids_t = _get_special_ids_tensor(tokenizer) |
|
|
| splits: dict[str, tuple] = {} |
| for split_name in ["train", "val"]: |
| df = meta[meta[SPLIT_COL] == split_name].reset_index(drop=True) |
| print(f"\n [{split_name}] {len(df)} rows") |
| if df.empty: |
| print(" [WARN] Empty split, skipping.") |
| continue |
| pooled_ds, full_ds = _build_datasets( |
| df[SEQ_COL].tolist(), df["smiles"].tolist(), |
| df[LABEL_COL].tolist(), |
| tokenizer, model, special_ids_t, MAX_LENGTH_PEPTIDECLM, |
| ) |
| splits[split_name] = (pooled_ds, full_ds) |
|
|
| _save(splits, PEPTIDECLM_OUT) |
|
|
| del model |
| torch.cuda.empty_cache() |
|
|
|
|
| |
| |
| |
| def main(): |
| print(f"\nDevice : {device}") |
| print(f"Meta : {META_CSV}") |
|
|
| |
| meta = pd.read_csv(META_CSV, sep=None, engine="python") |
| print(f"Loaded {len(meta)} rows. Columns: {meta.columns.tolist()}") |
| for col in [SEQ_COL, LABEL_COL, SPLIT_COL]: |
| if col not in meta.columns: |
| raise ValueError(f"Expected column '{col}' not found. Available: {meta.columns.tolist()}") |
|
|
| |
| meta[LABEL_COL] = pd.to_numeric(meta[LABEL_COL], errors="coerce") |
| meta = meta.dropna(subset=[SEQ_COL, LABEL_COL]).reset_index(drop=True) |
|
|
| |
| if ID_COL in meta.columns: |
| ids = meta[ID_COL].astype(str).tolist() |
| else: |
| ids = [f"seq_{i}" for i in range(len(meta))] |
|
|
| |
| |
| print("\nConverting peptide sequences to SMILES ...") |
| seqs = meta[SEQ_COL].astype(str).tolist() |
| seq2smi = sequences_to_smiles(seqs, ids) |
|
|
| meta["smiles"] = meta[SEQ_COL].astype(str).map(seq2smi) |
| n_missing = meta["smiles"].isna().sum() |
| if n_missing: |
| print(f" [WARN] {n_missing} sequences had no SMILES — dropping.") |
| meta = meta.dropna(subset=["smiles"]).reset_index(drop=True) |
| print(f" Retained {len(meta)} rows with valid SMILES.") |
| |
| smiles_meta_path = os.path.join(BASE_OUT, "permeability_smiles_meta_with_split.csv") |
| os.makedirs(BASE_OUT, exist_ok=True) |
| meta.to_csv(smiles_meta_path, index=False) |
| print(f" Saved SMILES meta to {smiles_meta_path}") |
| |
| |
| |
| |
|
|
| print("\nAll done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|