nas / PFMBench /src /data /esm /utils /misc.py
yuccaaa's picture
Add files using upload-large-folder tool
9627ce0 verified
raw
history blame
10.6 kB
import os
from collections import defaultdict
from io import BytesIO
from typing import Any, ContextManager, Sequence, TypeVar
from warnings import warn
import huggingface_hub
import numpy as np
import torch
import zstd
from src.data.esm.utils.constants.esm3 import CHAIN_BREAK_STR
from src.data.esm.utils.types import FunctionAnnotation
MAX_SUPPORTED_DISTANCE = 1e6
TSequence = TypeVar("TSequence", bound=Sequence)
def slice_python_object_as_numpy(
obj: TSequence, idx: int | list[int] | slice | np.ndarray
) -> TSequence:
"""
Slice a python object (like a list, string, or tuple) as if it was a numpy object.
Example:
>>> obj = "ABCDE"
>>> slice_python_object_as_numpy(obj, [1, 3, 4])
"BDE"
>>> obj = [1, 2, 3, 4, 5]
>>> slice_python_object_as_numpy(obj, np.arange(5) < 3)
[1, 2, 3]
"""
if isinstance(idx, int):
idx = [idx]
if isinstance(idx, np.ndarray) and idx.dtype == bool:
sliced_obj = [obj[i] for i in np.where(idx)[0]]
elif isinstance(idx, slice):
sliced_obj = obj[idx]
else:
sliced_obj = [obj[i] for i in idx]
match obj, sliced_obj:
case str(), list():
sliced_obj = "".join(sliced_obj)
case _:
sliced_obj = obj.__class__(sliced_obj) # type: ignore
return sliced_obj # type: ignore
def rbf(values, v_min, v_max, n_bins=16):
"""
Returns RBF encodings in a new dimension at the end.
"""
rbf_centers = torch.linspace(
v_min, v_max, n_bins, device=values.device, dtype=values.dtype
)
rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
rbf_std = (v_max - v_min) / n_bins
z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
return torch.exp(-(z**2))
def batched_gather(data, inds, dim=0, no_batch_dims=0):
ranges = []
for i, s in enumerate(data.shape[:no_batch_dims]):
r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r)
remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
return data[ranges]
def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1)
def knn_graph(
coords: torch.Tensor,
coord_mask: torch.Tensor,
padding_mask: torch.Tensor,
sequence_id: torch.Tensor,
*,
no_knn: int,
):
L = coords.shape[-2]
num_by_dist = min(no_knn, L)
device = coords.device
coords = coords.nan_to_num()
coord_mask = ~(coord_mask[..., None, :] & coord_mask[..., :, None])
padding_pairwise_mask = padding_mask[..., None, :] | padding_mask[..., :, None]
if sequence_id is not None:
padding_pairwise_mask |= torch.unsqueeze(sequence_id, 1) != torch.unsqueeze(
sequence_id, 2
)
dists = (coords.unsqueeze(-2) - coords.unsqueeze(-3)).norm(dim=-1)
arange = torch.arange(L, device=device)
seq_dists = (arange.unsqueeze(-1) - arange.unsqueeze(-2)).abs()
# We only support up to a certain distance, above that, we use sequence distance
# instead. This is so that when a large portion of the structure is masked out,
# the edges are built according to sequence distance.
max_dist = MAX_SUPPORTED_DISTANCE
torch._assert_async((dists[~coord_mask] < max_dist).all())
struct_then_seq_dist = (
seq_dists.to(dists.dtype)
.mul(1e2)
.add(max_dist)
.where(coord_mask, dists)
.masked_fill(padding_pairwise_mask, torch.inf)
)
dists, edges = struct_then_seq_dist.sort(dim=-1, descending=False)
# This is a L x L tensor, where we index by rows first,
# and columns are the edges we should pick.
chosen_edges = edges[..., :num_by_dist]
chosen_mask = dists[..., :num_by_dist].isfinite()
return chosen_edges, chosen_mask
def stack_variable_length_tensors(
sequences: Sequence[torch.Tensor],
constant_value: int | float = 0,
dtype: torch.dtype | None = None,
) -> torch.Tensor:
"""Automatically stack tensors together, padding variable lengths with the
value in constant_value. Handles an arbitrary number of dimensions.
Examples:
>>> tensor1, tensor2 = torch.ones([2]), torch.ones([5])
>>> stack_variable_length_tensors(tensor1, tensor2)
tensor of shape [2, 5]. First row is [1, 1, 0, 0, 0]. Second row is all ones.
>>> tensor1, tensor2 = torch.ones([2, 4]), torch.ones([5, 3])
>>> stack_variable_length_tensors(tensor1, tensor2)
tensor of shape [2, 5, 4]
"""
batch_size = len(sequences)
shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist()
if dtype is None:
dtype = sequences[0].dtype
device = sequences[0].device
array = torch.full(shape, constant_value, dtype=dtype, device=device)
for arr, seq in zip(array, sequences):
arrslice = tuple(slice(dim) for dim in seq.shape)
arr[arrslice] = seq
return array
def unbinpack(
tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float
):
"""
Args:
tensor (Tensor): [B, L, ...]
Returns:
Tensor: [B_unbinpacked, L_unbinpack, ...]
"""
if sequence_id is None:
return tensor
unpacked_tensors = []
num_sequences = sequence_id.max(dim=-1).values + 1
for batch_idx, (batch_seqid, batch_num_sequences) in enumerate(
zip(sequence_id, num_sequences)
):
for seqid in range(batch_num_sequences):
mask = batch_seqid == seqid
unpacked = tensor[batch_idx, mask]
unpacked_tensors.append(unpacked)
return stack_variable_length_tensors(unpacked_tensors, pad_value)
def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: # type: ignore
"""
Returns an autocast context manager that disables downcasting by AMP.
Args:
device_type: The device type ('cpu' or 'cuda')
Returns:
An autocast context manager with the specified behavior.
"""
if device_type == "cpu":
return torch.amp.autocast(device_type, enabled=False) # type: ignore
elif device_type == "cuda":
return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore
else:
raise ValueError(f"Unsupported device type: {device_type}")
def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[range]:
"""Merge overlapping ranges into sorted, non-overlapping segments.
Args:
ranges: collection of ranges to merge.
merge_gap_max: optionally merge neighboring ranges that are separated by a gap
no larger than this size.
Returns:
non-overlapping ranges merged from the inputs, sorted by position.
"""
ranges = sorted(ranges, key=lambda r: r.start)
merge_gap_max = merge_gap_max if merge_gap_max is not None else 0
assert merge_gap_max >= 0, f"Invalid merge_gap_max: {merge_gap_max}"
merged = []
for r in ranges:
if not merged:
merged.append(r)
else:
last = merged[-1]
if last.stop + merge_gap_max >= r.start:
merged[-1] = range(last.start, max(last.stop, r.stop))
else:
merged.append(r)
return merged
def merge_annotations(
annotations: list[FunctionAnnotation], merge_gap_max: int | None = None
) -> list[FunctionAnnotation]:
"""Merges annotations into non-overlapping segments.
Args:
annotations: annotations to merge.
merge_gap_max: optionally merge neighboring ranges that are separated by a gap
no larger than this size.
Returns:
non-overlapping annotations with gaps merged.
"""
grouped: dict[str, list[range]] = defaultdict(list)
for a in annotations:
# +1 since FunctionAnnotation.end is inlcusive.
grouped[a.label].append(range(a.start, a.end + 1))
merged = []
for label, ranges in grouped.items():
merged_ranges = merge_ranges(ranges, merge_gap_max=merge_gap_max)
for range_ in merged_ranges:
annotation = FunctionAnnotation(
label=label,
start=range_.start,
end=range_.stop - 1, # convert range.stop exclusive -> inclusive.
)
merged.append(annotation)
return merged
def replace_inf(data):
if data is None:
return None
array = np.array(data, dtype=np.float32, copy=False)
array = np.where(np.isinf(array), -1, array)
return array.tolist()
def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None:
if x is None:
return None
if convert_none_to_nan:
x = np.array(x, copy=False, dtype=np.float32)
x = np.where(x is None, np.nan, x)
return torch.tensor(x)
def maybe_list(x, convert_nan_to_none: bool = False) -> list | None:
if x is None:
return None
if not convert_nan_to_none:
return x.tolist()
nan_mask = torch.isnan(x)
np_arr = x.cpu().numpy().astype(object)
np_arr[nan_mask.cpu().numpy()] = None
return np_arr.tolist()
def huggingfacehub_login():
"""Authenticates with the Hugging Face Hub using the HF_TOKEN environment
variable, else by prompting the user"""
token = os.environ.get("HF_TOKEN")
huggingface_hub.login(token=token)
def get_chainbreak_boundaries_from_sequence(sequence: Sequence[str]) -> np.ndarray:
chain_boundaries = [0]
for i, aa in enumerate(sequence):
if aa == CHAIN_BREAK_STR:
if i == (len(sequence) - 1):
raise ValueError(
"Encountered chain break token at end of sequence, this is unexpected."
)
if i == (len(sequence) - 2):
warn(
"Encountered chain break token at penultimate position, this is unexpected."
)
chain_boundaries.append(i)
chain_boundaries.append(i + 1)
chain_boundaries.append(len(sequence))
assert len(chain_boundaries) % 2 == 0
chain_boundaries = np.array(chain_boundaries).reshape(-1, 2)
return chain_boundaries
def deserialize_tensors(b: bytes) -> Any:
buf = BytesIO(zstd.ZSTD_uncompress(b))
d = torch.load(buf, map_location="cpu", weights_only=False)
return d