DELM / src /models /utils.py
xushijie
add app
21f308b
import os
import signal
import psutil
import torch
import yaml
from functools import wraps
import errno
import signal
import numpy as np
from scipy.spatial import KDTree
from math import ceil
from tqdm import tqdm
import line_profiler
import os
import base64
import pickle
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import io
def num_parameters(model: torch.nn.Module) -> int:
"""Return the number of parameters in the model"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class Config:
"""Read configuration from a YAML file and store as attributes"""
def __init__(self, yaml_file: str):
with open(yaml_file, "r") as f:
config = yaml.safe_load(f)
for k, v in config.items():
setattr(self, k, v)
def update(self, new_yaml_file: str):
with open(new_yaml_file, "r") as f:
config = yaml.safe_load(f)
for k, v in config.items():
setattr(self, k, v)
def save(self, yaml_file: str):
with open(yaml_file, "w") as f:
yaml.dump(self.__dict__, f)
def memory_usage_psutil():
"""Return the memory usage in percentage like top"""
process = psutil.Process(os.getpid())
mem = process.memory_percent()
return mem
def is_wandb_running():
"""Check if wandb is running"""
return "WANDB_SWEEP_ID" in os.environ
class TimeoutError(Exception):
pass
def timeout(seconds=10, error_message=os.strerror(errno.ETIME)):
def decorator(func):
def _handle_timeout(signum, frame):
raise TimeoutError(error_message)
def wrapper(*args, **kwargs):
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(seconds)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
return result
return wraps(func)(wrapper)
return decorator
def shorten_path(path: str, max_len: int = 30) -> str:
"""Shorten the path to max_len characters"""
if len(path) > max_len:
return path[:max_len // 2] + "..." + path[-max_len // 2:]
return path
def cluster_points(data: torch.Tensor, d: float) -> torch.Tensor:
"""
Cluster points based on the Euclidean distance.
:param data: Input data, shape (n_points, n_features), type torch.Tensor.
:param d: Distance threshold for clustering.
:return: Cluster indices, shape (n_points,), type torch.Tensor.
"""
dist = torch.cdist(data, data)
indices = torch.full((data.shape[0],), -1, dtype=torch.long)
cluster_id = 0
for i in range(data.shape[0]):
if indices[i] == -1:
indices[dist[i] < d] = cluster_id
cluster_id += 1
return indices
def bron_kerbosch(R, P, X, graph):
if not P and not X:
yield R
while P:
v = P.pop()
yield from bron_kerbosch(
R | {v},
P & set(graph[v]),
X & set(graph[v]),
graph
)
X.add(v)
def find_cliques(graph):
"""
Find all maximal cliques in an undirected graph with the Bron–Kerbosch algorithm.
:param graph: Input graph as a NetworkX graph
:return: List of maximal cliques
"""
return list(bron_kerbosch(set(), set(graph.nodes()), set(), graph))
def segment_cmd(cmd_str: str, max_len: int = 1000):
cmds = ['']
prev = 0
for i, c in enumerate(cmd_str):
if c == ';':
if len(cmds[-1]) + len(cmd_str[prev:i]) > max_len:
cmds.append('')
cmds[-1] += cmd_str[prev:i + 1]
prev = i + 1
return cmds
def get_color(v):
assert 0 <= v <= 1, f'v should be in [0, 1], got {v}'
# green to brown
color1 = np.array([0, 128, 0])
color2 = np.array([165, 42, 42])
v = v * (color2 - color1) + color1
v /= 255
return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]'
def generate_pymol_script(possible_sites):
cmd = ''
for i, pos in enumerate(possible_sites):
cmd += f"pseudoatom s{i},pos=[{pos[0]:.1f},{pos[1]:.1f},{pos[2]:.1f}];color blue,s{i};"
return cmd
def remove_close_points_kdtree(points, min_distance):
tree = KDTree(points)
keep = np.ones(len(points), dtype=bool)
for i, point in enumerate(points):
if not keep[i]:
continue
neighbors = tree.query_ball_point(
point, min_distance)
keep[neighbors] = False
keep[i] = True # Keep the current point
return points[keep]
@line_profiler.profile
def pack_bit(x: torch.Tensor):
""" Pack the bit tensor to a sequence of bytes.
Args:
x (torch.Tensor): The input tensor to be packed.
Returns:
torch.Tensor: The packed tensor.
"""
batch_size, num_bits = x.shape
num_bytes = (num_bits + 7) // 8
output = torch.zeros(batch_size, num_bytes,
dtype=torch.uint8, device=x.device)
for i in range(num_bits):
byte_index = i // 8
bit_index = i % 8
output[:, byte_index] |= (x[:, i] << bit_index).to(torch.uint8)
return output
@line_profiler.profile
def unpack_bit(x: torch.Tensor, num_bits: int):
""" Unpack the bit tensor from a sequence of bytes.
Args:
x (torch.Tensor): The input tensor to be unpacked.
num_bits (int): The number of bits to unpack.
Returns:
torch.Tensor: The unpacked tensor.
"""
batch_size, num_bytes = x.shape
output = torch.zeros(batch_size, num_bits,
dtype=torch.uint8, device=x.device)
for i in range(num_bits):
byte_index = i // 8
bit_index = i % 8
output[:, i] = (x[:, byte_index] >> bit_index) & 1
return output
def safe_dist(vec1: torch.Tensor, vec2: torch.Tensor, max_size: int = 100_000_000, p: int = 2):
""" compute the minimum distance between two vectors:
vec1: (N, 3), N could be very very large, i.e., all atoms' coordinates in a large protein
vec2: (M, 3), M are not very large, usually the coordinates of the binding sites
max_size: the maximum size of the distance matrix to compute at once
p: the p-norm to use for distance calculation
return: (M, ) the minimum distance of each binding site to the protein
"""
size1 = vec1.shape
size2 = vec2.shape
batch_size = ceil(max_size / size1[0])
dists = []
for i in range(0, size2[0], batch_size):
dist = torch.cdist(vec1, vec2[i:i + batch_size], p=p)
dists.append(dist.min(dim=0).values)
return torch.cat(dists)
@line_profiler.profile
def safe_filter(nos: torch.Tensor, pos: torch.Tensor, thr: torch.Tensor, all: torch.Tensor, lb: float, max_size: int = 100_000_000):
""" filter the binding sites based on the distance matrix
nos: (N, 3), N are the coordinates of the binding sites
*pos: (M, 3), M are the coordinates of the protein, could be very very large
thr: (N, 2), the distance threshold for each binding site
all: (P, 3), P are the coordinates of all atoms in the protein
lb: the lower bound of the distance
return: (N, M) available binding sites
"""
N, M, P = nos.shape[0], pos.shape[0], all.shape[0]
batch_size = ceil(max_size / N)
output = []
interests = []
for i in tqdm(range(0, M, batch_size), leave=False, desc=f'Filtering (batch_size: {batch_size})'):
dist = torch.cdist(pos[i:i + batch_size], nos)
dist = (dist <= thr[:, 1].unsqueeze(0)) & \
(dist >= thr[:, 0].unsqueeze(0))
dist_all = safe_dist(all, pos[i:i + batch_size]) > lb
dist = dist & dist_all.unsqueeze(-1)
mask = dist.any(dim=1)
output.append(pack_bit(dist[mask]).T)
interests.append(mask)
return torch.cat(output, dim=1), torch.cat(interests)
def backbone(atoms, chain_id):
""" return the atoms of the backbone of a chain """
return atoms[
(atoms.chain_id == chain_id) &
(atoms.atom_name == "CA") &
(atoms.element == "C")]
def get_color(v):
assert 0 <= v <= 1, f'v should be in [0, 1], got {v}'
# green to brown
color1 = np.array([0, 128, 0])
color2 = np.array([165, 42, 42])
v = v * (color2 - color1) + color1
v /= 255
return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]'
def load_private_key_from_file(private_key_file=None):
if private_key_file is None:
private_key_b64 = os.environ.get('ModelCheckpointPrivateKey')
else:
with open(private_key_file, 'r') as f:
private_key_b64 = f.read().strip()
private_pem = base64.b64decode(private_key_b64)
private_key = serialization.load_pem_private_key(
private_pem,
password=None,
backend=default_backend()
)
return private_key
def decrypt_checkpoint(encrypted_path, private_key):
backend = default_backend()
with open(encrypted_path, 'rb') as f:
key_length = int.from_bytes(f.read(4), 'big')
encrypted_aes_key = f.read(key_length)
iv = f.read(16)
original_size = int.from_bytes(f.read(8), 'big')
encrypted_data = f.read()
try:
aes_key = private_key.decrypt(
encrypted_aes_key,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
cipher = Cipher(algorithms.AES(aes_key),
modes.CBC(iv), backend=backend)
decryptor = cipher.decryptor()
decrypted_padded = decryptor.update(
encrypted_data) + decryptor.finalize()
decrypted_data = decrypted_padded[:original_size]
try:
buffer = io.BytesIO(decrypted_data)
checkpoint_dict = torch.load(buffer, map_location='cpu')
return checkpoint_dict
except:
checkpoint_dict = pickle.loads(decrypted_data)
return checkpoint_dict
except Exception as e:
print(f"Error: {e}")
raise