| | import os |
| | from collections import defaultdict |
| | from io import BytesIO |
| | from typing import Any, ContextManager, Sequence, TypeVar |
| | from warnings import warn |
| |
|
| | import huggingface_hub |
| | import numpy as np |
| | import torch |
| | import zstd |
| |
|
| | from src.data.esm.utils.constants.esm3 import CHAIN_BREAK_STR |
| | from src.data.esm.utils.types import FunctionAnnotation |
| |
|
| | MAX_SUPPORTED_DISTANCE = 1e6 |
| |
|
| |
|
| | TSequence = TypeVar("TSequence", bound=Sequence) |
| |
|
| |
|
| | def slice_python_object_as_numpy( |
| | obj: TSequence, idx: int | list[int] | slice | np.ndarray |
| | ) -> TSequence: |
| | """ |
| | Slice a python object (like a list, string, or tuple) as if it was a numpy object. |
| | |
| | Example: |
| | >>> obj = "ABCDE" |
| | >>> slice_python_object_as_numpy(obj, [1, 3, 4]) |
| | "BDE" |
| | |
| | >>> obj = [1, 2, 3, 4, 5] |
| | >>> slice_python_object_as_numpy(obj, np.arange(5) < 3) |
| | [1, 2, 3] |
| | """ |
| | if isinstance(idx, int): |
| | idx = [idx] |
| |
|
| | if isinstance(idx, np.ndarray) and idx.dtype == bool: |
| | sliced_obj = [obj[i] for i in np.where(idx)[0]] |
| | elif isinstance(idx, slice): |
| | sliced_obj = obj[idx] |
| | else: |
| | sliced_obj = [obj[i] for i in idx] |
| |
|
| | match obj, sliced_obj: |
| | case str(), list(): |
| | sliced_obj = "".join(sliced_obj) |
| | case _: |
| | sliced_obj = obj.__class__(sliced_obj) |
| |
|
| | return sliced_obj |
| |
|
| |
|
| | def rbf(values, v_min, v_max, n_bins=16): |
| | """ |
| | Returns RBF encodings in a new dimension at the end. |
| | """ |
| | rbf_centers = torch.linspace( |
| | v_min, v_max, n_bins, device=values.device, dtype=values.dtype |
| | ) |
| | rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1]) |
| | rbf_std = (v_max - v_min) / n_bins |
| | z = (values.unsqueeze(-1) - rbf_centers) / rbf_std |
| | return torch.exp(-(z**2)) |
| |
|
| |
|
| | def batched_gather(data, inds, dim=0, no_batch_dims=0): |
| | ranges = [] |
| | for i, s in enumerate(data.shape[:no_batch_dims]): |
| | r = torch.arange(s) |
| | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) |
| | ranges.append(r) |
| |
|
| | remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] |
| | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds |
| | ranges.extend(remaining_dims) |
| | return data[ranges] |
| |
|
| |
|
| | def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor: |
| | return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1) |
| |
|
| |
|
| | def knn_graph( |
| | coords: torch.Tensor, |
| | coord_mask: torch.Tensor, |
| | padding_mask: torch.Tensor, |
| | sequence_id: torch.Tensor, |
| | *, |
| | no_knn: int, |
| | ): |
| | L = coords.shape[-2] |
| | num_by_dist = min(no_knn, L) |
| | device = coords.device |
| |
|
| | coords = coords.nan_to_num() |
| | coord_mask = ~(coord_mask[..., None, :] & coord_mask[..., :, None]) |
| | padding_pairwise_mask = padding_mask[..., None, :] | padding_mask[..., :, None] |
| | if sequence_id is not None: |
| | padding_pairwise_mask |= torch.unsqueeze(sequence_id, 1) != torch.unsqueeze( |
| | sequence_id, 2 |
| | ) |
| | dists = (coords.unsqueeze(-2) - coords.unsqueeze(-3)).norm(dim=-1) |
| | arange = torch.arange(L, device=device) |
| | seq_dists = (arange.unsqueeze(-1) - arange.unsqueeze(-2)).abs() |
| | |
| | |
| | |
| | max_dist = MAX_SUPPORTED_DISTANCE |
| | torch._assert_async((dists[~coord_mask] < max_dist).all()) |
| | struct_then_seq_dist = ( |
| | seq_dists.to(dists.dtype) |
| | .mul(1e2) |
| | .add(max_dist) |
| | .where(coord_mask, dists) |
| | .masked_fill(padding_pairwise_mask, torch.inf) |
| | ) |
| | dists, edges = struct_then_seq_dist.sort(dim=-1, descending=False) |
| | |
| | |
| | chosen_edges = edges[..., :num_by_dist] |
| | chosen_mask = dists[..., :num_by_dist].isfinite() |
| | return chosen_edges, chosen_mask |
| |
|
| |
|
| | def stack_variable_length_tensors( |
| | sequences: Sequence[torch.Tensor], |
| | constant_value: int | float = 0, |
| | dtype: torch.dtype | None = None, |
| | ) -> torch.Tensor: |
| | """Automatically stack tensors together, padding variable lengths with the |
| | value in constant_value. Handles an arbitrary number of dimensions. |
| | |
| | Examples: |
| | >>> tensor1, tensor2 = torch.ones([2]), torch.ones([5]) |
| | >>> stack_variable_length_tensors(tensor1, tensor2) |
| | tensor of shape [2, 5]. First row is [1, 1, 0, 0, 0]. Second row is all ones. |
| | |
| | >>> tensor1, tensor2 = torch.ones([2, 4]), torch.ones([5, 3]) |
| | >>> stack_variable_length_tensors(tensor1, tensor2) |
| | tensor of shape [2, 5, 4] |
| | """ |
| | batch_size = len(sequences) |
| | shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist() |
| |
|
| | if dtype is None: |
| | dtype = sequences[0].dtype |
| | device = sequences[0].device |
| |
|
| | array = torch.full(shape, constant_value, dtype=dtype, device=device) |
| | for arr, seq in zip(array, sequences): |
| | arrslice = tuple(slice(dim) for dim in seq.shape) |
| | arr[arrslice] = seq |
| |
|
| | return array |
| |
|
| |
|
| | def unbinpack( |
| | tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float |
| | ): |
| | """ |
| | Args: |
| | tensor (Tensor): [B, L, ...] |
| | |
| | Returns: |
| | Tensor: [B_unbinpacked, L_unbinpack, ...] |
| | """ |
| | if sequence_id is None: |
| | return tensor |
| |
|
| | unpacked_tensors = [] |
| | num_sequences = sequence_id.max(dim=-1).values + 1 |
| | for batch_idx, (batch_seqid, batch_num_sequences) in enumerate( |
| | zip(sequence_id, num_sequences) |
| | ): |
| | for seqid in range(batch_num_sequences): |
| | mask = batch_seqid == seqid |
| | unpacked = tensor[batch_idx, mask] |
| | unpacked_tensors.append(unpacked) |
| | return stack_variable_length_tensors(unpacked_tensors, pad_value) |
| |
|
| |
|
| | def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: |
| | """ |
| | Returns an autocast context manager that disables downcasting by AMP. |
| | |
| | Args: |
| | device_type: The device type ('cpu' or 'cuda') |
| | |
| | Returns: |
| | An autocast context manager with the specified behavior. |
| | """ |
| | if device_type == "cpu": |
| | return torch.amp.autocast(device_type, enabled=False) |
| | elif device_type == "cuda": |
| | return torch.amp.autocast(device_type, dtype=torch.float32) |
| | else: |
| | raise ValueError(f"Unsupported device type: {device_type}") |
| |
|
| |
|
| | def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[range]: |
| | """Merge overlapping ranges into sorted, non-overlapping segments. |
| | |
| | Args: |
| | ranges: collection of ranges to merge. |
| | merge_gap_max: optionally merge neighboring ranges that are separated by a gap |
| | no larger than this size. |
| | Returns: |
| | non-overlapping ranges merged from the inputs, sorted by position. |
| | """ |
| | ranges = sorted(ranges, key=lambda r: r.start) |
| | merge_gap_max = merge_gap_max if merge_gap_max is not None else 0 |
| | assert merge_gap_max >= 0, f"Invalid merge_gap_max: {merge_gap_max}" |
| |
|
| | merged = [] |
| | for r in ranges: |
| | if not merged: |
| | merged.append(r) |
| | else: |
| | last = merged[-1] |
| | if last.stop + merge_gap_max >= r.start: |
| | merged[-1] = range(last.start, max(last.stop, r.stop)) |
| | else: |
| | merged.append(r) |
| | return merged |
| |
|
| |
|
| | def merge_annotations( |
| | annotations: list[FunctionAnnotation], merge_gap_max: int | None = None |
| | ) -> list[FunctionAnnotation]: |
| | """Merges annotations into non-overlapping segments. |
| | |
| | Args: |
| | annotations: annotations to merge. |
| | merge_gap_max: optionally merge neighboring ranges that are separated by a gap |
| | no larger than this size. |
| | Returns: |
| | non-overlapping annotations with gaps merged. |
| | """ |
| | grouped: dict[str, list[range]] = defaultdict(list) |
| | for a in annotations: |
| | |
| | grouped[a.label].append(range(a.start, a.end + 1)) |
| |
|
| | merged = [] |
| | for label, ranges in grouped.items(): |
| | merged_ranges = merge_ranges(ranges, merge_gap_max=merge_gap_max) |
| | for range_ in merged_ranges: |
| | annotation = FunctionAnnotation( |
| | label=label, |
| | start=range_.start, |
| | end=range_.stop - 1, |
| | ) |
| | merged.append(annotation) |
| | return merged |
| |
|
| |
|
| | def replace_inf(data): |
| | if data is None: |
| | return None |
| | array = np.array(data, dtype=np.float32, copy=False) |
| | array = np.where(np.isinf(array), -1, array) |
| | return array.tolist() |
| |
|
| |
|
| | def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None: |
| | if x is None: |
| | return None |
| | if convert_none_to_nan: |
| | x = np.array(x, copy=False, dtype=np.float32) |
| | x = np.where(x is None, np.nan, x) |
| | return torch.tensor(x) |
| |
|
| |
|
| | def maybe_list(x, convert_nan_to_none: bool = False) -> list | None: |
| | if x is None: |
| | return None |
| | if not convert_nan_to_none: |
| | return x.tolist() |
| | nan_mask = torch.isnan(x) |
| | np_arr = x.cpu().numpy().astype(object) |
| | np_arr[nan_mask.cpu().numpy()] = None |
| | return np_arr.tolist() |
| |
|
| |
|
| | def huggingfacehub_login(): |
| | """Authenticates with the Hugging Face Hub using the HF_TOKEN environment |
| | variable, else by prompting the user""" |
| | token = os.environ.get("HF_TOKEN") |
| | huggingface_hub.login(token=token) |
| |
|
| |
|
| | def get_chainbreak_boundaries_from_sequence(sequence: Sequence[str]) -> np.ndarray: |
| | chain_boundaries = [0] |
| | for i, aa in enumerate(sequence): |
| | if aa == CHAIN_BREAK_STR: |
| | if i == (len(sequence) - 1): |
| | raise ValueError( |
| | "Encountered chain break token at end of sequence, this is unexpected." |
| | ) |
| | if i == (len(sequence) - 2): |
| | warn( |
| | "Encountered chain break token at penultimate position, this is unexpected." |
| | ) |
| | chain_boundaries.append(i) |
| | chain_boundaries.append(i + 1) |
| | chain_boundaries.append(len(sequence)) |
| | assert len(chain_boundaries) % 2 == 0 |
| | chain_boundaries = np.array(chain_boundaries).reshape(-1, 2) |
| | return chain_boundaries |
| |
|
| |
|
| | def deserialize_tensors(b: bytes) -> Any: |
| | buf = BytesIO(zstd.ZSTD_uncompress(b)) |
| | d = torch.load(buf, map_location="cpu", weights_only=False) |
| | return d |
| |
|