|
|
import os |
|
|
import signal |
|
|
|
|
|
import psutil |
|
|
import torch |
|
|
import yaml |
|
|
from functools import wraps |
|
|
import errno |
|
|
import signal |
|
|
import numpy as np |
|
|
from scipy.spatial import KDTree |
|
|
from math import ceil |
|
|
from tqdm import tqdm |
|
|
import line_profiler |
|
|
import os |
|
|
import base64 |
|
|
import pickle |
|
|
from cryptography.hazmat.primitives.asymmetric import padding |
|
|
from cryptography.hazmat.primitives import serialization, hashes |
|
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes |
|
|
from cryptography.hazmat.backends import default_backend |
|
|
|
|
|
import io |
|
|
|
|
|
|
|
|
def num_parameters(model: torch.nn.Module) -> int: |
|
|
"""Return the number of parameters in the model""" |
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
|
|
|
class Config: |
|
|
"""Read configuration from a YAML file and store as attributes""" |
|
|
|
|
|
def __init__(self, yaml_file: str): |
|
|
with open(yaml_file, "r") as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
for k, v in config.items(): |
|
|
setattr(self, k, v) |
|
|
|
|
|
def update(self, new_yaml_file: str): |
|
|
with open(new_yaml_file, "r") as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
for k, v in config.items(): |
|
|
setattr(self, k, v) |
|
|
|
|
|
def save(self, yaml_file: str): |
|
|
with open(yaml_file, "w") as f: |
|
|
yaml.dump(self.__dict__, f) |
|
|
|
|
|
|
|
|
def memory_usage_psutil(): |
|
|
"""Return the memory usage in percentage like top""" |
|
|
process = psutil.Process(os.getpid()) |
|
|
mem = process.memory_percent() |
|
|
return mem |
|
|
|
|
|
|
|
|
def is_wandb_running(): |
|
|
"""Check if wandb is running""" |
|
|
return "WANDB_SWEEP_ID" in os.environ |
|
|
|
|
|
|
|
|
class TimeoutError(Exception): |
|
|
pass |
|
|
|
|
|
|
|
|
def timeout(seconds=10, error_message=os.strerror(errno.ETIME)): |
|
|
def decorator(func): |
|
|
def _handle_timeout(signum, frame): |
|
|
raise TimeoutError(error_message) |
|
|
|
|
|
def wrapper(*args, **kwargs): |
|
|
signal.signal(signal.SIGALRM, _handle_timeout) |
|
|
signal.alarm(seconds) |
|
|
try: |
|
|
result = func(*args, **kwargs) |
|
|
finally: |
|
|
signal.alarm(0) |
|
|
return result |
|
|
|
|
|
return wraps(func)(wrapper) |
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
|
def shorten_path(path: str, max_len: int = 30) -> str: |
|
|
"""Shorten the path to max_len characters""" |
|
|
if len(path) > max_len: |
|
|
return path[:max_len // 2] + "..." + path[-max_len // 2:] |
|
|
return path |
|
|
|
|
|
|
|
|
def cluster_points(data: torch.Tensor, d: float) -> torch.Tensor: |
|
|
""" |
|
|
Cluster points based on the Euclidean distance. |
|
|
|
|
|
:param data: Input data, shape (n_points, n_features), type torch.Tensor. |
|
|
:param d: Distance threshold for clustering. |
|
|
:return: Cluster indices, shape (n_points,), type torch.Tensor. |
|
|
""" |
|
|
dist = torch.cdist(data, data) |
|
|
indices = torch.full((data.shape[0],), -1, dtype=torch.long) |
|
|
cluster_id = 0 |
|
|
for i in range(data.shape[0]): |
|
|
if indices[i] == -1: |
|
|
indices[dist[i] < d] = cluster_id |
|
|
cluster_id += 1 |
|
|
return indices |
|
|
|
|
|
|
|
|
def bron_kerbosch(R, P, X, graph): |
|
|
if not P and not X: |
|
|
yield R |
|
|
while P: |
|
|
v = P.pop() |
|
|
yield from bron_kerbosch( |
|
|
R | {v}, |
|
|
P & set(graph[v]), |
|
|
X & set(graph[v]), |
|
|
graph |
|
|
) |
|
|
X.add(v) |
|
|
|
|
|
|
|
|
def find_cliques(graph): |
|
|
""" |
|
|
Find all maximal cliques in an undirected graph with the Bron–Kerbosch algorithm. |
|
|
|
|
|
:param graph: Input graph as a NetworkX graph |
|
|
:return: List of maximal cliques |
|
|
""" |
|
|
return list(bron_kerbosch(set(), set(graph.nodes()), set(), graph)) |
|
|
|
|
|
|
|
|
def segment_cmd(cmd_str: str, max_len: int = 1000): |
|
|
cmds = [''] |
|
|
prev = 0 |
|
|
for i, c in enumerate(cmd_str): |
|
|
if c == ';': |
|
|
if len(cmds[-1]) + len(cmd_str[prev:i]) > max_len: |
|
|
cmds.append('') |
|
|
cmds[-1] += cmd_str[prev:i + 1] |
|
|
prev = i + 1 |
|
|
return cmds |
|
|
|
|
|
|
|
|
def get_color(v): |
|
|
assert 0 <= v <= 1, f'v should be in [0, 1], got {v}' |
|
|
|
|
|
color1 = np.array([0, 128, 0]) |
|
|
color2 = np.array([165, 42, 42]) |
|
|
v = v * (color2 - color1) + color1 |
|
|
v /= 255 |
|
|
return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]' |
|
|
|
|
|
|
|
|
def generate_pymol_script(possible_sites): |
|
|
cmd = '' |
|
|
for i, pos in enumerate(possible_sites): |
|
|
cmd += f"pseudoatom s{i},pos=[{pos[0]:.1f},{pos[1]:.1f},{pos[2]:.1f}];color blue,s{i};" |
|
|
return cmd |
|
|
|
|
|
|
|
|
def remove_close_points_kdtree(points, min_distance): |
|
|
tree = KDTree(points) |
|
|
keep = np.ones(len(points), dtype=bool) |
|
|
for i, point in enumerate(points): |
|
|
if not keep[i]: |
|
|
continue |
|
|
neighbors = tree.query_ball_point( |
|
|
point, min_distance) |
|
|
keep[neighbors] = False |
|
|
keep[i] = True |
|
|
return points[keep] |
|
|
|
|
|
|
|
|
@line_profiler.profile |
|
|
def pack_bit(x: torch.Tensor): |
|
|
""" Pack the bit tensor to a sequence of bytes. |
|
|
Args: |
|
|
x (torch.Tensor): The input tensor to be packed. |
|
|
Returns: |
|
|
torch.Tensor: The packed tensor. |
|
|
""" |
|
|
batch_size, num_bits = x.shape |
|
|
num_bytes = (num_bits + 7) // 8 |
|
|
output = torch.zeros(batch_size, num_bytes, |
|
|
dtype=torch.uint8, device=x.device) |
|
|
for i in range(num_bits): |
|
|
byte_index = i // 8 |
|
|
bit_index = i % 8 |
|
|
output[:, byte_index] |= (x[:, i] << bit_index).to(torch.uint8) |
|
|
return output |
|
|
|
|
|
|
|
|
@line_profiler.profile |
|
|
def unpack_bit(x: torch.Tensor, num_bits: int): |
|
|
""" Unpack the bit tensor from a sequence of bytes. |
|
|
Args: |
|
|
x (torch.Tensor): The input tensor to be unpacked. |
|
|
num_bits (int): The number of bits to unpack. |
|
|
Returns: |
|
|
torch.Tensor: The unpacked tensor. |
|
|
""" |
|
|
batch_size, num_bytes = x.shape |
|
|
output = torch.zeros(batch_size, num_bits, |
|
|
dtype=torch.uint8, device=x.device) |
|
|
for i in range(num_bits): |
|
|
byte_index = i // 8 |
|
|
bit_index = i % 8 |
|
|
output[:, i] = (x[:, byte_index] >> bit_index) & 1 |
|
|
return output |
|
|
|
|
|
|
|
|
def safe_dist(vec1: torch.Tensor, vec2: torch.Tensor, max_size: int = 100_000_000, p: int = 2): |
|
|
""" compute the minimum distance between two vectors: |
|
|
|
|
|
vec1: (N, 3), N could be very very large, i.e., all atoms' coordinates in a large protein |
|
|
|
|
|
vec2: (M, 3), M are not very large, usually the coordinates of the binding sites |
|
|
|
|
|
max_size: the maximum size of the distance matrix to compute at once |
|
|
|
|
|
p: the p-norm to use for distance calculation |
|
|
|
|
|
return: (M, ) the minimum distance of each binding site to the protein |
|
|
""" |
|
|
size1 = vec1.shape |
|
|
size2 = vec2.shape |
|
|
batch_size = ceil(max_size / size1[0]) |
|
|
dists = [] |
|
|
for i in range(0, size2[0], batch_size): |
|
|
dist = torch.cdist(vec1, vec2[i:i + batch_size], p=p) |
|
|
dists.append(dist.min(dim=0).values) |
|
|
return torch.cat(dists) |
|
|
|
|
|
|
|
|
@line_profiler.profile |
|
|
def safe_filter(nos: torch.Tensor, pos: torch.Tensor, thr: torch.Tensor, all: torch.Tensor, lb: float, max_size: int = 100_000_000): |
|
|
""" filter the binding sites based on the distance matrix |
|
|
nos: (N, 3), N are the coordinates of the binding sites |
|
|
*pos: (M, 3), M are the coordinates of the protein, could be very very large |
|
|
thr: (N, 2), the distance threshold for each binding site |
|
|
all: (P, 3), P are the coordinates of all atoms in the protein |
|
|
lb: the lower bound of the distance |
|
|
|
|
|
return: (N, M) available binding sites |
|
|
""" |
|
|
N, M, P = nos.shape[0], pos.shape[0], all.shape[0] |
|
|
batch_size = ceil(max_size / N) |
|
|
output = [] |
|
|
interests = [] |
|
|
for i in tqdm(range(0, M, batch_size), leave=False, desc=f'Filtering (batch_size: {batch_size})'): |
|
|
dist = torch.cdist(pos[i:i + batch_size], nos) |
|
|
dist = (dist <= thr[:, 1].unsqueeze(0)) & \ |
|
|
(dist >= thr[:, 0].unsqueeze(0)) |
|
|
dist_all = safe_dist(all, pos[i:i + batch_size]) > lb |
|
|
dist = dist & dist_all.unsqueeze(-1) |
|
|
|
|
|
mask = dist.any(dim=1) |
|
|
output.append(pack_bit(dist[mask]).T) |
|
|
interests.append(mask) |
|
|
return torch.cat(output, dim=1), torch.cat(interests) |
|
|
|
|
|
|
|
|
def backbone(atoms, chain_id): |
|
|
""" return the atoms of the backbone of a chain """ |
|
|
return atoms[ |
|
|
(atoms.chain_id == chain_id) & |
|
|
(atoms.atom_name == "CA") & |
|
|
(atoms.element == "C")] |
|
|
|
|
|
|
|
|
def get_color(v): |
|
|
assert 0 <= v <= 1, f'v should be in [0, 1], got {v}' |
|
|
|
|
|
color1 = np.array([0, 128, 0]) |
|
|
color2 = np.array([165, 42, 42]) |
|
|
v = v * (color2 - color1) + color1 |
|
|
v /= 255 |
|
|
return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]' |
|
|
|
|
|
|
|
|
def load_private_key_from_file(private_key_file=None): |
|
|
if private_key_file is None: |
|
|
private_key_b64 = os.environ.get('ModelCheckpointPrivateKey') |
|
|
else: |
|
|
with open(private_key_file, 'r') as f: |
|
|
private_key_b64 = f.read().strip() |
|
|
|
|
|
private_pem = base64.b64decode(private_key_b64) |
|
|
private_key = serialization.load_pem_private_key( |
|
|
private_pem, |
|
|
password=None, |
|
|
backend=default_backend() |
|
|
) |
|
|
return private_key |
|
|
|
|
|
|
|
|
def decrypt_checkpoint(encrypted_path, private_key): |
|
|
backend = default_backend() |
|
|
|
|
|
with open(encrypted_path, 'rb') as f: |
|
|
|
|
|
key_length = int.from_bytes(f.read(4), 'big') |
|
|
|
|
|
encrypted_aes_key = f.read(key_length) |
|
|
iv = f.read(16) |
|
|
original_size = int.from_bytes(f.read(8), 'big') |
|
|
encrypted_data = f.read() |
|
|
|
|
|
try: |
|
|
aes_key = private_key.decrypt( |
|
|
encrypted_aes_key, |
|
|
padding.OAEP( |
|
|
mgf=padding.MGF1(algorithm=hashes.SHA256()), |
|
|
algorithm=hashes.SHA256(), |
|
|
label=None |
|
|
) |
|
|
) |
|
|
|
|
|
cipher = Cipher(algorithms.AES(aes_key), |
|
|
modes.CBC(iv), backend=backend) |
|
|
decryptor = cipher.decryptor() |
|
|
decrypted_padded = decryptor.update( |
|
|
encrypted_data) + decryptor.finalize() |
|
|
|
|
|
decrypted_data = decrypted_padded[:original_size] |
|
|
|
|
|
try: |
|
|
buffer = io.BytesIO(decrypted_data) |
|
|
checkpoint_dict = torch.load(buffer, map_location='cpu') |
|
|
return checkpoint_dict |
|
|
except: |
|
|
checkpoint_dict = pickle.loads(decrypted_data) |
|
|
return checkpoint_dict |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
raise |
|
|
|