import os import signal import psutil import torch import yaml from functools import wraps import errno import signal import numpy as np from scipy.spatial import KDTree from math import ceil from tqdm import tqdm import line_profiler import os import base64 import pickle from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend import io def num_parameters(model: torch.nn.Module) -> int: """Return the number of parameters in the model""" return sum(p.numel() for p in model.parameters() if p.requires_grad) class Config: """Read configuration from a YAML file and store as attributes""" def __init__(self, yaml_file: str): with open(yaml_file, "r") as f: config = yaml.safe_load(f) for k, v in config.items(): setattr(self, k, v) def update(self, new_yaml_file: str): with open(new_yaml_file, "r") as f: config = yaml.safe_load(f) for k, v in config.items(): setattr(self, k, v) def save(self, yaml_file: str): with open(yaml_file, "w") as f: yaml.dump(self.__dict__, f) def memory_usage_psutil(): """Return the memory usage in percentage like top""" process = psutil.Process(os.getpid()) mem = process.memory_percent() return mem def is_wandb_running(): """Check if wandb is running""" return "WANDB_SWEEP_ID" in os.environ class TimeoutError(Exception): pass def timeout(seconds=10, error_message=os.strerror(errno.ETIME)): def decorator(func): def _handle_timeout(signum, frame): raise TimeoutError(error_message) def wrapper(*args, **kwargs): signal.signal(signal.SIGALRM, _handle_timeout) signal.alarm(seconds) try: result = func(*args, **kwargs) finally: signal.alarm(0) return result return wraps(func)(wrapper) return decorator def shorten_path(path: str, max_len: int = 30) -> str: """Shorten the path to max_len characters""" if len(path) > max_len: return path[:max_len // 2] + "..." + path[-max_len // 2:] return path def cluster_points(data: torch.Tensor, d: float) -> torch.Tensor: """ Cluster points based on the Euclidean distance. :param data: Input data, shape (n_points, n_features), type torch.Tensor. :param d: Distance threshold for clustering. :return: Cluster indices, shape (n_points,), type torch.Tensor. """ dist = torch.cdist(data, data) indices = torch.full((data.shape[0],), -1, dtype=torch.long) cluster_id = 0 for i in range(data.shape[0]): if indices[i] == -1: indices[dist[i] < d] = cluster_id cluster_id += 1 return indices def bron_kerbosch(R, P, X, graph): if not P and not X: yield R while P: v = P.pop() yield from bron_kerbosch( R | {v}, P & set(graph[v]), X & set(graph[v]), graph ) X.add(v) def find_cliques(graph): """ Find all maximal cliques in an undirected graph with the Bron–Kerbosch algorithm. :param graph: Input graph as a NetworkX graph :return: List of maximal cliques """ return list(bron_kerbosch(set(), set(graph.nodes()), set(), graph)) def segment_cmd(cmd_str: str, max_len: int = 1000): cmds = [''] prev = 0 for i, c in enumerate(cmd_str): if c == ';': if len(cmds[-1]) + len(cmd_str[prev:i]) > max_len: cmds.append('') cmds[-1] += cmd_str[prev:i + 1] prev = i + 1 return cmds def get_color(v): assert 0 <= v <= 1, f'v should be in [0, 1], got {v}' # green to brown color1 = np.array([0, 128, 0]) color2 = np.array([165, 42, 42]) v = v * (color2 - color1) + color1 v /= 255 return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]' def generate_pymol_script(possible_sites): cmd = '' for i, pos in enumerate(possible_sites): cmd += f"pseudoatom s{i},pos=[{pos[0]:.1f},{pos[1]:.1f},{pos[2]:.1f}];color blue,s{i};" return cmd def remove_close_points_kdtree(points, min_distance): tree = KDTree(points) keep = np.ones(len(points), dtype=bool) for i, point in enumerate(points): if not keep[i]: continue neighbors = tree.query_ball_point( point, min_distance) keep[neighbors] = False keep[i] = True # Keep the current point return points[keep] @line_profiler.profile def pack_bit(x: torch.Tensor): """ Pack the bit tensor to a sequence of bytes. Args: x (torch.Tensor): The input tensor to be packed. Returns: torch.Tensor: The packed tensor. """ batch_size, num_bits = x.shape num_bytes = (num_bits + 7) // 8 output = torch.zeros(batch_size, num_bytes, dtype=torch.uint8, device=x.device) for i in range(num_bits): byte_index = i // 8 bit_index = i % 8 output[:, byte_index] |= (x[:, i] << bit_index).to(torch.uint8) return output @line_profiler.profile def unpack_bit(x: torch.Tensor, num_bits: int): """ Unpack the bit tensor from a sequence of bytes. Args: x (torch.Tensor): The input tensor to be unpacked. num_bits (int): The number of bits to unpack. Returns: torch.Tensor: The unpacked tensor. """ batch_size, num_bytes = x.shape output = torch.zeros(batch_size, num_bits, dtype=torch.uint8, device=x.device) for i in range(num_bits): byte_index = i // 8 bit_index = i % 8 output[:, i] = (x[:, byte_index] >> bit_index) & 1 return output def safe_dist(vec1: torch.Tensor, vec2: torch.Tensor, max_size: int = 100_000_000, p: int = 2): """ compute the minimum distance between two vectors: vec1: (N, 3), N could be very very large, i.e., all atoms' coordinates in a large protein vec2: (M, 3), M are not very large, usually the coordinates of the binding sites max_size: the maximum size of the distance matrix to compute at once p: the p-norm to use for distance calculation return: (M, ) the minimum distance of each binding site to the protein """ size1 = vec1.shape size2 = vec2.shape batch_size = ceil(max_size / size1[0]) dists = [] for i in range(0, size2[0], batch_size): dist = torch.cdist(vec1, vec2[i:i + batch_size], p=p) dists.append(dist.min(dim=0).values) return torch.cat(dists) @line_profiler.profile def safe_filter(nos: torch.Tensor, pos: torch.Tensor, thr: torch.Tensor, all: torch.Tensor, lb: float, max_size: int = 100_000_000): """ filter the binding sites based on the distance matrix nos: (N, 3), N are the coordinates of the binding sites *pos: (M, 3), M are the coordinates of the protein, could be very very large thr: (N, 2), the distance threshold for each binding site all: (P, 3), P are the coordinates of all atoms in the protein lb: the lower bound of the distance return: (N, M) available binding sites """ N, M, P = nos.shape[0], pos.shape[0], all.shape[0] batch_size = ceil(max_size / N) output = [] interests = [] for i in tqdm(range(0, M, batch_size), leave=False, desc=f'Filtering (batch_size: {batch_size})'): dist = torch.cdist(pos[i:i + batch_size], nos) dist = (dist <= thr[:, 1].unsqueeze(0)) & \ (dist >= thr[:, 0].unsqueeze(0)) dist_all = safe_dist(all, pos[i:i + batch_size]) > lb dist = dist & dist_all.unsqueeze(-1) mask = dist.any(dim=1) output.append(pack_bit(dist[mask]).T) interests.append(mask) return torch.cat(output, dim=1), torch.cat(interests) def backbone(atoms, chain_id): """ return the atoms of the backbone of a chain """ return atoms[ (atoms.chain_id == chain_id) & (atoms.atom_name == "CA") & (atoms.element == "C")] def get_color(v): assert 0 <= v <= 1, f'v should be in [0, 1], got {v}' # green to brown color1 = np.array([0, 128, 0]) color2 = np.array([165, 42, 42]) v = v * (color2 - color1) + color1 v /= 255 return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]' def load_private_key_from_file(private_key_file=None): if private_key_file is None: private_key_b64 = os.environ.get('ModelCheckpointPrivateKey') else: with open(private_key_file, 'r') as f: private_key_b64 = f.read().strip() private_pem = base64.b64decode(private_key_b64) private_key = serialization.load_pem_private_key( private_pem, password=None, backend=default_backend() ) return private_key def decrypt_checkpoint(encrypted_path, private_key): backend = default_backend() with open(encrypted_path, 'rb') as f: key_length = int.from_bytes(f.read(4), 'big') encrypted_aes_key = f.read(key_length) iv = f.read(16) original_size = int.from_bytes(f.read(8), 'big') encrypted_data = f.read() try: aes_key = private_key.decrypt( encrypted_aes_key, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ) ) cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=backend) decryptor = cipher.decryptor() decrypted_padded = decryptor.update( encrypted_data) + decryptor.finalize() decrypted_data = decrypted_padded[:original_size] try: buffer = io.BytesIO(decrypted_data) checkpoint_dict = torch.load(buffer, map_location='cpu') return checkpoint_dict except: checkpoint_dict = pickle.loads(decrypted_data) return checkpoint_dict except Exception as e: print(f"Error: {e}") raise