File size: 10,238 Bytes
21f308b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import os
import signal

import psutil
import torch
import yaml
from functools import wraps
import errno
import signal
import numpy as np
from scipy.spatial import KDTree
from math import ceil
from tqdm import tqdm
import line_profiler
import os
import base64
import pickle
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend

import io


def num_parameters(model: torch.nn.Module) -> int:
    """Return the number of parameters in the model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class Config:
    """Read configuration from a YAML file and store as attributes"""

    def __init__(self, yaml_file: str):
        with open(yaml_file, "r") as f:
            config = yaml.safe_load(f)

        for k, v in config.items():
            setattr(self, k, v)

    def update(self, new_yaml_file: str):
        with open(new_yaml_file, "r") as f:
            config = yaml.safe_load(f)

        for k, v in config.items():
            setattr(self, k, v)

    def save(self, yaml_file: str):
        with open(yaml_file, "w") as f:
            yaml.dump(self.__dict__, f)


def memory_usage_psutil():
    """Return the memory usage in percentage like top"""
    process = psutil.Process(os.getpid())
    mem = process.memory_percent()
    return mem


def is_wandb_running():
    """Check if wandb is running"""
    return "WANDB_SWEEP_ID" in os.environ


class TimeoutError(Exception):
    pass


def timeout(seconds=10, error_message=os.strerror(errno.ETIME)):
    def decorator(func):
        def _handle_timeout(signum, frame):
            raise TimeoutError(error_message)

        def wrapper(*args, **kwargs):
            signal.signal(signal.SIGALRM, _handle_timeout)
            signal.alarm(seconds)
            try:
                result = func(*args, **kwargs)
            finally:
                signal.alarm(0)
            return result

        return wraps(func)(wrapper)

    return decorator


def shorten_path(path: str, max_len: int = 30) -> str:
    """Shorten the path to max_len characters"""
    if len(path) > max_len:
        return path[:max_len // 2] + "..." + path[-max_len // 2:]
    return path


def cluster_points(data: torch.Tensor, d: float) -> torch.Tensor:
    """
    Cluster points based on the Euclidean distance.

    :param data: Input data, shape (n_points, n_features), type torch.Tensor.
    :param d: Distance threshold for clustering.
    :return: Cluster indices, shape (n_points,), type torch.Tensor.
    """
    dist = torch.cdist(data, data)
    indices = torch.full((data.shape[0],), -1, dtype=torch.long)
    cluster_id = 0
    for i in range(data.shape[0]):
        if indices[i] == -1:
            indices[dist[i] < d] = cluster_id
            cluster_id += 1
    return indices


def bron_kerbosch(R, P, X, graph):
    if not P and not X:
        yield R
    while P:
        v = P.pop()
        yield from bron_kerbosch(
            R | {v},
            P & set(graph[v]),
            X & set(graph[v]),
            graph
        )
        X.add(v)


def find_cliques(graph):
    """
    Find all maximal cliques in an undirected graph with the Bron–Kerbosch algorithm.

    :param graph: Input graph as a NetworkX graph
    :return: List of maximal cliques
    """
    return list(bron_kerbosch(set(), set(graph.nodes()), set(), graph))


def segment_cmd(cmd_str: str, max_len: int = 1000):
    cmds = ['']
    prev = 0
    for i, c in enumerate(cmd_str):
        if c == ';':
            if len(cmds[-1]) + len(cmd_str[prev:i]) > max_len:
                cmds.append('')
            cmds[-1] += cmd_str[prev:i + 1]
            prev = i + 1
    return cmds


def get_color(v):
    assert 0 <= v <= 1, f'v should be in [0, 1], got {v}'
    # green to brown
    color1 = np.array([0, 128, 0])
    color2 = np.array([165, 42, 42])
    v = v * (color2 - color1) + color1
    v /= 255
    return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]'


def generate_pymol_script(possible_sites):
    cmd = ''
    for i, pos in enumerate(possible_sites):
        cmd += f"pseudoatom s{i},pos=[{pos[0]:.1f},{pos[1]:.1f},{pos[2]:.1f}];color blue,s{i};"
    return cmd


def remove_close_points_kdtree(points, min_distance):
    tree = KDTree(points)
    keep = np.ones(len(points), dtype=bool)
    for i, point in enumerate(points):
        if not keep[i]:
            continue
        neighbors = tree.query_ball_point(
            point, min_distance)
        keep[neighbors] = False
        keep[i] = True  # Keep the current point
    return points[keep]


@line_profiler.profile
def pack_bit(x: torch.Tensor):
    """ Pack the bit tensor to a sequence of bytes.
    Args:
        x (torch.Tensor): The input tensor to be packed.
    Returns:
        torch.Tensor: The packed tensor.
    """
    batch_size, num_bits = x.shape
    num_bytes = (num_bits + 7) // 8
    output = torch.zeros(batch_size, num_bytes,
                         dtype=torch.uint8, device=x.device)
    for i in range(num_bits):
        byte_index = i // 8
        bit_index = i % 8
        output[:, byte_index] |= (x[:, i] << bit_index).to(torch.uint8)
    return output


@line_profiler.profile
def unpack_bit(x: torch.Tensor, num_bits: int):
    """ Unpack the bit tensor from a sequence of bytes.
    Args:
        x (torch.Tensor): The input tensor to be unpacked.
        num_bits (int): The number of bits to unpack.
    Returns:
        torch.Tensor: The unpacked tensor.
    """
    batch_size, num_bytes = x.shape
    output = torch.zeros(batch_size, num_bits,
                         dtype=torch.uint8, device=x.device)
    for i in range(num_bits):
        byte_index = i // 8
        bit_index = i % 8
        output[:, i] = (x[:, byte_index] >> bit_index) & 1
    return output


def safe_dist(vec1: torch.Tensor, vec2: torch.Tensor, max_size: int = 100_000_000, p: int = 2):
    """ compute the minimum distance between two vectors:

    vec1: (N, 3), N could be very very large, i.e., all atoms' coordinates in a large protein

    vec2: (M, 3), M are not very large, usually the coordinates of the binding sites

    max_size: the maximum size of the distance matrix to compute at once

    p: the p-norm to use for distance calculation

    return: (M, ) the minimum distance of each binding site to the protein
    """
    size1 = vec1.shape
    size2 = vec2.shape
    batch_size = ceil(max_size / size1[0])
    dists = []
    for i in range(0, size2[0], batch_size):
        dist = torch.cdist(vec1, vec2[i:i + batch_size], p=p)
        dists.append(dist.min(dim=0).values)
    return torch.cat(dists)


@line_profiler.profile
def safe_filter(nos: torch.Tensor, pos: torch.Tensor, thr: torch.Tensor, all: torch.Tensor, lb: float, max_size: int = 100_000_000):
    """ filter the binding sites based on the distance matrix 
    nos: (N, 3), N are the coordinates of the binding sites
    *pos: (M, 3), M are the coordinates of the protein, could be very very large
    thr: (N, 2), the distance threshold for each binding site
    all: (P, 3), P are the coordinates of all atoms in the protein
    lb: the lower bound of the distance

    return: (N, M) available binding sites
    """
    N, M, P = nos.shape[0], pos.shape[0], all.shape[0]
    batch_size = ceil(max_size / N)
    output = []
    interests = []
    for i in tqdm(range(0, M, batch_size), leave=False, desc=f'Filtering (batch_size: {batch_size})'):
        dist = torch.cdist(pos[i:i + batch_size], nos)
        dist = (dist <= thr[:, 1].unsqueeze(0)) & \
            (dist >= thr[:, 0].unsqueeze(0))
        dist_all = safe_dist(all, pos[i:i + batch_size]) > lb
        dist = dist & dist_all.unsqueeze(-1)

        mask = dist.any(dim=1)
        output.append(pack_bit(dist[mask]).T)
        interests.append(mask)
    return torch.cat(output, dim=1), torch.cat(interests)


def backbone(atoms, chain_id):
    """ return the atoms of the backbone of a chain """
    return atoms[
        (atoms.chain_id == chain_id) &
        (atoms.atom_name == "CA") &
        (atoms.element == "C")]


def get_color(v):
    assert 0 <= v <= 1, f'v should be in [0, 1], got {v}'
    # green to brown
    color1 = np.array([0, 128, 0])
    color2 = np.array([165, 42, 42])
    v = v * (color2 - color1) + color1
    v /= 255
    return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]'


def load_private_key_from_file(private_key_file=None):
    if private_key_file is None:
        private_key_b64 = os.environ.get('ModelCheckpointPrivateKey')
    else:
        with open(private_key_file, 'r') as f:
            private_key_b64 = f.read().strip()

    private_pem = base64.b64decode(private_key_b64)
    private_key = serialization.load_pem_private_key(
        private_pem,
        password=None,
        backend=default_backend()
    )
    return private_key


def decrypt_checkpoint(encrypted_path, private_key):
    backend = default_backend()

    with open(encrypted_path, 'rb') as f:

        key_length = int.from_bytes(f.read(4), 'big')

        encrypted_aes_key = f.read(key_length)
        iv = f.read(16)
        original_size = int.from_bytes(f.read(8), 'big')
        encrypted_data = f.read()

    try:
        aes_key = private_key.decrypt(
            encrypted_aes_key,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=None
            )
        )

        cipher = Cipher(algorithms.AES(aes_key),
                        modes.CBC(iv), backend=backend)
        decryptor = cipher.decryptor()
        decrypted_padded = decryptor.update(
            encrypted_data) + decryptor.finalize()

        decrypted_data = decrypted_padded[:original_size]

        try:
            buffer = io.BytesIO(decrypted_data)
            checkpoint_dict = torch.load(buffer, map_location='cpu')
            return checkpoint_dict
        except:
            checkpoint_dict = pickle.loads(decrypted_data)
            return checkpoint_dict

    except Exception as e:
        print(f"Error: {e}")
        raise