| import argparse |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import torch |
| import os |
| import random |
| import requests |
| import tempfile |
| import sys |
| from Bio.PDB import PDBParser, PPBuilder |
|
|
| from esm2.modeling_fastesm import FastEsmModel |
|
|
|
|
| def download_random_pdb(): |
| """ |
| Download a random protein chain PDB file. |
| |
| Returns: |
| str: Path to the downloaded PDB file. |
| """ |
| example_pdbs = ["1AKE"] |
| |
| |
| pdb_id = random.choice(example_pdbs) |
| print(f"Selected random PDB ID: {pdb_id}") |
| |
| |
| temp_file = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) |
| temp_file_path = temp_file.name |
| temp_file.close() |
| |
| |
| url = f"https://files.rcsb.org/download/{pdb_id}.pdb" |
| response = requests.get(url) |
| |
| if response.status_code == 200: |
| with open(temp_file_path, 'wb') as f: |
| f.write(response.content) |
| print(f"Downloaded PDB file to: {temp_file_path}") |
| return temp_file_path |
| else: |
| raise Exception(f"Failed to download PDB file: {response.status_code}") |
|
|
|
|
| def parse_pdb(pdb_file): |
| """ |
| Parse a PDB file and extract the protein sequence and CA atom coordinates. |
| |
| Parameters: |
| pdb_file (str): Path to the PDB file. |
| |
| Returns: |
| tuple: (sequence (str), coords (np.ndarray of shape (L, 3))) |
| """ |
| parser = PDBParser(QUIET=True) |
| structure = parser.get_structure("protein", pdb_file) |
| ppb = PPBuilder() |
| |
| |
| for pp in ppb.build_peptides(structure): |
| sequence = str(pp.get_sequence()) |
| coords = [] |
| for residue in pp: |
| |
| if 'CA' in residue: |
| coords.append(residue['CA'].get_coord()) |
| if len(coords) == 0: |
| raise ValueError("No CA atoms found in the polypeptide.") |
| return sequence, np.array(coords) |
| |
| raise ValueError("No polypeptide chains were found in the PDB file.") |
|
|
|
|
| def compute_distance_matrix(coords): |
| """ |
| Compute the pairwise Euclidean distance matrix from a set of coordinates. |
| |
| Parameters: |
| coords (np.ndarray): Array of shape (L, 3) where L is the number of residues. |
| |
| Returns: |
| np.ndarray: A matrix of shape (L, L) containing distances. |
| """ |
| diff = coords[:, None, :] - coords[None, :, :] |
| dist_matrix = np.sqrt(np.sum(diff**2, axis=-1)) |
|
|
| return dist_matrix |
|
|
|
|
| def get_esm_contact_map(sequence): |
| """ |
| Use the ESM model to predict a contact map for the given protein sequence. |
| |
| Parameters: |
| sequence (str): Amino acid sequence. |
| |
| Returns: |
| np.ndarray: A 2D array (L x L) with contact probabilities. |
| """ |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model_path = "Synthyra/ESM2-650M" |
| model = FastEsmModel.from_pretrained(model_path).eval().to(device) |
| tokenizer = model.tokenizer |
| |
| inputs = tokenizer(sequence, return_tensors="pt") |
| inputs = {key: value.to(device) for key, value in inputs.items()} |
| with torch.no_grad(): |
| contact_map = model.predict_contacts(inputs["input_ids"], inputs["attention_mask"]) |
| print(contact_map.shape) |
| contact_map = contact_map.squeeze().cpu().numpy() |
| print(contact_map.shape) |
| return contact_map |
|
|
|
|
| def plot_maps(true_contact_map, predicted_contact_map, pdb_file): |
| """ |
| Generate two subplots: |
| 1. ESM predicted contact map. |
| 2. True contact map from the PDB (binary, thresholded). |
| |
| Parameters: |
| true_contact_map (np.ndarray): Binary (0/1) contact map from PDB. |
| predicted_contact_map (np.ndarray): Predicted contact probabilities from ESM. |
| pdb_file (str): Path to the PDB file, used to generate output filename. |
| """ |
| fig, axs = plt.subplots(1, 2, figsize=(12, 6)) |
| |
| |
| im0 = axs[0].imshow(predicted_contact_map, cmap='RdYlBu_r', aspect='equal') |
| axs[0].set_title("Predicted contact probabilities") |
| axs[0].set_xlabel("Residue index") |
| axs[0].set_ylabel("Residue index") |
| fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04) |
| |
| |
| im1 = axs[1].imshow(true_contact_map, cmap='RdYlBu_r', aspect='equal') |
| axs[1].set_title("True contacts (PDB, threshold = 8 Å)") |
| axs[1].set_xlabel("Residue index") |
| axs[1].set_ylabel("Residue index") |
| fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04) |
| |
| plt.tight_layout() |
| |
| |
| pdb_name = os.path.splitext(os.path.basename(pdb_file))[0] |
| output_file = f"contact_maps_{pdb_name}.png" |
| plt.savefig(output_file, dpi=300, bbox_inches='tight') |
| plt.close() |
|
|
|
|
| def main(): |
| |
| parser = argparse.ArgumentParser( |
| description="Extract protein sequence and compute contact maps from a PDB file using ESM predictions." |
| ) |
| parser.add_argument("--pdb_file", type=str, help="Path to the PDB file of the protein. If not provided, a random PDB will be downloaded.", default=None) |
| parser.add_argument( |
| "--threshold", |
| type=float, |
| default=8.0, |
| help="Distance threshold (in Å) for defining true contacts (default: 8.0 Å)." |
| ) |
| args = parser.parse_args() |
| |
| |
| if args.pdb_file is None: |
| pdb_file = download_random_pdb() |
| else: |
| pdb_file = args.pdb_file |
| |
| try: |
| |
| sequence, coords = parse_pdb(pdb_file) |
| print("Extracted Protein Sequence:") |
| print(sequence) |
| |
| |
| dist_matrix = compute_distance_matrix(coords) |
| |
| |
| true_contact_map = (dist_matrix < args.threshold).astype(float) |
|
|
| |
| predicted_contact_map = get_esm_contact_map(sequence) |
| |
| |
| if predicted_contact_map.shape[0] != true_contact_map.shape[0]: |
| print("Warning: The predicted contact map and true contact map have different dimensions.") |
| |
| |
| plot_maps(true_contact_map, predicted_contact_map, pdb_file) |
| |
| print(f"Contact maps saved to: contact_maps_{os.path.splitext(os.path.basename(pdb_file))[0]}.png") |
| |
| finally: |
| |
| if args.pdb_file is None and os.path.exists(pdb_file): |
| os.remove(pdb_file) |
| print(f"Removed temporary PDB file: {pdb_file}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|
|
|
|
|
|
|