| from __future__ import annotations |
|
|
| import json |
| import h5py |
| import traceback |
| import numpy as np |
| import traceback |
| from tqdm import tqdm |
| from pathlib import Path |
| from typing import Optional, Union, BinaryIO, TextIO |
| from dataclasses import dataclass |
| from scipy.spatial.distance import cdist |
|
|
| import torch |
|
|
| |
| from esm.utils import residue_constants as RC |
| from esm.utils.structure.protein_chain import ProteinChain |
|
|
| |
| import biotite.structure as bs |
| from biotite.database import rcsb |
| from biotite.structure.io.pdb import PDBFile |
| from biotite.structure import annotate_sse |
|
|
| from cloudpathlib import CloudPath |
| from Bio.Data import PDBData |
|
|
| import py3Dmol |
|
|
| |
| from ..utils.constants import BASE_DIR |
| from ..utils.loading import load_epitopes_csv, load_epitopes_csv_single, load_species |
| from .pc import AMINO_ACID_1TO3, AMINO_ACID_3TO1, MAX_ASA |
| from ..model.ReCEP import ReCEP |
| from ..data.utils import create_graph_data |
|
|
|
|
| PathOrBuffer = Union[str, Path, BinaryIO, TextIO] |
|
|
| @dataclass |
| class AntigenChain(ProteinChain): |
| """ |
| Extended ProteinChain class that adds additional functionalities, |
| such as computing surface residues based on SASA and maxASA constants. |
| """ |
| def __post_init__(self, token: Optional[str] = "1mzAo8l1uxaU8UfVcGgV7B"): |
| super().__post_init__() |
| |
| |
| self.resnum_to_index = {int(rnum): i for i, rnum in enumerate(self.residue_index)} |
| |
| |
| self.epitopes = self.get_epitopes() |
| |
| |
| self.token = token |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| @staticmethod |
| def convert_letter_1to3(letter: str) -> str: |
| """ |
| Convert a one-letter amino acid code to its corresponding three-letter code. |
| |
| Args: |
| letter (str): A single-character amino acid code (e.g., "A"). |
| |
| Returns: |
| str: The corresponding three-letter code (e.g., "ALA"). |
| Returns "UNK" if the code is not recognized. |
| """ |
| return AMINO_ACID_1TO3.get(letter.upper(), "UNK") |
|
|
| @staticmethod |
| def convert_letter_3to1(three_letter: str) -> str: |
| """ |
| Convert a three-letter amino acid code to its corresponding one-letter code. |
| |
| Args: |
| three_letter (str): A three-letter amino acid code (e.g., "ALA"). |
| |
| Returns: |
| str: The corresponding one-letter code (e.g., "A"). |
| Returns "X" if the code is not recognized. |
| """ |
| return AMINO_ACID_3TO1.get(three_letter.upper(), "X") |
| |
| def get_species(self) -> str: |
| """ |
| Get the species of the antigen. |
| """ |
| from ..utils.tools import get_chain_organism |
| |
| species_dict = load_species() |
| if self.id in species_dict: |
| species = species_dict[self.id]['classification'] |
| else: |
| try: |
| species = get_chain_organism(self.id, self.chain_id) |
| species_dict[self.id] = {'classification': species} |
| |
| |
| species_file_path = Path(f"{BASE_DIR}/data/species.json") |
| species_file_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| with open(species_file_path, "w") as f: |
| json.dump(species_dict, f, indent=2) |
| except Exception as e: |
| print(f"[ERROR] Failed to get species for {self.id}_{self.chain_id}: {str(e)}") |
| species = "Unknown" |
| return species |
| |
| def get_backbone_atoms(self) -> np.ndarray: |
| """ |
| Get backbone atom coordinates in the order: CA, C, N. |
| |
| Returns: |
| np.ndarray: [L, 3, 3] array where [:, 0] is CA, [:, 1] is C, [:, 2] is N. |
| """ |
| file = Path(f"{BASE_DIR}/data/coords/{self.id}_{self.chain_id}.npy") |
| |
| if file.exists(): |
| return np.load(file) |
| else: |
| idx_CA = RC.atom_order["CA"] |
| idx_C = RC.atom_order["C"] |
| idx_N = RC.atom_order["N"] |
|
|
| backbone_atoms = self.atom37_positions[:, [idx_N, idx_CA, idx_C], :] |
| |
| |
| file.parent.mkdir(parents=True, exist_ok=True) |
| np.save(file, backbone_atoms) |
| return backbone_atoms |
| |
| def get_secondary_structure(self) -> np.ndarray: |
| """ |
| Get secondary structure information using numpy operations. |
| """ |
| try: |
| ss3_arr = annotate_sse(self.atom_array) |
| biotite_ss3_str = "".join(ss3_arr) |
| |
| if len(biotite_ss3_str) != len(self.sequence): |
| print(f"[WARNING] Secondary structure prediction length ({len(biotite_ss3_str)}) " |
| f"doesn't match sequence length ({len(self.sequence)}) " |
| f"for protein {self.id}_{self.chain_id}") |
| return None |
| |
| translation_table = str.maketrans({ |
| "a": "H", |
| "b": "E", |
| "c": "C", |
| }) |
| return biotite_ss3_str.translate(translation_table) |
| |
| except Exception as e: |
| print(f"[ERROR] Failed to predict secondary structure for " |
| f"{self.id}_{self.chain_id}: {str(e)}") |
| return None |
| |
| def get_ss_onehot(self) -> np.ndarray: |
| """ |
| Get one-hot encoded secondary structure information using numpy operations. |
| Only encode H (helix) and E (sheet), as C (coil) can be inferred. |
| |
| Returns: |
| np.ndarray: One-hot encoded secondary structure array of shape (seq_len, 2) |
| where 2 represents [H, E] (Helix, Sheet) |
| """ |
| self.secondary_structure = self.get_secondary_structure() |
| seq_len = len(self.secondary_structure) |
| ss_onehot = np.zeros((seq_len, 2), dtype=np.float32) |
| |
| |
| ss_array = np.array(list(self.secondary_structure)) |
| ss_onehot[:, 0] = (ss_array == 'H') |
| ss_onehot[:, 1] = (ss_array == 'E') |
| |
| return ss_onehot |
|
|
| def get_rsa(self) -> np.ndarray: |
| """ |
| Calculate relative solvent accessibility (RSA) for all residues. |
| RSA is the ratio of SASA to maximum ASA for each residue. |
| |
| Returns: |
| np.ndarray: An array of RSA values for each residue in the sequence. |
| """ |
| |
| cache_file = Path(BASE_DIR) / "data" / "rsa" / f"{self.id}_{self.chain_id}.npy" |
| if cache_file.exists(): |
| return np.load(cache_file) |
| |
| sasa_values = self.sasa() |
| rsa_values = np.zeros(len(self.sequence), dtype=np.float32) |
| |
| |
| for i, (letter, sasa) in enumerate(zip(self.sequence, sasa_values)): |
| three_letter = self.convert_letter_1to3(letter) |
| max_asa = MAX_ASA.get(three_letter) |
| if max_asa is not None and max_asa != 0: |
| rsa_values[i] = sasa / max_asa |
| |
| |
| cache_file.parent.mkdir(parents=True, exist_ok=True) |
| np.save(cache_file, rsa_values) |
| |
| return rsa_values |
|
|
| def get_surface_residues(self, threshold: float = 0.25) -> list: |
| """ |
| Identify surface-exposed residues using RSA values. |
| |
| A residue is considered surface-exposed if its RSA value |
| is at least `threshold`. |
| |
| Args: |
| threshold (float): The minimum RSA value required to consider |
| the residue as surface-exposed. |
| |
| Returns: |
| tuple: A tuple of two lists, where the first list contains residue numbers (from the PDB) that are surface-exposed, |
| and the second list contains the indices of the surface residues in the sequence. |
| """ |
| rsa_values = self.get_rsa() |
| surface_residue_numbers = [] |
| surface_residue_indices = [] |
| |
| |
| for idx, rsa in enumerate(rsa_values): |
| if rsa >= threshold: |
| surface_residue_numbers.append(int(self.residue_index[idx])) |
| surface_residue_indices.append(idx) |
| |
| return surface_residue_numbers, surface_residue_indices |
| |
| def get_epitopes(self, threshold: float = 0.25) -> np.ndarray: |
| """ |
| Retrieve epitopes for this chain as a boolean array. |
| |
| Args: |
| threshold (float): SASA threshold for determining surface residues. |
| |
| Returns: |
| np.ndarray: A boolean array of length L (sequence length) where True indicates |
| epitope positions and False indicates non-epitope positions. |
| Only surface-exposed residues can be True. |
| """ |
| _, _, epitopes = load_epitopes_csv() |
|
|
| if f'{self.id}_{self.chain_id}' in epitopes: |
| binary_labels = epitopes.get(f'{self.id}_{self.chain_id}', [0] * len(self.sequence)) |
| else: |
| print(f"[WARNING] Epitopes not found for {self.id}_{self.chain_id}. Use single epitopes.") |
| binary_labels = self.get_epitopes_single() |
| |
| |
| epitope_array = np.zeros(len(self.sequence), dtype=bool) |
| |
| |
| if binary_labels is not None and len(binary_labels) > 0: |
| |
| if len(binary_labels) == len(self.sequence): |
| epitope_array = np.array(binary_labels, dtype=bool) |
| else: |
| print(f"[WARNING] Binary labels length ({len(binary_labels)}) doesn't match " |
| f"sequence length ({len(self.sequence)}) for {self.id}_{self.chain_id}") |
| return epitope_array |
| |
| if threshold == 0.0: |
| return epitope_array |
| |
| |
| _, surface_indices = self.get_surface_residues(threshold=threshold) |
| |
| |
| surface_mask = np.zeros(len(self.sequence), dtype=bool) |
| for res_idx in surface_indices: |
| if 0 <= res_idx < len(self.sequence): |
| surface_mask[res_idx] = True |
| |
| |
| epitope_array = epitope_array & surface_mask |
| |
| return epitope_array |
|
|
| def get_epitopes_single(self) -> np.ndarray: |
| """ |
| Retrieve epitopes for this chain as a boolean array. |
| """ |
| _, _, epitopes = load_epitopes_csv_single() |
| |
| |
| possible_keys = [ |
| f'{self.id.upper()}_{self.chain_id}', |
| f'{self.id}_{self.chain_id}', |
| f'{self.id.lower()}_{self.chain_id}' |
| ] |
| |
| epitopes_resnums = None |
| for key in possible_keys: |
| if key in epitopes: |
| epitopes_resnums = epitopes.get(key) |
| break |
| |
| if epitopes_resnums is not None: |
| epitope_array = np.zeros(len(self.sequence), dtype=int) |
| for resnum in epitopes_resnums: |
| if resnum in self.resnum_to_index: |
| epitope_array[self.resnum_to_index[resnum]] = 1 |
| return epitope_array |
| else: |
| print(f"[WARNING] Single Epitopes not found for {self.id}_{self.chain_id}. Use no epitopes.") |
| epitope_array = np.zeros(len(self.sequence), dtype=int) |
| |
| return epitope_array |
| |
| def get_epitope_residue_numbers(self) -> list: |
| """ |
| Get epitope residue numbers from the boolean epitope array. |
| |
| Returns: |
| list: List of residue numbers that are epitopes. |
| """ |
| epitope_indices = np.where(self.epitopes)[0] |
| epitope_residue_numbers = [int(self.residue_index[idx]) for idx in epitope_indices] |
| return epitope_residue_numbers |
|
|
| def get_embeddings(self, override: bool = False, encoder: str = "esmc") -> np.ndarray: |
| """ |
| Retrieve or compute per-residue (full) ESM-C embeddings. |
| |
| Returns: |
| np.ndarray: Array of shape (seq_len, embed_dim), dtype float32. |
| """ |
| full_file = Path(BASE_DIR) / "data" / "embeddings" / f"{encoder}" / f"{self.id}_{self.chain_id}.h5" |
|
|
| if full_file.exists() and not override: |
| with h5py.File(full_file, "r") as h5f: |
| full_embedding = h5f["embedding"][:] |
| else: |
| if encoder == "esmc": |
| if self.token is None: |
| raise ValueError("ESM token is not set. Please go to https://forge.evolutionaryscale.ai/ to get a token.") |
| |
| else: |
| print(f"[INFO] Generating with ESM-C...") |
| |
| from esm.sdk.api import ESMProtein, LogitsConfig |
| from esm.sdk.forge import ESM3ForgeInferenceClient |
|
|
| token = self.token |
| model = ESM3ForgeInferenceClient( |
| model="esmc-6b-2024-12", |
| url="https://forge.evolutionaryscale.ai", |
| token=token |
| ) |
| config = LogitsConfig(sequence=True, return_embeddings=True) |
|
|
| sequence = self.sequence[:2046] |
| protein = ESMProtein(sequence) |
| protein_tensor = model.encode(protein) |
| output = model.logits(protein_tensor, config) |
| full_embedding = output.embeddings.squeeze(0)[1:-1, :].to(torch.float32).cpu().numpy() |
|
|
| full_file.parent.mkdir(parents=True, exist_ok=True) |
| with h5py.File(full_file, "w") as h5f: |
| h5f.create_dataset("embedding", data=full_embedding, compression="gzip") |
| |
| elif encoder == "esm2": |
| model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t33_650M_UR50D") |
| batch_converter = alphabet.get_batch_converter() |
| model.eval() |
| data = [ |
| ("antigen", self.sequence[:2046]) |
| ] |
| batch_labels, batch_strs, batch_tokens = batch_converter(data) |
| batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) |
| model.to(self.device) |
| batch_tokens = batch_tokens.to(self.device) |
| with torch.no_grad(): |
| results = model(batch_tokens, repr_layers=[33], return_contacts=True) |
| token_representations = results["representations"][33] |
| full_embedding = token_representations.squeeze(0)[1:-1, :].to(torch.float32).cpu().numpy() |
| |
| full_file.parent.mkdir(parents=True, exist_ok=True) |
| with h5py.File(full_file, "w") as h5f: |
| h5f.create_dataset("embedding", data=full_embedding, compression="gzip") |
|
|
| return full_embedding |
| |
| def _scan_surface_residues(self, radius: float, threshold: float = 0.25) -> tuple: |
| """ |
| Helper function to compute the surface coverage for each surface residue. |
| For each surface residue, using its C_alpha coordinate as the center of a sphere with |
| radius `radius`, determine which surface residues are covered. |
| |
| Args: |
| radius (float): The radius of the sphere (in Ångstroms) |
| threshold (float): Fraction of maximum ASA to define a residue as surface-exposed |
| |
| Returns: |
| tuple: |
| - coverage (dict): Mapping from center residue index to: |
| (list[int]): List of covered residue indices |
| (list[int]): List of covered epitope residue indices |
| (float): Precision |
| (float): Recall |
| - max_recall_res (int): Center residue index with highest recall |
| - max_precision_res (int): Center residue index with highest precision |
| """ |
| |
| if radius <= 0: |
| raise ValueError("Radius must be positive") |
| if threshold < 0 or threshold > 1: |
| raise ValueError("Threshold must be between 0 and 1") |
|
|
| |
| surface_res_nums, surface_indices = self.get_surface_residues(threshold=threshold) |
| |
| |
| valid_surface_indices = [ |
| idx for idx in surface_indices |
| if 0 <= idx < len(self.sequence) |
| ] |
| valid_surface_res_nums = [ |
| surface_res_nums[surface_indices.index(idx)] |
| for idx in valid_surface_indices |
| ] |
| |
| if not valid_surface_indices: |
| return {}, None, None |
|
|
| |
| all_atoms = [] |
| all_res_indices = [] |
| for idx in valid_surface_indices: |
| mask = self.atom37_mask[idx] |
| coords = self.atom37_positions[idx][mask] |
| if len(coords) > 0: |
| all_atoms.append(coords) |
| all_res_indices.extend([idx] * len(coords)) |
| |
| if not all_atoms: |
| return {idx: ([], [], 0.0, 0.0) for idx in valid_surface_indices}, None, None |
| |
| all_atoms = np.vstack(all_atoms).astype(np.float32) |
| all_res_indices = np.array(all_res_indices) |
| |
| |
| surface_ca = [] |
| valid_center_indices = [] |
| ca_idx = RC.atom_order["CA"] |
| |
| for idx in valid_surface_indices: |
| |
| ca_coord = self.atom37_positions[idx, ca_idx, :] |
| if not np.any(np.isnan(ca_coord)) and self.atom37_mask[idx, ca_idx]: |
| surface_ca.append(ca_coord) |
| valid_center_indices.append(idx) |
| |
| if not surface_ca: |
| return {}, None, None |
| |
| surface_ca = np.array(surface_ca, dtype=np.float32) |
| surface_ca = surface_ca.reshape(-1, 3) |
| |
| |
| try: |
| dist_matrix = cdist(surface_ca, all_atoms) |
| except ValueError as e: |
| print(f"Error in distance calculation: {e}") |
| print(f"surface_ca shape: {surface_ca.shape}") |
| print(f"all_atoms shape: {all_atoms.shape}") |
| return {}, None, None |
| |
| max_recall = -1 |
| max_recall_res = None |
| max_precision = -1 |
| max_precision_res = None |
|
|
| coverage = {} |
| epitope_indices = np.where(self.epitopes)[0] |
| if len(epitope_indices) == 0: |
| print(f"No epitopes records for protein {self.id}_{self.chain_id}") |
| |
| for i, center_idx in enumerate(valid_center_indices): |
| within_radius = dist_matrix[i] < radius |
| covered_indices = np.unique(all_res_indices[within_radius]) |
| covered_indices_list = covered_indices.tolist() |
| |
| |
| covered_epitope_indices = list(set(covered_indices_list).intersection(set(epitope_indices))) |
|
|
| |
| precision = len(covered_epitope_indices) / len(covered_indices_list) if covered_indices_list else 0.0 |
| recall = len(covered_epitope_indices) / len(epitope_indices) if len(epitope_indices) > 0 else 0.0 |
|
|
| if recall > max_recall: |
| max_recall = recall |
| max_recall_res = center_idx |
| if precision > max_precision: |
| max_precision = precision |
| max_precision_res = center_idx |
| |
| |
| coverage[int(center_idx)] = ( |
| [int(idx) for idx in covered_indices_list], |
| [int(idx) for idx in covered_epitope_indices], |
| float(precision), |
| float(recall) |
| ) |
|
|
| return coverage, max_recall_res, max_precision_res |
| |
| def get_surface_coverage(self, radius: float = 18, |
| threshold: float = 0.25, |
| index: bool = True, |
| override: bool = False) -> tuple: |
| """ |
| Retrieve (or compute and cache) the coverage mapping for surface residues. |
| For each surface residue, using its C_alpha as the sphere center (with radius `radius`), |
| determine which surface residues are covered (i.e. if any atom falls within that sphere). |
| The result is cached to an HDF5 file for faster subsequent retrieval. |
| |
| The cache file is saved in BASE_DIR / "data/antigen_sphere", with the file name |
| "{self.id}_{self.chain_id}.h5", and radius as the first-level key. |
| |
| Args: |
| radius (float): The radius of the sphere (in Ångstroms). |
| threshold (float): Fraction of maximum ASA to define a residue as surface-exposed. |
| index (bool): If True, return indices instead of residue numbers for easier embeddings/coords access. |
| override (bool): If True, recompute even if cache exists. |
| |
| Returns: |
| tuple: |
| - coverage (dict): A dictionary mapping each surface residue to a tuple of: |
| If index=True: center_index -> (list[int]): List of covered residue indices |
| (list[int]): List of covered epitope residue indices |
| (float): Precision |
| (float): Recall |
| If index=False: center_residue_num -> (list[int]): List of covered residue numbers |
| (list[int]): List of covered epitope residue numbers |
| (float): Precision |
| (float): Recall |
| - max_recall_res (int): The surface residue number with the highest recall. |
| - max_precision_res (int): The surface residue number with the highest precision. |
| """ |
| |
| cache_dir = BASE_DIR / "data" / "antigen_sphere" |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| cache_filename = f"{self.id}_{self.chain_id}.h5" |
| cache_path = cache_dir / cache_filename |
| radius_key = f"r{radius}" |
| |
| |
| if cache_path.exists() and not override: |
| try: |
| with h5py.File(cache_path, "r") as h5f: |
| if radius_key in h5f: |
| |
| radius_group = h5f[radius_key] |
| |
| if index: |
| |
| coverage = {} |
| for center_idx_str in radius_group.keys(): |
| center_idx = int(center_idx_str) |
| center_group = radius_group[center_idx_str] |
| covered_indices = center_group['covered_indices'][:].tolist() |
| covered_epitope_indices = center_group['covered_epitope_indices'][:].tolist() |
| precision = float(center_group.attrs['precision']) |
| recall = float(center_group.attrs['recall']) |
| coverage[center_idx] = (covered_indices, covered_epitope_indices, precision, recall) |
| return coverage, None, None |
| else: |
| |
| coverage = {} |
| max_recall = -1 |
| max_recall_res = None |
| max_precision = -1 |
| max_precision_res = None |
| |
| for center_idx_str in radius_group.keys(): |
| center_idx = int(center_idx_str) |
| center_res_num = int(self.residue_index[center_idx]) |
| center_group = radius_group[center_idx_str] |
| |
| covered_indices = center_group['covered_indices'][:].tolist() |
| covered_epitope_indices = center_group['covered_epitope_indices'][:].tolist() |
| precision = float(center_group.attrs['precision']) |
| recall = float(center_group.attrs['recall']) |
| |
| |
| covered_res_nums = [int(self.residue_index[idx]) for idx in covered_indices if 0 <= idx < len(self.residue_index)] |
| covered_epitope_res_nums = [int(self.residue_index[idx]) for idx in covered_epitope_indices if 0 <= idx < len(self.residue_index)] |
| |
| coverage[center_res_num] = (covered_res_nums, covered_epitope_res_nums, precision, recall) |
| |
| if recall > max_recall: |
| max_recall = recall |
| max_recall_res = center_res_num |
| if precision > max_precision: |
| max_precision = precision |
| max_precision_res = center_res_num |
|
|
| return coverage, max_recall_res, max_precision_res |
| except (OSError, KeyError, ValueError) as e: |
| print(f"[WARNING] Error reading cache file {cache_path}: {e}") |
| print(f"[INFO] Recomputing surface coverage...") |
| |
| |
| coverage, max_recall_res, max_precision_res = self._scan_surface_residues(radius, threshold) |
| |
| |
| |
| with h5py.File(cache_path, "a") as h5f: |
| |
| if radius_key in h5f: |
| del h5f[radius_key] |
| |
| radius_group = h5f.create_group(radius_key) |
| |
| |
| for center_idx, (covered_indices, covered_epitope_indices, precision, recall) in coverage.items(): |
| center_group = radius_group.create_group(str(center_idx)) |
| center_group.create_dataset('covered_indices', data=np.array(covered_indices, dtype=np.int32), compression='gzip') |
| center_group.create_dataset('covered_epitope_indices', data=np.array(covered_epitope_indices, dtype=np.int32), compression='gzip') |
| center_group.attrs['precision'] = precision |
| center_group.attrs['recall'] = recall |
| |
| |
| if not index: |
| coverage_resnums = {} |
| max_recall_res_num = None |
| max_precision_res_num = None |
| |
| if max_recall_res is not None: |
| max_recall_res_num = int(self.residue_index[max_recall_res]) |
| if max_precision_res is not None: |
| max_precision_res_num = int(self.residue_index[max_precision_res]) |
| |
| for center_idx, (covered_indices, covered_epitope_indices, precision, recall) in coverage.items(): |
| center_res_num = int(self.residue_index[center_idx]) |
| |
| covered_res_nums = [int(self.residue_index[idx]) for idx in covered_indices if 0 <= idx < len(self.residue_index)] |
| covered_epitope_res_nums = [int(self.residue_index[idx]) for idx in covered_epitope_indices if 0 <= idx < len(self.residue_index)] |
| coverage_resnums[center_res_num] = (covered_res_nums, covered_epitope_res_nums, precision, recall) |
| |
| return coverage_resnums, max_recall_res_num, max_precision_res_num |
| |
| return coverage, max_recall_res, max_precision_res |
| |
| def data_preparation(self, radius: float = None, encoder: str = "esmc", override: bool = False): |
| """ |
| Retrieve or compute region embeddings for surface residues using spherical regions. |
| |
| Args: |
| radius (float): Radius to define the neighborhood of each center residue. |
| threshold (float): Threshold to determine surface residues. |
| cover (bool): Whether to recompute and overwrite cached data. |
| verbose (bool): Whether to print progress information. |
| |
| Returns: |
| tuple: |
| - embeddings (np.ndarray): Array of embeddings mean of the region. (num_regions, embedding_dim) |
| - center_residues (np.ndarray): Array of center residue numbers. (num_regions,) |
| - precisions (np.ndarray): Array of precision values for each center residue. (num_regions,) |
| - recalls (np.ndarray): Array of recall values for each center residue. (num_regions,) |
| """ |
| embeddings = self.get_embeddings(encoder=encoder) |
| backbone_atoms = self.get_backbone_atoms() |
| rsa = self.get_rsa() |
| if radius is None: |
| |
| for i in range(16,21,2): |
| _, _, _ = self.get_surface_coverage(radius=i, override=override) |
| return embeddings, backbone_atoms, rsa, None |
| else: |
| coverage_dict, _, _ = self.get_surface_coverage(radius=radius, override=override) |
| return embeddings, backbone_atoms, rsa, coverage_dict |
| |
| def evaluate(self, model_path: str = None, device_id: int = 1, radius: float = 19.0, k: int = 7, |
| threshold: float = None, verbose: bool = True, encoder: str = "esmc", use_gpu: bool = True): |
| """ |
| Evaluate epitopes using ReCEP model with spherical regions. |
| |
| Args: |
| model_path (str): Path to the trained ReCEP model |
| device_id (int): GPU device ID to use |
| radius (float): Radius for spherical regions |
| k (int): Number of top regions to select |
| threshold (float): Threshold for node-level epitope prediction |
| verbose (bool): Whether to print progress information |
| |
| Returns: |
| dict: Dictionary containing: |
| - 'predicted_epitopes': List of predicted epitope residue numbers |
| - 'true_epitopes': Set of true epitope residue numbers |
| - 'precision': Final prediction precision |
| - 'recall': Final prediction recall |
| - 'top_k_regions': Information about selected regions |
| """ |
| |
| if use_gpu and torch.cuda.is_available() and device_id >= 0: |
| device = torch.device(f"cuda:{device_id}") |
| else: |
| device = torch.device("cpu") |
| if verbose: |
| print(f"[INFO] Using device: {device}") |
| |
| |
| try: |
| if model_path is None: |
| model_path = f"{BASE_DIR}/models/ReCEP/20250626_110438/best_mcc_model.bin" |
| |
| if threshold is None: |
| model, threshold = ReCEP.load(model_path, device=device, strict=False, verbose=False) |
| else: |
| model, _ = ReCEP.load(model_path, device=device, strict=False, verbose=False) |
| |
| model.eval() |
| if verbose: |
| print(f"[INFO] Loaded ReCEP model from {model_path}") |
| except Exception as e: |
| if verbose: |
| print(f"[ERROR] Failed to load model: {str(e)}") |
| return {} |
| |
| |
| try: |
| embeddings, backbone_atoms, rsa, coverage_dict = self.data_preparation(radius=radius, encoder=encoder) |
| if verbose: |
| print(f"[INFO] Retrieved protein data for {len(coverage_dict)} surface regions") |
| except Exception as e: |
| if verbose: |
| print(f"[ERROR] Failed to prepare data: {str(e)}") |
| traceback.print_exc() |
| return {} |
| |
| if not coverage_dict: |
| if verbose: |
| print("[WARNING] No surface regions found") |
| return {} |
| |
| |
| epitope_indices = np.where(self.epitopes)[0].tolist() |
| |
| |
| region_predictions = [] |
| |
| with torch.no_grad(): |
| for center_idx, (covered_indices, covered_epitope_indices, precision, recall) in tqdm( |
| coverage_dict.items(), desc="Predicting region values", disable=not verbose): |
| |
| if len(covered_indices) < 2: |
| continue |
| |
| try: |
| |
| graph_data = create_graph_data( |
| center_idx=center_idx, |
| covered_indices=covered_indices, |
| covered_epitope_indices=covered_epitope_indices, |
| embeddings=embeddings, |
| backbone_atoms=backbone_atoms, |
| rsa_values=rsa, |
| epitope_indices=epitope_indices, |
| recall=recall, |
| precision=precision, |
| pdb_id=self.id, |
| chain_id=self.chain_id, |
| verbose=True |
| ) |
| |
| if graph_data is None: |
| if verbose: |
| print(f"[WARNING] Failed to create graph data for region {center_idx}") |
| continue |
| |
| |
| graph_data = graph_data.to(device) |
| |
| |
| graph_data.batch = torch.zeros(graph_data.num_nodes, dtype=torch.long, device=device) |
| |
| |
| outputs = model(graph_data) |
| |
| |
| if 'global_pred' in outputs: |
| graph_pred = torch.sigmoid(outputs['global_pred']).cpu().item() |
| else: |
| |
| node_preds = torch.sigmoid(outputs['node_preds']).cpu().numpy() |
| graph_pred = float(np.mean(node_preds)) |
| |
| region_predictions.append({ |
| 'center_idx': center_idx, |
| 'covered_indices': covered_indices, |
| 'covered_epitope_indices': covered_epitope_indices, |
| 'graph_pred': graph_pred, |
| 'true_recall': recall, |
| 'graph_data': graph_data |
| }) |
| |
| except Exception as e: |
| if verbose: |
| print(f"[WARNING] Error processing region {center_idx}: {str(e)}") |
| traceback.print_exc() |
| continue |
| |
| if not region_predictions: |
| if verbose: |
| print("[WARNING] No valid region predictions") |
| return {} |
| |
| |
| region_predictions.sort(key=lambda x: x['graph_pred'], reverse=True) |
| top_k_regions = region_predictions[:k] |
| |
| if verbose: |
| print(f"[INFO] Selected top {len(top_k_regions)} regions:") |
| for i, region in enumerate(top_k_regions): |
| print(f" Region {i+1}: center={region['center_idx']}, " |
| f"predicted_value={region['graph_pred']:.3f}, " |
| f"true_recall={region['true_recall']:.3f}") |
| |
| |
| residue_votes = {} |
| residue_probs = {} |
| |
| with torch.no_grad(): |
| for region in tqdm(top_k_regions, desc="Predicting node values", disable=not verbose): |
| try: |
| graph_data = region['graph_data'] |
| |
| |
| if not hasattr(graph_data, 'batch') or graph_data.batch is None: |
| graph_data.batch = torch.zeros(graph_data.num_nodes, dtype=torch.long, device=device) |
| |
| |
| outputs = model(graph_data) |
| |
| |
| node_preds = torch.sigmoid(outputs['node_preds']).cpu().numpy() |
| |
| |
| for local_idx, residue_idx in enumerate(region['covered_indices']): |
| if residue_idx not in residue_votes: |
| residue_votes[residue_idx] = [] |
| residue_probs[residue_idx] = [] |
| |
| |
| prob = float(node_preds[local_idx]) |
| residue_probs[residue_idx].append(prob) |
| |
| |
| vote = 1 if prob >= threshold else 0 |
| residue_votes[residue_idx].append(vote) |
| |
| except Exception as e: |
| if verbose: |
| print(f"[WARNING] Error in node prediction for region {region['center_idx']}: {str(e)}") |
| traceback.print_exc() |
| continue |
| |
| |
| all_residue_predictions = {} |
| for idx in range(len(self.residue_index)): |
| residue_num = int(self.residue_index[idx]) |
| if idx in residue_probs: |
| |
| all_residue_predictions[residue_num] = float(np.mean(residue_probs[idx])) |
| else: |
| |
| all_residue_predictions[residue_num] = 1e-2 |
| |
| |
| voted_epitope_indices = [] |
| for residue_idx, votes in residue_votes.items(): |
| |
| if sum(votes) >= len(votes) / 2: |
| voted_epitope_indices.append(residue_idx) |
| |
| |
| voted_epitope_resnums = [int(self.residue_index[idx]) for idx in voted_epitope_indices |
| if 0 <= idx < len(self.residue_index)] |
| |
| |
| predicted_epitope_resnums = [] |
| for residue_num, prob in all_residue_predictions.items(): |
| if prob >= threshold: |
| predicted_epitope_resnums.append(residue_num) |
| |
| |
| true_epitope_resnums = set(self.get_epitope_residue_numbers()) |
| |
| |
| |
| voted_tp = len(set(voted_epitope_resnums) & true_epitope_resnums) |
| voted_precision = voted_tp / len(voted_epitope_resnums) if voted_epitope_resnums else 0 |
| voted_recall = voted_tp / len(true_epitope_resnums) if true_epitope_resnums else 0 |
| |
| |
| predicted_tp = len(set(predicted_epitope_resnums) & true_epitope_resnums) |
| predicted_precision = predicted_tp / len(predicted_epitope_resnums) if predicted_epitope_resnums else 0 |
| predicted_recall = predicted_tp / len(true_epitope_resnums) if true_epitope_resnums else 0 |
| |
| if verbose: |
| print(f"\n[INFO] Final Results:") |
| print(f" True epitopes: {len(true_epitope_resnums)}") |
| print(f" Residues in top-k regions: {len(residue_probs)}/{len(self.residue_index)}") |
| print(f"\n Voting-based prediction:") |
| print(f" Voted epitopes: {len(voted_epitope_resnums)}") |
| print(f" Voted precision: {voted_precision:.3f}") |
| print(f" Voted recall: {voted_recall:.3f}") |
| print(f"\n Probability-based prediction (threshold={threshold}):") |
| print(f" Predicted epitopes: {len(predicted_epitope_resnums)}") |
| print(f" Predicted precision: {predicted_precision:.3f}") |
| print(f" Predicted recall: {predicted_recall:.3f}") |
| |
| return { |
| 'predicted_epitopes': predicted_epitope_resnums, |
| 'voted_epitopes': voted_epitope_resnums, |
| 'true_epitopes': true_epitope_resnums, |
| 'predicted_precision': predicted_precision, |
| 'predicted_recall': predicted_recall, |
| 'voted_precision': voted_precision, |
| 'voted_recall': voted_recall, |
| 'predictions': all_residue_predictions, |
| 'top_k_regions': [ |
| { |
| 'center_residue': int(self.residue_index[region['center_idx']]), |
| 'center_idx': region['center_idx'], |
| 'predicted_value': region['graph_pred'], |
| 'true_recall': region['true_recall'], |
| 'covered_residues': [int(self.residue_index[idx]) for idx in region['covered_indices']] |
| } |
| for region in top_k_regions |
| ], |
| 'residue_votes': { |
| int(self.residue_index[idx]): votes |
| for idx, votes in residue_votes.items() |
| if 0 <= idx < len(self.residue_index) |
| } |
| } |
| |
| def predict(self, model_path: str = None, device_id: int = 1, radius: float = 19.0, k: int = 7, |
| threshold: float = None, verbose: bool = True, encoder: str = "esmc", use_gpu: bool = True, |
| auto_cleanup: bool = False): |
| """ |
| Predict epitopes using ReCEP model with spherical regions (for unknown true epitopes). |
| |
| Args: |
| model_path (str): Path to the trained ReCEP model |
| device_id (int): GPU device ID to use |
| radius (float): Radius for spherical regions |
| k (int): Number of top regions to select |
| threshold (float): Threshold for node-level epitope prediction |
| verbose (bool): Whether to print progress information |
| encoder (str): Encoder type for embeddings |
| use_gpu (bool): Whether to use GPU for computation |
| auto_cleanup (bool): Whether to automatically delete generated data files after prediction |
| |
| Returns: |
| dict: Dictionary containing: |
| - 'predicted_epitopes': List of predicted epitope residue numbers |
| - 'predictions': Dictionary of all residue probabilities {resnum: probability} |
| - 'top_k_centers': List of top-k center residue numbers |
| - 'top_k_region_residues': List of all residues covered by top-k regions (union) |
| - 'top_k_regions': Detailed information about selected regions |
| """ |
| |
| if use_gpu and torch.cuda.is_available() and device_id >= 0: |
| device = torch.device(f"cuda:{device_id}") |
| else: |
| device = torch.device("cpu") |
| if verbose: |
| print(f"[INFO] Using device: {device}") |
| |
| |
| try: |
| if model_path is None: |
| model_path = f"{BASE_DIR}/models/ReCEP/20250626_110438/best_mcc_model.bin" |
| |
| if threshold is None: |
| model, threshold = ReCEP.load(model_path, device=device, strict=False, verbose=False) |
| else: |
| model, _ = ReCEP.load(model_path, device=device, strict=False, verbose=False) |
| |
| model.eval() |
| if verbose: |
| print(f"[INFO] Loaded ReCEP model from {model_path}") |
| except Exception as e: |
| if verbose: |
| print(f"[ERROR] Failed to load model: {str(e)}") |
| return {} |
| |
| |
| try: |
| embeddings, backbone_atoms, rsa, coverage_dict = self.data_preparation(radius=radius, encoder=encoder) |
| if verbose: |
| print(f"[INFO] Retrieved protein data for {len(coverage_dict)} surface regions") |
| except Exception as e: |
| if verbose: |
| print(f"[ERROR] Failed to prepare data: {str(e)}") |
| traceback.print_exc() |
| return {} |
| |
| if not coverage_dict: |
| if verbose: |
| print("[WARNING] No surface regions found") |
| return {} |
| |
| |
| region_predictions = [] |
| |
| with torch.no_grad(): |
| for center_idx, (covered_indices, covered_epitope_indices, precision, recall) in tqdm( |
| coverage_dict.items(), desc="Predicting region values", disable=not verbose): |
| |
| if len(covered_indices) < 2: |
| continue |
| |
| try: |
| |
| graph_data = create_graph_data( |
| center_idx=center_idx, |
| covered_indices=covered_indices, |
| covered_epitope_indices=[], |
| embeddings=embeddings, |
| backbone_atoms=backbone_atoms, |
| rsa_values=rsa, |
| epitope_indices=[], |
| recall=0.0, |
| precision=0.0, |
| pdb_id=self.id, |
| chain_id=self.chain_id, |
| verbose=False |
| ) |
| |
| if graph_data is None: |
| if verbose: |
| print(f"[WARNING] Failed to create graph data for region {center_idx}") |
| continue |
| |
| |
| graph_data = graph_data.to(device) |
| |
| |
| graph_data.batch = torch.zeros(graph_data.num_nodes, dtype=torch.long, device=device) |
| |
| |
| outputs = model(graph_data) |
| |
| |
| if 'global_pred' in outputs: |
| graph_pred = torch.sigmoid(outputs['global_pred']).cpu().item() |
| else: |
| |
| node_preds = torch.sigmoid(outputs['node_preds']).cpu().numpy() |
| graph_pred = float(np.mean(node_preds)) |
| |
| region_predictions.append({ |
| 'center_idx': center_idx, |
| 'covered_indices': covered_indices, |
| 'graph_pred': graph_pred, |
| 'graph_data': graph_data |
| }) |
| |
| except Exception as e: |
| if verbose: |
| print(f"[WARNING] Error processing region {center_idx}: {str(e)}") |
| traceback.print_exc() |
| continue |
| |
| if not region_predictions: |
| if verbose: |
| print("[WARNING] No valid region predictions") |
| return {} |
| |
| |
| region_predictions.sort(key=lambda x: x['graph_pred'], reverse=True) |
| top_k_regions = region_predictions[:k] |
| |
| if verbose: |
| print(f"[INFO] Selected top {len(top_k_regions)} regions:") |
| for i, region in enumerate(top_k_regions): |
| print(f" Region {i+1}: center={region['center_idx']}, " |
| f"predicted_value={region['graph_pred']:.3f}") |
| |
| |
| residue_probs = {} |
| |
| with torch.no_grad(): |
| for region in tqdm(top_k_regions, desc="Predicting node values", disable=not verbose): |
| try: |
| graph_data = region['graph_data'] |
| |
| |
| if not hasattr(graph_data, 'batch') or graph_data.batch is None: |
| graph_data.batch = torch.zeros(graph_data.num_nodes, dtype=torch.long, device=device) |
| |
| |
| outputs = model(graph_data) |
| |
| |
| node_preds = torch.sigmoid(outputs['node_preds']).cpu().numpy() |
| |
| |
| for local_idx, residue_idx in enumerate(region['covered_indices']): |
| if residue_idx not in residue_probs: |
| residue_probs[residue_idx] = [] |
| |
| |
| prob = float(node_preds[local_idx]) |
| residue_probs[residue_idx].append(prob) |
| |
| except Exception as e: |
| if verbose: |
| print(f"[WARNING] Error in node prediction for region {region['center_idx']}: {str(e)}") |
| traceback.print_exc() |
| continue |
| |
| |
| all_residue_predictions = {} |
| for idx in range(len(self.residue_index)): |
| residue_num = int(self.residue_index[idx]) |
| if idx in residue_probs: |
| |
| all_residue_predictions[residue_num] = float(np.mean(residue_probs[idx])) |
| else: |
| |
| all_residue_predictions[residue_num] = 0.0 |
| |
| |
| predicted_epitope_resnums = [] |
| node_mean = 0.0 |
| for residue_num, prob in all_residue_predictions.items(): |
| node_mean += prob |
| if prob >= threshold: |
| predicted_epitope_resnums.append(residue_num) |
| node_mean /= len(all_residue_predictions) if all_residue_predictions else 1 |
| |
| |
| top_k_centers = [int(self.residue_index[region['center_idx']]) for region in top_k_regions] |
| |
| |
| graph_mean = 0.0 |
| all_covered_indices = set() |
| for region in top_k_regions: |
| all_covered_indices.update(region['covered_indices']) |
| graph_mean += region['graph_pred'] |
| graph_mean /= len(top_k_regions) |
| |
| top_k_region_residues = [int(self.residue_index[idx]) for idx in all_covered_indices |
| if 0 <= idx < len(self.residue_index)] |
| |
| if verbose: |
| print(f"\n[INFO] Prediction Results:") |
| print(f" Predicted epitopes: {len(predicted_epitope_resnums)}") |
| print(f" Top-k centers: {top_k_centers}") |
| print(f" Total residues in top-k regions: {len(top_k_region_residues)}") |
| |
| |
| results = { |
| 'predicted_epitopes': predicted_epitope_resnums, |
| 'predictions': all_residue_predictions, |
| 'top_k_centers': top_k_centers, |
| 'top_k_region_residues': top_k_region_residues, |
| 'top_k_regions': [ |
| { |
| 'center_residue': int(self.residue_index[region['center_idx']]), |
| 'center_idx': region['center_idx'], |
| 'predicted_value': region['graph_pred'], |
| 'covered_residues': [int(self.residue_index[idx]) for idx in region['covered_indices']] |
| } |
| for region in top_k_regions |
| ], |
| 'antigen_rate': graph_mean, |
| 'epitope_rate': node_mean |
| } |
| |
| |
| if auto_cleanup: |
| self._cleanup_generated_data(encoder=encoder, verbose=verbose) |
| |
| return results |
| |
| def _cleanup_generated_data(self, encoder: str = "esmc", verbose: bool = True): |
| """ |
| Clean up generated data files for this antigen chain. |
| |
| Args: |
| encoder (str): Encoder type used for embeddings |
| verbose (bool): Whether to print cleanup information |
| """ |
| import os |
| |
| |
| files_to_delete = [ |
| |
| Path(BASE_DIR) / "data" / "embeddings" / encoder / f"{self.id}_{self.chain_id}.h5", |
| |
| Path(BASE_DIR) / "data" / "coords" / f"{self.id}_{self.chain_id}.npy", |
| |
| Path(BASE_DIR) / "data" / "rsa" / f"{self.id}_{self.chain_id}.npy", |
| |
| Path(BASE_DIR) / "data" / "antigen_sphere" / f"{self.id}_{self.chain_id}.h5" |
| ] |
| |
| deleted_files = [] |
| failed_deletions = [] |
| total_size = 0 |
| |
| for file_path in files_to_delete: |
| if file_path.exists(): |
| try: |
| |
| file_size = file_path.stat().st_size |
| os.remove(file_path) |
| deleted_files.append(file_path) |
| total_size += file_size |
| if verbose: |
| print(f"[INFO] Deleted: {file_path}") |
| except Exception as e: |
| failed_deletions.append((file_path, str(e))) |
| if verbose: |
| print(f"[WARNING] Failed to delete {file_path}: {str(e)}") |
| else: |
| if verbose: |
| print(f"[INFO] File not found (already deleted or not generated): {file_path}") |
| |
| if verbose: |
| print(f"[INFO] Cleanup completed for {self.id}_{self.chain_id}") |
| print(f" - Files deleted: {len(deleted_files)}") |
| print(f" - Failed deletions: {len(failed_deletions)}") |
| if total_size > 0: |
| print(f" - Total space freed: {total_size / (1024**2):.2f} MB") |
| |
| def visualize(self, |
| mode: str = 'normal', |
| style: str = 'cartoon', |
| predicted_epitopes: list = None, |
| predict_results: dict = None, |
| prediction_mode: str = 'residue', |
| center_res: int = None, |
| radius: float = None, |
| region_index: int = None, |
| width: int = 800, |
| height: int = 600, |
| base_color: str = '#e6e6f7', |
| true_epitope_color: str = '#f1b54c', |
| false_positive_color: str = '#ef5331', |
| true_positive_color: str = '#a0d293', |
| coverage_color: str = '#9C6ADE', |
| prediction_color: str = '#9C6ADE', |
| center_color: str = '#2C3E50', |
| probability_colormap: str = 'RdYlBu_r', |
| show_surface: bool = True, |
| show_shape: bool = True, |
| show_center: bool = True, |
| center_radius: float = 0.7, |
| n_points: int = 50, |
| shape_opacity: float = 0.3, |
| surface_opacity: float = 1.0, |
| wireframe: bool = True, |
| show_epitope: bool = True, |
| show_coverage: bool = True, |
| show_top_regions: bool = True, |
| max_spheres: int = None, |
| prob_threshold: float = 0.5): |
| """ |
| Visualize the protein chain with various modes and integration with predict results. |
| |
| Args: |
| mode (str): Visualization mode. Options: |
| - 'normal': Basic protein structure |
| - 'epitope': Show predicted epitopes vs true epitopes |
| - 'coverage': Show spherical coverage region |
| - 'evaluation': Show evaluation results from evaluate() function |
| - 'prediction': Show prediction results from predict() function |
| - 'probability': Show residue probabilities as color gradient |
| - 'top_regions': Show top-k regions from prediction |
| - 'comparison': Compare voted vs predicted epitopes |
| prediction_mode (str): Sub-mode for prediction visualization ('residue' or 'region') |
| - 'residue': Color predicted epitopes by probability (gradient purple) |
| - 'region': Color all residues in top-k regions uniformly |
| style (str): Protein representation style ('cartoon', 'stick', 'sphere', 'surface') |
| predicted_epitopes (list): List of predicted epitope residue numbers |
| predict_results (dict): Results dictionary from predict() function |
| center_res (int): Center residue number for coverage visualization |
| radius (float): Radius for spherical coverage |
| region_index (int): Index of specific region to show in probability mode (0-based) |
| If None, shows all regions |
| Each region uses a distinct color for shape visualization |
| probability_colormap (str): Colormap name for probability visualization |
| prob_threshold (float): Threshold for probability-based coloring |
| ... (other parameters as before) |
| |
| Returns: |
| py3Dmol.view: The molecular visualization view object |
| """ |
| |
| view = self._create_base_view(width, height) |
| |
| |
| style_dict = { |
| 'cartoon': {'cartoon': {}}, |
| 'stick': {'stick': {}}, |
| 'sphere': {'sphere': {}}, |
| 'surface': {'surface': {}} |
| } |
| base_style = style_dict.get(style, {'cartoon': {}}) |
| |
| |
| if mode == 'epitope' and predicted_epitopes is not None: |
| self._add_epitope_visualization( |
| view, style, predicted_epitopes, |
| base_color, true_epitope_color, false_positive_color, |
| true_positive_color, coverage_color, |
| show_surface, surface_opacity, show_coverage, |
| center_res, radius |
| ) |
| |
| |
| if show_shape and center_res is not None and radius is not None: |
| self._add_shape_visualization( |
| view, center_res, radius, |
| coverage_color, center_color, |
| show_center, center_radius, |
| shape_opacity, wireframe |
| ) |
| |
| elif mode == 'coverage' and center_res is not None and radius is not None: |
| self._add_coverage_visualization( |
| view, style, center_res, radius, |
| base_color, coverage_color, true_positive_color, true_epitope_color, |
| show_surface, show_shape, show_center, |
| surface_opacity, shape_opacity, center_radius, |
| n_points, center_color, wireframe, show_epitope |
| ) |
| |
| elif mode == 'evaluation' and predict_results is not None: |
| self._add_evaluation_visualization( |
| view, style, predict_results, |
| base_color, true_epitope_color, false_positive_color, |
| true_positive_color, coverage_color, |
| show_surface, surface_opacity, show_shape, radius, max_spheres |
| ) |
| |
| elif mode == 'prediction' and predict_results is not None: |
| self._add_prediction_visualization( |
| view, style, predict_results, prediction_mode, |
| base_color, prediction_color, show_surface, surface_opacity, |
| show_shape, shape_opacity, show_center, center_radius, |
| wireframe, radius, max_spheres |
| ) |
| |
| elif mode == 'probability' and predict_results is not None: |
| self._add_probability_visualization( |
| view, style, predict_results, |
| base_color, probability_colormap, show_surface, surface_opacity, |
| prob_threshold, region_index, radius, show_shape, shape_opacity, |
| show_center, center_radius, wireframe, coverage_color, center_color |
| ) |
| |
| elif mode == 'top_regions' and predict_results is not None: |
| self._add_top_regions_visualization( |
| view, style, predict_results, |
| base_color, coverage_color, center_color, |
| show_surface, show_shape, show_center, |
| surface_opacity, shape_opacity, center_radius, |
| wireframe, radius, max_spheres |
| ) |
| |
| elif mode == 'comparison' and predict_results is not None: |
| self._add_comparison_visualization( |
| view, style, predict_results, |
| base_color, true_epitope_color, false_positive_color, |
| true_positive_color, coverage_color, show_surface, surface_opacity |
| ) |
| |
| else: |
| |
| view.setStyle({'chain': self.chain_id}, base_style) |
| |
| |
| view.zoomTo() |
| return view |
| |
| def _add_prediction_visualization(self, view, style, predict_results, prediction_mode, |
| base_color, prediction_color, show_surface, surface_opacity, |
| show_shape, shape_opacity, show_center, center_radius, |
| wireframe, radius, max_spheres): |
| """Add visualization for prediction results""" |
| if prediction_mode == 'residue': |
| self._add_prediction_residue_mode( |
| view, style, predict_results, base_color, prediction_color, |
| show_surface, surface_opacity |
| ) |
| elif prediction_mode == 'region': |
| self._add_prediction_region_mode( |
| view, style, predict_results, base_color, prediction_color, |
| show_surface, surface_opacity, show_shape, shape_opacity, |
| show_center, center_radius, wireframe, radius, max_spheres |
| ) |
| |
| def _add_prediction_residue_mode(self, view, style, predict_results, base_color, prediction_color, |
| show_surface, surface_opacity): |
| """Add visualization for prediction results in residue mode""" |
| import matplotlib.pyplot as plt |
| import matplotlib.colors as mcolors |
| |
| |
| predictions = predict_results.get('predictions', {}) |
| predicted_epitopes = predict_results.get('predicted_epitopes', []) |
| |
| |
| style_dict = { |
| 'cartoon': {'cartoon': {}}, |
| 'stick': {'stick': {}}, |
| 'sphere': {'sphere': {}}, |
| 'surface': {'surface': {}} |
| } |
| base_style = style_dict.get(style, {'cartoon': {}}) |
| |
| if not predictions: |
| |
| view.setStyle({'chain': self.chain_id}, {**base_style, |
| list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
| if show_surface: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }, {'chain': self.chain_id}) |
| return |
| |
| |
| epitope_predictions = {res: prob for res, prob in predictions.items() |
| if res in predicted_epitopes} |
| |
| if not epitope_predictions: |
| |
| view.setStyle({'chain': self.chain_id}, {**base_style, list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
| if show_surface: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }, {'chain': self.chain_id}) |
| return |
| |
| |
| probs = list(epitope_predictions.values()) |
| min_prob, max_prob = min(probs), max(probs) |
| |
| |
| |
| epitope_colors = [ |
| '#FFE4B5', |
| '#FFD700', |
| '#FFA500', |
| '#FF8C00', |
| '#FF6347', |
| '#FF4500', |
| '#DC143C' |
| ] |
| n_colors = len(epitope_colors) |
| |
| |
| view.setStyle({'chain': self.chain_id}, {**base_style, list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
| |
| |
| for residue_num, prob in epitope_predictions.items(): |
| |
| if max_prob > min_prob: |
| norm_prob = (prob - min_prob) / (max_prob - min_prob) |
| else: |
| norm_prob = 0.5 |
| |
| |
| color_idx = int(norm_prob * (n_colors - 1)) |
| color_idx = max(0, min(color_idx, n_colors - 1)) |
| color = epitope_colors[color_idx] |
| |
| |
| style_name = list(base_style.keys())[0] |
| colored_style = {style_name: {'color': color}} |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': residue_num}, |
| colored_style |
| ) |
| |
| |
| if show_surface: |
| |
| all_residues = set(int(res) for res in self.residue_index) |
| non_epitope_residues = all_residues - set(predicted_epitopes) |
| |
| if non_epitope_residues: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }, {'chain': self.chain_id, 'resi': list(non_epitope_residues)}) |
| |
| |
| for residue_num, prob in epitope_predictions.items(): |
| |
| if max_prob > min_prob: |
| norm_prob = (prob - min_prob) / (max_prob - min_prob) |
| else: |
| norm_prob = 0.5 |
| |
| |
| color_idx = int(norm_prob * (n_colors - 1)) |
| color_idx = max(0, min(color_idx, n_colors - 1)) |
| color = epitope_colors[color_idx] |
| |
| |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity, |
| 'color': color |
| }, {'chain': self.chain_id, 'resi': residue_num}) |
| |
| def _add_prediction_region_mode(self, view, style, predict_results, base_color, prediction_color, |
| show_surface, surface_opacity, show_shape, shape_opacity, |
| show_center, center_radius, wireframe, radius, max_spheres): |
| """Add visualization for prediction results in region mode""" |
| |
| top_k_regions = predict_results.get('top_k_regions', []) |
| top_k_region_residues = predict_results.get('top_k_region_residues', []) |
| |
| |
| style_dict = { |
| 'cartoon': {'cartoon': {}}, |
| 'stick': {'stick': {}}, |
| 'sphere': {'sphere': {}}, |
| 'surface': {'surface': {}} |
| } |
| base_style = style_dict.get(style, {'cartoon': {}}) |
| |
| if not top_k_region_residues: |
| |
| view.setStyle({'chain': self.chain_id}, {**base_style, list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
| if show_surface: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }, {'chain': self.chain_id}) |
| return |
| |
| |
| view.setStyle({'chain': self.chain_id}, {**base_style, list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
| |
| |
| if top_k_region_residues: |
| style_name = list(base_style.keys())[0] |
| colored_style = {style_name: {'color': prediction_color}} |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': top_k_region_residues}, |
| colored_style |
| ) |
| |
| |
| if show_surface: |
| |
| all_residues = set(int(res) for res in self.residue_index) |
| non_region_residues = all_residues - set(top_k_region_residues) |
| |
| if non_region_residues: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }, {'chain': self.chain_id, 'resi': list(non_region_residues)}) |
| |
| |
| if top_k_region_residues: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity, |
| 'color': prediction_color |
| }, {'chain': self.chain_id, 'resi': top_k_region_residues}) |
| |
| |
| if show_shape and top_k_regions: |
| self._add_multi_shape_visualization( |
| view, top_k_regions, radius, max_spheres, |
| show_center, center_radius, shape_opacity, wireframe |
| ) |
| |
| def _add_evaluation_visualization(self, view, style, predict_results, |
| base_color, true_epitope_color, false_positive_color, |
| true_positive_color, coverage_color, |
| show_surface, surface_opacity, show_shape, radius, max_spheres): |
| """Add visualization for evaluation results""" |
| |
| predicted_epitopes = set(predict_results.get('predicted_epitopes', [])) |
| true_epitopes = set(predict_results.get('true_epitopes', [])) |
| |
| |
| true_positives = predicted_epitopes & true_epitopes |
| false_positives = predicted_epitopes - true_epitopes |
| false_negatives = true_epitopes - predicted_epitopes |
| |
| |
| style_dict = { |
| 'cartoon': {'cartoon': {}}, |
| 'stick': {'stick': {}}, |
| 'sphere': {'sphere': {}}, |
| 'surface': {'surface': {}} |
| } |
| base_style = style_dict.get(style, {'cartoon': {}}) |
| |
| |
| view.setStyle({'chain': self.chain_id}, {**base_style, list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
| |
| |
| for residues, color in [ |
| (true_positives, true_positive_color), |
| (false_positives, false_positive_color), |
| (false_negatives, true_epitope_color) |
| ]: |
| if residues: |
| |
| style_name = list(base_style.keys())[0] |
| colored_style = {style_name: {'color': color}} |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': list(residues)}, |
| colored_style |
| ) |
| |
| |
| if show_surface: |
| |
| all_colored_residues = true_positives | false_positives | false_negatives |
| |
| |
| if all_colored_residues: |
| all_residues = set(int(res) for res in self.residue_index) |
| non_colored_residues = all_residues - all_colored_residues |
| |
| if non_colored_residues: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }, {'chain': self.chain_id, 'resi': list(non_colored_residues)}) |
| else: |
| |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }, {'chain': self.chain_id}) |
| |
| |
| for residues, color in [ |
| (true_positives, true_positive_color), |
| (false_positives, false_positive_color), |
| (false_negatives, true_epitope_color) |
| ]: |
| if residues: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity, |
| 'color': color |
| }, {'chain': self.chain_id, 'resi': list(residues)}) |
| |
| |
| if show_shape and 'top_k_regions' in predict_results: |
| top_regions = predict_results['top_k_regions'] |
| self._add_multi_shape_visualization( |
| view, top_regions, radius, max_spheres, |
| True, 0.5, 0.2, True |
| ) |
| |
| def _add_probability_visualization(self, view, style, predict_results, |
| base_color, colormap, show_surface, surface_opacity, threshold, |
| region_index, radius, show_shape, shape_opacity, |
| show_center, center_radius, wireframe, coverage_color, center_color): |
| """ |
| Add visualization based on prediction probabilities with enhanced support for |
| specific region selection and surface rendering. |
| |
| Args: |
| view: py3Dmol view object |
| style (str): Protein representation style |
| predict_results (dict): Results from predict() function |
| base_color (str): Base color for non-highlighted residues |
| colormap (str): Colormap name for probability visualization |
| show_surface (bool): Whether to show surface |
| surface_opacity (float): Surface opacity |
| threshold (float): Probability threshold for coloring |
| region_index (int): Index of specific region to show (0-based), None for all |
| Each region_index uses a distinct color for shape visualization |
| radius (float): Radius for spherical regions |
| show_shape (bool): Whether to show spherical shapes |
| shape_opacity (float): Shape opacity |
| show_center (bool): Whether to show center points |
| center_radius (float): Center point radius |
| wireframe (bool): Whether to show wireframe spheres |
| coverage_color (str): Color for coverage regions (not used when region_index is specified) |
| center_color (str): Color for center points |
| """ |
| import matplotlib.pyplot as plt |
| import matplotlib.colors as mcolors |
| |
| |
| predictions = predict_results.get('predictions', {}) |
| top_k_regions = predict_results.get('top_k_regions', []) |
| |
| |
| style_dict = { |
| 'cartoon': {'cartoon': {}}, |
| 'stick': {'stick': {}}, |
| 'sphere': {'sphere': {}}, |
| 'surface': {'surface': {}} |
| } |
| base_style = style_dict.get(style, {'cartoon': {}}) |
| |
| if not predictions: |
| |
| view.setStyle({'chain': self.chain_id}, {**base_style, |
| list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
| if show_surface: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }, {'chain': self.chain_id}) |
| return |
| |
| |
| view.setStyle({'chain': self.chain_id}, {**base_style, |
| list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
| |
| |
| target_residues = {} |
| selected_region = None |
| |
| if region_index is not None and 0 <= region_index < len(top_k_regions): |
| |
| selected_region = top_k_regions[region_index] |
| covered_residues = selected_region.get('covered_residues', []) |
| |
| |
| for res_num in covered_residues: |
| if res_num in predictions: |
| target_residues[res_num] = predictions[res_num] |
| else: |
| |
| target_residues = {res: prob for res, prob in predictions.items() |
| if prob >= threshold} |
| |
| if not target_residues: |
| |
| if show_surface: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }, {'chain': self.chain_id}) |
| return |
| |
| |
| probs = list(target_residues.values()) |
| min_prob, max_prob = min(probs), max(probs) |
| |
| |
| if colormap in ['RdYlBu_r', 'coolwarm', 'RdBu_r']: |
| |
| probability_colors = [ |
| '#c6dbef', |
| '#9ecae1', |
| '#6baed6', |
| '#4292c6', |
| '#2171b5', |
| '#fcbba1', |
| '#fc9272', |
| '#fb6a4a', |
| '#ef3b2c', |
| '#cb181d' |
| ] |
| n_colors = len(probability_colors) |
| else: |
| |
| cmap = plt.cm.get_cmap(colormap) |
| probability_colors = [] |
| n_colors = 10 |
| for i in range(n_colors): |
| color_rgba = cmap(i / (n_colors - 1)) |
| |
| softened_rgba = [ |
| color_rgba[0] * 0.7 + 0.3, |
| color_rgba[1] * 0.7 + 0.3, |
| color_rgba[2] * 0.7 + 0.3, |
| ] |
| |
| softened_rgba = [min(1.0, val) for val in softened_rgba] |
| probability_colors.append(mcolors.rgb2hex(softened_rgba)) |
| |
| |
| colored_residues = [] |
| for residue_num, prob in target_residues.items(): |
| |
| if max_prob > min_prob: |
| norm_prob = (prob - min_prob) / (max_prob - min_prob) |
| else: |
| norm_prob = 0.5 |
| |
| |
| color_idx = int(norm_prob * (n_colors - 1)) |
| color_idx = max(0, min(color_idx, n_colors - 1)) |
| color = probability_colors[color_idx] |
| |
| |
| style_name = list(base_style.keys())[0] |
| colored_style = {style_name: {'color': color}} |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': residue_num}, |
| colored_style |
| ) |
| colored_residues.append(residue_num) |
| |
| |
| if show_surface: |
| |
| all_residues = set(int(res) for res in self.residue_index) |
| non_colored_residues = all_residues - set(colored_residues) |
| |
| if non_colored_residues: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }, {'chain': self.chain_id, 'resi': list(non_colored_residues)}) |
| |
| |
| for residue_num, prob in target_residues.items(): |
| |
| if max_prob > min_prob: |
| norm_prob = (prob - min_prob) / (max_prob - min_prob) |
| else: |
| norm_prob = 0.5 |
| |
| |
| color_idx = int(norm_prob * (n_colors - 1)) |
| color_idx = max(0, min(color_idx, n_colors - 1)) |
| color = probability_colors[color_idx] |
| |
| |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': color |
| }, {'chain': self.chain_id, 'resi': residue_num}) |
| |
| |
| if selected_region is not None and show_shape: |
| center_res = selected_region['center_residue'] |
| |
| |
| sphere_radius = radius or 19.0 |
| |
| |
| region_colors = [ |
| '#FF6B6B', |
| '#4ECDC4', |
| '#45B7D1', |
| '#96CEB4', |
| '#FFEAA7', |
| '#DDA0DD', |
| '#87CEEB', |
| '#F0E68C', |
| '#FFB6C1', |
| '#98FB98', |
| '#9C6ADE', |
| '#FF9A8B' |
| ] |
| |
| |
| shape_color = region_colors[region_index % len(region_colors)] |
| |
| |
| self._add_shape_visualization( |
| view, center_res, sphere_radius, |
| shape_color, center_color, |
| show_center, center_radius, |
| shape_opacity * 0.6, |
| wireframe |
| ) |
| |
| |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': center_res}, |
| {list(base_style.keys())[0]: {'color': shape_color}} |
| ) |
| |
| def _add_top_regions_visualization(self, view, style, predict_results, |
| base_color, coverage_color, center_color, |
| show_surface, show_shape, show_center, |
| surface_opacity, shape_opacity, center_radius, |
| wireframe, radius, max_spheres): |
| """Add visualization for top-k regions""" |
| |
| view.setStyle({'chain': self.chain_id}, {style: {'color': base_color}}) |
| |
| |
| top_regions = predict_results.get('top_k_regions', []) |
| |
| |
| if max_spheres is not None: |
| top_regions = top_regions[:max_spheres] |
| |
| |
| region_colors = [ |
| '#FF6B6B', |
| '#96CEB4', |
| '#4ECDC4', |
| '#45B7D1', |
| '#FFEAA7', |
| '#DDA0DD', |
| '#87CEEB', |
| '#F0E68C', |
| '#FFB6C1', |
| '#98FB98' |
| ] |
| |
| for i, region in enumerate(top_regions): |
| center_res = region['center_residue'] |
| covered_residues = region.get('covered_residues', []) |
| region_color = region_colors[i % len(region_colors)] |
| |
| |
| if covered_residues: |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': covered_residues}, |
| {style: {'color': region_color}} |
| ) |
| |
| |
| if show_shape: |
| self._add_shape_visualization( |
| view, center_res, radius or 18.0, |
| region_color, center_color, |
| show_center, center_radius * 0.8, |
| shape_opacity, wireframe |
| ) |
| |
| |
| if show_surface: |
| |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }) |
| |
| |
| for i, region in enumerate(top_regions): |
| covered_residues = region.get('covered_residues', []) |
| region_color = region_colors[i % len(region_colors)] |
| |
| if covered_residues: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity, |
| 'color': region_color |
| }, {'resi': covered_residues}) |
| |
| def _add_comparison_visualization(self, view, style, predict_results, |
| base_color, true_epitope_color, false_positive_color, |
| true_positive_color, coverage_color, show_surface, surface_opacity): |
| """Add visualization comparing voted vs predicted epitopes""" |
| |
| view.setStyle({'chain': self.chain_id}, {style: {'color': base_color}}) |
| |
| |
| predicted_epitopes = set(predict_results.get('predicted_epitopes', [])) |
| voted_epitopes = set(predict_results.get('voted_epitopes', [])) |
| true_epitopes = set(predict_results.get('true_epitopes', [])) |
| |
| |
| both_methods = predicted_epitopes & voted_epitopes |
| only_predicted = predicted_epitopes - voted_epitopes |
| only_voted = voted_epitopes - predicted_epitopes |
| |
| |
| both_correct = both_methods & true_epitopes |
| both_incorrect = both_methods - true_epitopes |
| only_pred_correct = only_predicted & true_epitopes |
| only_pred_incorrect = only_predicted - true_epitopes |
| only_vote_correct = only_voted & true_epitopes |
| only_vote_incorrect = only_voted - true_epitopes |
| |
| |
| color_assignments = [ |
| (both_correct, '#00FF00'), |
| (both_incorrect, '#FF0000'), |
| (only_pred_correct, '#90EE90'), |
| (only_pred_incorrect, '#FFB6C1'), |
| (only_vote_correct, '#87CEEB'), |
| (only_vote_incorrect, '#DDA0DD') |
| ] |
| |
| for residues, color in color_assignments: |
| if residues: |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': list(residues)}, |
| {style: {'color': color}} |
| ) |
| |
| |
| if show_surface: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity, |
| 'color': base_color |
| }) |
| |
| for residues, color in color_assignments: |
| if residues: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity, |
| 'color': color |
| }, {'resi': list(residues)}) |
|
|
| def _create_base_view(self, width: int, height: int) -> py3Dmol.view: |
| """创建基本的py3Dmol视图并添加蛋白质结构""" |
| view = py3Dmol.view(width=width, height=height) |
| |
| |
| pdb_str = "MODEL 1\n" |
| atom_num = 1 |
| for res_idx in range(len(self.sequence)): |
| one_letter = self.sequence[res_idx] |
| resname = self.convert_letter_1to3(one_letter) |
| resnum = self.residue_index[res_idx] |
| |
| mask = self.atom37_mask[res_idx] |
| coords = self.atom37_positions[res_idx][mask] |
| atoms = [name for name, exists in zip(RC.atom_types, mask) if exists] |
| |
| for atom_name, coord in zip(atoms, coords): |
| x, y, z = coord |
| pdb_str += (f"ATOM {atom_num:5d} {atom_name:<3s} {resname:>3s} {self.chain_id:1s}{resnum:4d}" |
| f" {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00\n") |
| atom_num += 1 |
| |
| pdb_str += "ENDMDL\n" |
| view.addModel(pdb_str, "pdb") |
| return view |
|
|
| def _add_epitope_visualization(self, view, style, predicted_epitopes, |
| base_color, true_epitope_color, false_positive_color, true_positive_color, coverage_color, |
| show_surface, surface_opacity, show_coverage, |
| center_res=None, radius=None): |
| """添加表位可视化""" |
| |
| view.setStyle({'chain': self.chain_id}, {style: {'color': base_color}}) |
| |
| true_epitopes = set(self.get_epitope_residue_numbers()) |
| true_positives = set(predicted_epitopes) & true_epitopes |
| false_positives = set(predicted_epitopes) - true_epitopes |
| false_negatives = true_epitopes - set(predicted_epitopes) |
| |
| |
| covered_residues = [] |
| if center_res is not None and radius is not None: |
| coverage_dict, _, _ = self.get_surface_coverage( |
| radius=radius, threshold=0.25, index=False |
| ) |
| covered_res_list = coverage_dict.get(center_res, [[], [], 0, 0])[0] |
| covered_residues = covered_res_list |
| |
| |
| if covered_residues: |
| true_negatives = [res for res in covered_residues |
| if res not in true_epitopes and res not in predicted_epitopes] |
| |
| |
| true_negative_color = '#888888' |
| |
| if true_negatives: |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': true_negatives}, |
| {style: {'color': true_negative_color}} |
| ) |
| |
| |
| for residues, color in [ |
| (true_positives, true_positive_color), |
| (false_positives, false_positive_color), |
| (false_negatives, true_epitope_color) |
| ]: |
| if residues: |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': list(residues)}, |
| {style: {'color': color}} |
| ) |
| |
| |
| if show_surface: |
| |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': base_color |
| }) |
| |
| |
| for residues, color in [ |
| (true_positives, true_positive_color), |
| (false_positives, false_positive_color), |
| (false_negatives, true_epitope_color) |
| ]: |
| if residues: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity, |
| 'color': color |
| }, {'resi': list(residues)}) |
| |
| |
| if center_res is not None and radius is not None and covered_residues and show_coverage: |
| true_negatives = [res for res in covered_residues |
| if res not in true_epitopes and res not in predicted_epitopes] |
| |
| if true_negatives: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity, |
| 'color': coverage_color |
| }, {'resi': true_negatives}) |
|
|
| def _add_shape_visualization(self, view, center_res, radius, |
| coverage_color, center_color, |
| show_center, center_radius, |
| shape_opacity, wireframe): |
| """添加球形可视化""" |
| center_idx = self.resnum_to_index.get(center_res) |
| if center_idx is None: |
| return |
| |
| ca_idx = RC.atom_order["CA"] |
| center_coord = self.atom37_positions[center_idx, ca_idx, :] |
| |
| |
| sphere_params = { |
| 'center': {'x': float(center_coord[0]), |
| 'y': float(center_coord[1]), |
| 'z': float(center_coord[2])}, |
| 'radius': float(radius), |
| 'color': coverage_color |
| } |
| if wireframe: |
| sphere_params.update({'wireframe': True, 'linewidth': 1.5}) |
| else: |
| sphere_params.update({'opacity': shape_opacity}) |
| view.addSphere(sphere_params) |
| |
| |
| if show_center: |
| view.addSphere({ |
| 'center': {'x': float(center_coord[0]), |
| 'y': float(center_coord[1]), |
| 'z': float(center_coord[2])}, |
| 'radius': float(center_radius), |
| 'color': center_color, |
| 'opacity': 1.0 |
| }) |
|
|
| def _add_coverage_visualization(self, view, style, center_res, radius, |
| base_color, coverage_color, true_positive_color, true_epitope_color, |
| show_surface, show_shape, show_center, |
| surface_opacity, shape_opacity, center_radius, |
| n_points, center_color, wireframe, show_epitope): |
| """添加覆盖区域可视化""" |
| |
| view.setStyle({'chain': self.chain_id}, {style: {'color': base_color}}) |
| |
| |
| coverage_dict, _, _ = self.get_surface_coverage( |
| radius=radius, threshold=0.25, index=False |
| ) |
| |
| covered_res_list = coverage_dict.get(center_res, [[], [], 0, 0])[0] |
| covered_residues = covered_res_list |
|
|
| |
| if show_epitope: |
| true_epitopes = set(self.get_epitope_residue_numbers()) |
| else: |
| true_epitopes = set() |
| |
| |
| true_positives = set(covered_residues) & true_epitopes |
| false_negatives = true_epitopes - set(covered_residues) |
| covered_non_epitopes = set(covered_residues) - true_epitopes |
| |
| |
| if show_surface: |
| |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 1.0, |
| 'color': base_color |
| }) |
| |
| |
| if false_negatives: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity, |
| 'color': true_epitope_color |
| }, {'resi': list(false_negatives)}) |
| |
| |
| if true_positives: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity, |
| 'color': true_positive_color |
| }, {'resi': list(true_positives)}) |
| |
| |
| if covered_non_epitopes: |
| view.addSurface(py3Dmol.VDW, { |
| 'opacity': surface_opacity * 0.9, |
| 'color': coverage_color |
| }, {'resi': list(covered_non_epitopes)}) |
| |
| |
| if false_negatives: |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': list(false_negatives)}, |
| {style: {'color': true_epitope_color}} |
| ) |
| |
| if true_positives: |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': list(true_positives)}, |
| {style: {'color': true_positive_color}} |
| ) |
| |
| if covered_non_epitopes: |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': list(covered_non_epitopes)}, |
| {style: {'color': coverage_color}} |
| ) |
| |
| |
| view.addStyle( |
| {'chain': self.chain_id, 'resi': center_res}, |
| {style: {'color': '#FFD700'}} |
| ) |
| |
| |
| if show_shape: |
| self._add_shape_visualization( |
| view, center_res, radius, |
| coverage_color, |
| center_color, |
| show_center, center_radius, |
| shape_opacity, wireframe |
| ) |
|
|
| def _add_multi_shape_visualization(self, view, regions_data, radius, max_spheres, |
| show_center, center_radius, shape_opacity, wireframe): |
| """Add multiple spherical regions with different colors""" |
| if not regions_data: |
| return |
| |
| |
| regions_to_show = regions_data[:max_spheres] if max_spheres else regions_data |
| |
| |
| sphere_colors = [ |
| '#d671f1', |
| '#7190f1', |
| '#FF6B6B', |
| '#96CEB4', |
| '#FFEAA7', |
| '#FFB6C1', |
| '#4ECDC4', |
| '#87CEEB', |
| '#F0E68C', |
| '#98FB98', |
| '#45B7D1' |
| ] |
| |
| for i, region_data in enumerate(regions_to_show): |
| if isinstance(region_data, dict): |
| |
| center_res = region_data['center_residue'] |
| else: |
| |
| center_res = region_data |
| |
| sphere_color = sphere_colors[i % len(sphere_colors)] |
| self._add_shape_visualization( |
| view, center_res, radius or 18.0, |
| sphere_color, '#FFD700', |
| show_center, center_radius, shape_opacity, wireframe |
| ) |
|
|
| @classmethod |
| def from_pdb( |
| cls, |
| path: Optional[PathOrBuffer] = None, |
| chain_id: str = "detect", |
| id: Optional[str] = None, |
| is_predicted: bool = False, |
| ) -> "AntigenChain": |
| """ |
| Return a AntigenChain object from a pdb file. |
| |
| If `path` is not provided, the function will try multiple possible paths: |
| 1. {id}_{chain_id}.pdb |
| 2. {id}.pdb |
| 3. {id.lower()}_{chain_id}.pdb |
| 4. {id.upper()}_{chain_id}.pdb |
| If none of these paths exist, it will download the structure from RCSB PDB |
| and save it to the antigen_structs directory. |
| |
| Args: |
| path (Optional[PathOrBuffer]): Path or buffer to read pdb file from. If None, |
| the default path is constructed from DATA_DIR. |
| chain_id (str, optional): Select a chain corresponding to (author) chain id. |
| "detect" uses the first detected chain. |
| id (Optional[str], optional): Protein identifier (pdb_id). If not provided and `path` |
| is given, the id will be inferred from the file name. |
| is_predicted (bool, optional): If True, reads b factor as the confidence readout. |
| |
| Returns: |
| AntigenChain: The constructed antigen chain. |
| """ |
| |
| id = id.lower() |
| |
| if path is None: |
| if id is None: |
| raise ValueError("Either 'path' or 'id' must be provided to locate the pdb file.") |
| |
| |
| possible_paths = [ |
| Path(BASE_DIR) / "data" / "antigen_structs" / f"{id}_{chain_id}.pdb", |
| Path(BASE_DIR) / "data" / "antigen_structs" / f"{id}.pdb", |
| |
| |
| Path(BASE_DIR) / "data" / "PDB" / f"{id.lower()}.pdb", |
| |
| ] |
| |
| |
| path = None |
| for p in possible_paths: |
| if p.exists(): |
| path = p |
| print(f"Found pdb file at {path}") |
| break |
| |
| |
| if path is None: |
| try: |
| |
| save_dir = Path(BASE_DIR) / "data" / "pdb" |
| save_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| rcsb.fetch(id, "pdb", target_path=save_dir) |
| |
| path = save_dir / f"{id}.pdb" |
| print(f"No existing pdb file for {id}_{chain_id}, downloaded {id} complex pdb file to {path}") |
| |
| except Exception as e: |
| print(f"[ERROR] Failed to download pdb file for {id}: {str(e)}") |
| return None |
| else: |
| path = Path(path) |
|
|
| |
| if id is not None: |
| file_id = id |
| else: |
| |
| file_id = path.with_suffix("").name |
|
|
| |
| try: |
| atom_array = PDBFile.read(path).get_structure(model=1, extra_fields=["b_factor"]) |
| except Exception as e: |
| print(f"[ERROR] Failed to read pdb file {path}: {str(e)}") |
| return None |
| |
| |
| if chain_id == "detect": |
| chain_id = atom_array.chain_id[0] |
| print(f"[WARNING] No chain_id provided, using the first detected chain: {chain_id}") |
|
|
| |
| atom_array = atom_array[ |
| bs.filter_amino_acids(atom_array) |
| & ~atom_array.hetero |
| & (atom_array.chain_id == chain_id) |
| ] |
|
|
| |
| entity_id = 1 |
|
|
| |
| sequence = "".join( |
| ( |
| r if len((r := PDBData.protein_letters_3to1.get(monomer[0].res_name, "X"))) == 1 else "X" |
| ) |
| for monomer in bs.residue_iter(atom_array) |
| ) |
| num_res = len(sequence) |
|
|
| |
| atom_positions = np.full([num_res, RC.atom_type_num, 3], np.nan, dtype=np.float32) |
| atom_mask = np.full([num_res, RC.atom_type_num], False, dtype=bool) |
| residue_index = np.full([num_res], -1, dtype=np.int64) |
| insertion_code = np.full([num_res], "", dtype="<U4") |
| confidence = np.ones([num_res], dtype=np.float32) |
|
|
| |
| for i, res in enumerate(bs.residue_iter(atom_array)): |
| for atom in res: |
| atom_name = atom.atom_name |
| if atom_name == "SE" and atom.res_name == "MSE": |
| atom_name = "SD" |
| if atom_name in RC.atom_order: |
| atom_positions[i, RC.atom_order[atom_name]] = atom.coord |
| atom_mask[i, RC.atom_order[atom_name]] = True |
| if is_predicted and atom_name == "CA": |
| confidence[i] = atom.b_factor |
| residue_index[i] = res[0].res_id |
| insertion_code[i] = res[0].ins_code |
|
|
| |
| assert all(sequence), "Some residue name was not specified correctly" |
|
|
| return cls( |
| id=file_id, |
| sequence=sequence, |
| chain_id=chain_id, |
| entity_id=entity_id, |
| atom37_positions=atom_positions, |
| atom37_mask=atom_mask, |
| residue_index=residue_index, |
| insertion_code=insertion_code, |
| confidence=confidence, |
| ) |