| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import copy |
| | import logging |
| | import re |
| | import warnings |
| | from dataclasses import dataclass |
| | from functools import partial |
| | from typing import Dict, List, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | import src.data.protein.polyseq as polyseq |
| | import src.data.protein.starparser as sp |
| | from src.data import constants |
| | import gzip |
| | import io |
| |
|
| | @dataclass |
| | class SystemAssemblyInfo: |
| | """A class for representing the assembly information for System objects. |
| | |
| | assemblies (dict): a dictionary of assemblies with keys being assembly IDs |
| | and values being dictionaries with of the following structure: |
| | { |
| | "details": "complete icosahedral assembly", |
| | "instructions": [ |
| | { |
| | "oper_expression": "(1-60)", |
| | "chains": [0, 1, 2], |
| | |
| | # Each assembly instruction has information for generating |
| | # one or more images, with image `i` generated by applying |
| | # the sequence of operations with IDs in `operations[i]` to the |
| | # list of chains in `chains`. The corresponding operations |
| | # are described under `assembly_info["operations"][ID]`. |
| | "operations": [["X0", "1", "2", "3"], ["X0", "4", "5", "6"]]], |
| | }, |
| | ... |
| | ], |
| | } |
| | |
| | operations (dict): a dictionary with symmetry operations. Keys are operation IDs |
| | and values being dictionaries with the following structure: |
| | { |
| | "type": "identity operation", |
| | "name": "1_555", |
| | "matrix": np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), |
| | "vector": np.array([0., 0., 0.]), |
| | }, |
| | ... |
| | """ |
| |
|
| | assemblies: dict |
| | operations: dict |
| |
|
| | def __init__(self, assemblies: dict = dict(), operations: dict = dict()): |
| | self.assemblies = assemblies |
| | self.operations = operations |
| |
|
| | @staticmethod |
| | def make_operation(type: str, name: str, matrix: list, vector: list): |
| | op = { |
| | "type": type, |
| | "name": name, |
| | "matrix": np.zeros([3, 3]), |
| | "vector": np.zeros([3, 1]), |
| | } |
| | assert len(matrix) == 9, "expected 9 elements in rotation matrix" |
| | assert len(vector) == 3, "expected 3 elements in translation vector" |
| | for i in range(3): |
| | op["vector"][i] = float(vector[i]) |
| | for j in range(3): |
| | op["matrix"][i][j] = float(matrix[i * 3 + j]) |
| | return op |
| |
|
| | def delete_chain(self, cid: str): |
| | """Deletes the mention of the chain from assembly information. |
| | |
| | Args: |
| | cid (str): Chain ID to delete. |
| | """ |
| | for ass_id, assembly in self.assemblies.items(): |
| | for ins in assembly["instructions"]: |
| | ins["chains"] = [_id for _id in ins["chains"] if _id != cid] |
| |
|
| | def rename_chain(self, old_cid: str, new_cid: str): |
| | """Renames all mentions of a chain to its new chain ID. |
| | |
| | Args: |
| | old_cid (str): Chain ID to rename. |
| | new_cid (str): Newly assigned Chain ID. |
| | """ |
| | for ass_id, assembly in self.assemblies.items(): |
| | for ins in assembly["instructions"]: |
| | ins["chains"] = [ |
| | new_cid if cid == old_cid else cid for cid in ins["chains"] |
| | ] |
| |
|
| |
|
| | class StringList: |
| | """A class for representing and accessing a list of strings in a highly memory-efficient |
| | manner. Access is constant time, but modification is linear time in length of list. |
| | """ |
| |
|
| | def __init__(self, init_list: List[str] = []): |
| | self.string = "" |
| | self.rng = ArrayList(2, dtype=int) |
| | for i in range(len(init_list)): |
| | self.append(init_list[i]) |
| |
|
| | def __getitem__(self, i: int): |
| | beg, length = self.rng[i] |
| | return self.string[beg : beg + length] |
| |
|
| | def __setitem__(self, i: int, new_string: str): |
| | beg, length = self.rng[i] |
| | self.string = self.string[:beg] + new_string + self.string[beg + length :] |
| | if len(new_string) != length: |
| | self.rng[i, 1] = len(new_string) |
| | self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) - length |
| |
|
| | def __str__(self): |
| | return self.string |
| |
|
| | def __len__(self): |
| | return len(self.rng) |
| |
|
| | def copy(self): |
| | new_list = StringList() |
| | new_list.string = self.string |
| | new_list.rng = self.rng.copy() |
| | return new_list |
| |
|
| | def append(self, new_string: str): |
| | self.rng.append([len(self.string), len(new_string)]) |
| | self.string = self.string + new_string |
| |
|
| | def insert(self, i: int, new_string: str): |
| | if i < len(self): |
| | ix, _ = self.rng[i] |
| | elif i == len(self): |
| | if len(self) == 0: |
| | ix = 0 |
| | else: |
| | ix = self.rng[i - 1].sum() |
| | else: |
| | raise Exception( |
| | f"cannot insert in position {i} for stringList of length {len(self)}" |
| | ) |
| | self.string = self.string[0:ix] + new_string + self.string[ix:] |
| | self.rng.insert(i, [ix, len(new_string)]) |
| | if len(new_string) > 0: |
| | self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) |
| |
|
| | def pop(self, i: int): |
| | beg, length = self.rng[i] |
| | val = self.string[beg : beg + length] |
| | self.string = self.string[0:beg] + self.string[beg + length :] |
| | self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] - len(val) |
| | self.rng.pop(i) |
| | return val |
| |
|
| | def delete_range(self, rng: range): |
| | rng = sorted(rng) |
| | [i, j] = [rng[0], rng[-1]] |
| | beg, _ = self.rng[i] |
| | end = self.rng[j].sum() |
| | self.string = self.string[0:beg] + self.string[end:] |
| | self.rng[j + 1 :, 0] = self.rng[j + 1 :, 0] - (end - beg + 1) |
| | self.rng.delete_range(rng) |
| |
|
| |
|
| | class NameList: |
| | """A class for representing and accessing a list of "names"--i.e., strings that tend to |
| | have generic values, such that many repeat values are expected in a given list.""" |
| |
|
| | def __init__(self, init_list: List[str] = []): |
| | self._reindex(init_list) |
| |
|
| | def _reindex(self, init_list: List[str]): |
| | self.unique_names = [] |
| | self.name_indicies = dict() |
| | self.index_use = dict() |
| | self.indices = ArrayList(1, dtype=int) |
| | for name in init_list: |
| | self.append(name) |
| |
|
| | def copy(self): |
| | new_list = NameList() |
| | new_list.unique_names = self.unique_names.copy() |
| | new_list.name_indicies = self.name_indicies.copy() |
| | new_list.index_use = self.index_use.copy() |
| | new_list.indices = self.indices.copy() |
| | return new_list |
| |
|
| | def _check_index(self): |
| | L = len(self.unique_names) |
| | I = len(self.index_use) |
| | if (L > 2 * I) and (L - I > 10): |
| | self._reindex([self[i] for i in range(len(self))]) |
| |
|
| | def __getitem__(self, i: int): |
| | try: |
| | idx = self.indices[i].item() |
| | except IndexError as e: |
| | raise IndexError(f"index {i} out of range for nameList\n" + str(e)) |
| | return self.unique_names[idx] |
| |
|
| | def __setitem__(self, i: int, new_name: str): |
| | try: |
| | idx = self.indices[i] |
| | except IndexError as e: |
| | raise IndexError(f"index {i} out of range for nameList\n" + str(e)) |
| | self.index_use[idx] = self.index_use[idx] - 1 |
| | if self.index_use[idx] == 0: |
| | del self.index_use[idx] |
| | if new_name not in self.name_indicies: |
| | idx = len(self.name_indicies) |
| | self.name_indicies[new_name] = idx |
| | self.unique_names.append(new_name) |
| | else: |
| | idx = self.name_indicies[new_name] |
| | self.indices[i] = idx |
| | self._update_use(idx, 1) |
| | self._check_index() |
| |
|
| | def __str__(self): |
| | return str([self[i] for i in range(len(self))]) |
| |
|
| | def __len__(self): |
| | return len(self.indices) |
| |
|
| | def _update_use(self, idx, delta): |
| | self.index_use[idx] = self.index_use.get(idx, 0) + delta |
| | if self.index_use[idx] <= 0: |
| | del self.index_use[idx] |
| |
|
| | def _get_name_index(self, name: str): |
| | if name not in self.name_indicies: |
| | idx = len(self.name_indicies) |
| | self.name_indicies[name] = idx |
| | self.unique_names.append(name) |
| | else: |
| | idx = self.name_indicies[name] |
| | return idx |
| |
|
| | def append(self, name: str): |
| | idx = self._get_name_index(name) |
| | self.indices.append(idx) |
| | self.index_use[idx] = self.index_use.get(idx, 0) + 1 |
| |
|
| | def insert(self, i: int, new_string: str): |
| | idx = self._get_name_index(new_string) |
| | self.indices.insert(i, idx) |
| | self.index_use[idx] = self.index_use.get(idx, 0) + 1 |
| |
|
| | def pop(self, i: int): |
| | idx = self.indices.pop(i).item() |
| | val = self.unique_names[idx] |
| | self._update_use(idx, -1) |
| | self._check_index() |
| | return val |
| |
|
| | def delete_range(self, rng: range): |
| | for i in reversed(sorted(rng)): |
| | self.pop(i) |
| |
|
| |
|
| | class ArrayList: |
| | def __init__(self, ndims: int, dtype: type, length: int = 0, val=0): |
| | if ndims == 1: |
| | self._array = np.ndarray(shape=(max(length, 2)), dtype=dtype) |
| | else: |
| | self._array = np.ndarray(shape=(max(length, 2), ndims), dtype=dtype) |
| | self.ndims = ndims |
| | self._array[:] = val |
| | self.length = length |
| | |
| | self.array = self._array[: self.length] |
| |
|
| | def convert_negative_slice(self, slice_obj): |
| | start = slice_obj.start if slice_obj.start is not None else 0 |
| | stop = slice_obj.stop if slice_obj.stop is not None else self.length |
| |
|
| | if start < 0: |
| | start = self.length + start |
| | if stop < 0: |
| | stop = self.length + stop |
| |
|
| | return slice(start, stop, slice_obj.step) |
| |
|
| | def copy(self): |
| | new_list = ArrayList(ndims=self.ndims, dtype=self.array.dtype, length=len(self)) |
| | new_list[:] = self[:] |
| | return new_list |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def capacity(self): |
| | return self._array.shape[0] |
| |
|
| | def __getitem__(self, i: int): |
| | return self.array[i] |
| |
|
| | def __setitem__(self, i: int, row: list): |
| | self.array[i] = row |
| |
|
| | def resize(self, delta): |
| | |
| | new_length = self.length + delta |
| | cap = self._array.shape[0] |
| | if (new_length > cap) or (new_length < cap / 3): |
| | new_capacity = 2 * new_length |
| | self._resize(new_capacity) |
| | self.length = new_length |
| | self.array = self._array[: self.length] |
| |
|
| | def _resize(self, new_size): |
| | if self.ndims == 1: |
| | self._array.resize((new_size), refcheck=False) |
| | else: |
| | self._array.resize((new_size, self.ndims), refcheck=False) |
| |
|
| | def items(self): |
| | for i in range(self.length): |
| | yield self.array[i, :] |
| |
|
| | def append(self, row: list): |
| | self.resize(1) |
| | self.array[-1] = row |
| |
|
| | def insert(self, i: int, row: list): |
| | """Insert the row such that it ends up being at index ``i`` in the new arrayList""" |
| | |
| | self.resize(1) |
| |
|
| | |
| | self.array[i + 1 :] = self.array[i:-1] |
| |
|
| | |
| | self.array[i] = row |
| |
|
| | def pop(self, i: int): |
| | """Remove and return element at index i""" |
| |
|
| | |
| | row = self.array[i].copy() |
| |
|
| | |
| | self.array[i:-1] = self.array[i + 1 :] |
| |
|
| | |
| | self.resize(-1) |
| |
|
| | return row |
| |
|
| | def delete_range(self, rng: range): |
| | i, j = min(rng), max(rng) |
| |
|
| | |
| | cut_length = j - i + 1 |
| | new_length = len(self) - cut_length |
| | self.array[i:new_length] = self.array[j + 1 :] |
| |
|
| | |
| | self.resize(-cut_length) |
| |
|
| | def __str__(self): |
| | return str([self[i] for i in range(len(self))]) |
| |
|
| |
|
| | @dataclass |
| | class HierarchicList: |
| | """A utility class that represents a hierarchy of lists. Each level represents |
| | a list of elements, each element having a set of properties (each property being |
| | stored as an array-like object over elements). Further, each element has a number |
| | of children corresponding to a range of elements in a lower-hierarhy list.""" |
| |
|
| | _properties: dict |
| | _parent_list: HierarchicList |
| | _child_list: HierarchicList |
| | _num_children: ArrayList |
| | _child_offset: ArrayList |
| |
|
| | def __init__( |
| | self, |
| | properties: dict, |
| | parent_list: HierarchicList = None, |
| | num_children: ArrayList = ArrayList(1, dtype=int), |
| | ): |
| | self._properties = dict() |
| | for key in properties: |
| | self._properties[key] = properties[key].copy() |
| | self._parent_list = parent_list |
| | if self._parent_list is not None: |
| | self._parent_list._child_list = self |
| | self._child_list = None |
| | self._num_children = num_children.copy() if num_children is not None else None |
| | |
| | self._child_offset = None |
| |
|
| | def copy(self): |
| | new_list = HierarchicList( |
| | self._properties, self._parent_list, self._num_children |
| | ) |
| | new_list._child_list = self._child_list |
| | if self._child_offset is None: |
| | new_list._child_offset = None |
| | else: |
| | new_list._child_offset = self._child_offset.copy() |
| | return new_list |
| |
|
| | def set_parent(self, parent_list: HierarchicList): |
| | self._parent_list = parent_list |
| |
|
| | def child_index(self, i: int, at: int): |
| | if self._child_offset is not None: |
| | return self._child_offset[i] + at |
| | return self._num_children[0:i].sum() + at |
| |
|
| | def reindex(self): |
| | if self._num_children is not None: |
| | self._child_offset = ArrayList( |
| | 1, dtype=int, length=len(self._num_children), val=0 |
| | ) |
| | for i in range(1, len(self)): |
| | self._child_offset[i] = ( |
| | self._child_offset[i - 1] + self._num_children[i - 1] |
| | ) |
| |
|
| | def append_child(self, properties): |
| | self._num_children[len(self._num_children) - 1] += 1 |
| | self._child_list.append(properties) |
| |
|
| | def insert_child(self, i: int, at: int, properties): |
| | idx = self.child_index(i, at) |
| | self._num_children[i] += 1 |
| | self._child_offset = None |
| | self._child_list.insert(idx, properties) |
| | return idx |
| |
|
| | def delete_child(self, i: int, at: int): |
| | idx = self.child_index(i, at) |
| | self._num_children[i] -= 1 |
| | self._child_offset = None |
| | self._child_list.delete(idx) |
| |
|
| | def append(self, properties): |
| | if set(properties.keys()) != set(self._properties.keys()): |
| | raise Exception(f"unexpected set of attributes '{list(properties.keys())}") |
| | for key, value in properties.items(): |
| | self._properties[key].append(value) |
| | if self._child_offset is not None: |
| | self._child_offset.append( |
| | self._child_offset[-1:].sum() + self._num_children[-1:].sum() |
| | ) |
| | if self._num_children is not None: |
| | self._num_children.append(0) |
| |
|
| | def insert(self, i: int, properties): |
| | if set(properties.keys()) != set(self._properties.keys()): |
| | raise Exception(f"unexpected set of attributes '{list(properties.keys())}") |
| | for key, value in properties.items(): |
| | self._properties[key].insert(i, value) |
| | if self._child_offset is not None: |
| | if i >= len(self._child_offset): |
| | off = self._child_offset[-1:].sum() + self._num_children[-1:].sum() |
| | else: |
| | off = self._child_offset[i] |
| | self._child_offset.insert(i, off) |
| | if self._num_children is not None: |
| | self._num_children.insert(i, 0) |
| |
|
| | def delete(self, i: int): |
| | for key in self._properties: |
| | self._properties[key].pop(i) |
| | if self._num_children is not None and self._num_children[i] != 0: |
| | for at in range(self._num_children[i] - 1, -1, -1): |
| | self.delete_child(i, at) |
| | self._num_children.pop(i) |
| | self._child_offset = None |
| |
|
| | def delete_range(self, rng: range): |
| | for key in self._properties: |
| | self._properties[key].delete_range(rng) |
| | |
| | for i in reversed(sorted(rng)): |
| | if self._num_children is not None and self._num_children[i] != 0: |
| | idx = self.child_index(i, 0) |
| | self._child_list.delete_range( |
| | self, range(idx, idx + self._num_children[i]) |
| | ) |
| | self._num_children[i] = 0 |
| | self._child_offset = None |
| |
|
| | def __len__(self): |
| | for key in self._properties: |
| | return len(self._properties[key]) |
| | return None |
| |
|
| | def __getitem__(self, i: str): |
| | return self._properties[i] |
| |
|
| | |
| | |
| |
|
| | def num_children(self, i: int): |
| | return self._num_children[i] |
| |
|
| | def has_children(self, i: int): |
| | return self._num_children is not None and self._num_children[i] |
| |
|
| | def __str__(self): |
| | string = "Properties:\n" |
| | for key in self._properties: |
| | string += f"{key}: {str(self._properties[key])}\n" |
| | string += f"num_children: {str(self._num_children)}\n" |
| | string += f"child_offset: {str(self._child_offset)}\n" |
| | string += "----\n" |
| | string += str(self._child_list) |
| | return string |
| |
|
| |
|
| | @dataclass |
| | class System: |
| | """A class for storing, accessing, managing, and manipulating a molecular |
| | system's structure, sequence, and topological information. The class is |
| | organized as a hierarchy of objects: |
| | |
| | System: top-level class containing all information about a molecular system |
| | -> Chain: a sub-portion of the System; for polymers this is generally a |
| | chemically connected molecular graph belong to a System (e.g., for |
| | protein complexes, this would be one of the proteins). |
| | -> Residue: a generally chemically-connected molecular unit (for polymers, |
| | the repeating unit), belonging to a Chain. |
| | -> Atom: an atom belonging to a Residue with zero, one, or more locations. |
| | -> AtomLocation: the location of an Atom (3D coordinates and other information). |
| | |
| | Attributes: |
| | name (str): given name for System |
| | _chains (list): a list of Chain objects |
| | _entities (dict): a dictionary of SystemEntity objects, with keys being entity IDs |
| | _chain_entities (list): `chain_entities[ci]` stores entity IDs (i.e., keys into |
| | `entities`) corresponding to the entity for chain `ci` |
| | _extra_models (list): a list of hierarchicList object, representing locations |
| | for alternative models |
| | _labels (dict): a dictionary of residue labels. A label is a string value, |
| | under some category (also a string), associated with a residue. E.g., |
| | the category could be "SSE" and the value could be "H" or "S". If entry |
| | `labels[category][gti]` exists and is equal to `value`, this means that |
| | residue with global template index `gti` has the label `category:value`. |
| | _selections (dict): a dictionary of selections. Keys are selection names and |
| | values are lists of corresponding gti indices. |
| | _assembly_info (SystemAssemblyInfo): information on symmetric assemblies that can |
| | be constructed from components of the molecular system. See ``SystemAssemblyInfo``. |
| | """ |
| |
|
| | name: str |
| | _chains: HierarchicList |
| | _residues: HierarchicList |
| | _atoms: HierarchicList |
| | _locations: HierarchicList |
| | _entities: Dict[int, SystemEntity] |
| | _chain_entities: List[int] |
| | _extra_models: List[HierarchicList] |
| | _labels: Dict[str, Dict[int, str]] |
| | _selections: Dict[str, List[int]] |
| | _assembly_info: SystemAssemblyInfo |
| |
|
| | def __init__(self, name: str = "system"): |
| | self.name = name |
| | self._chains = HierarchicList( |
| | properties={ |
| | "cid": StringList(), |
| | "segid": StringList(), |
| | "authid": StringList(), |
| | } |
| | ) |
| | self._residues = HierarchicList( |
| | properties={ |
| | "name": NameList(), |
| | "resnum": ArrayList(1, dtype=int), |
| | "authresid": StringList(), |
| | "icode": ArrayList(1, dtype="U1"), |
| | }, |
| | parent_list=self._chains, |
| | ) |
| | self._atoms = HierarchicList( |
| | properties={"name": NameList(), "het": ArrayList(1, dtype=bool)}, |
| | parent_list=self._residues, |
| | ) |
| | self._locations = HierarchicList( |
| | properties={ |
| | "coor": ArrayList(5, dtype=float), |
| | "alt": ArrayList(1, dtype="U1"), |
| | }, |
| | parent_list=self._atoms, |
| | num_children=None, |
| | ) |
| | self._entities = dict() |
| | self._chain_entities = [] |
| | self._extra_models = [] |
| | self._labels = dict() |
| | self._selections = dict() |
| | self._assembly_info = SystemAssemblyInfo() |
| |
|
| | def _reindex(self): |
| | self._chains.reindex() |
| | self._residues.reindex() |
| | self._atoms.reindex() |
| | self._locations.reindex() |
| |
|
| | def _print_indexing(self): |
| | for chain in self.chains(): |
| | off = self._chains.child_index(chain._ix, 0) |
| | num = self._chains.num_children(chain._ix) |
| | print(f"chain {chain._ix}, {chain}: [{off} - {off + num})") |
| | for residue in chain.residues(): |
| | off = self._residues.child_index(residue._ix, 0) |
| | num = self._residues.num_children(residue._ix) |
| | print(f"\tresidue {residue._ix}, {residue}: [{off} - {off + num})") |
| | for atom in residue.atoms(): |
| | off = self._atoms.child_index(atom._ix, 0) |
| | num = self._atoms.num_children(atom._ix) |
| | print(f"\t\tatom {atom._ix}, {atom}: [{off} - {off + num})") |
| | for loc in atom.locations(): |
| | has_children = self._locations.has_children(loc._ix) |
| | print( |
| | f"\t\t\tlocation {loc._ix}, {loc}: has children? {has_children}" |
| | ) |
| |
|
| | @classmethod |
| | def from_XCS( |
| | cls, |
| | X: torch.Tensor, |
| | C: torch.Tensor, |
| | S: torch.Tensor, |
| | alternate_alphabet: str = None, |
| | ) -> System: |
| | """Convert an XCS set of pytorch tensors to a new System object. |
| | |
| | B is batch size (Function only handles batch size of one now) |
| | N is the number of residues |
| | |
| | Args: |
| | X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. |
| | `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. |
| | C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes |
| | positions as 0 when masked, positive integers for chain indices, |
| | and negative integers to represent missing residues of the |
| | corresponding positive integers. |
| | S (torch.LongTensor): Sequence with shape `(1, num_residues)`. |
| | alternate_alphabet (str, optional): Optional alternative alphabet for |
| | sequence encoding. Otherwise the default alphabet is set in |
| | `constants.AA20`.Amino acid alphabet for embedding. |
| | Returns: |
| | System: A System object with the new XCS data. |
| | |
| | """ |
| | alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet |
| | all_atom = X.shape[2] == 14 |
| |
|
| | assert X.shape[0] == 1 |
| | assert C.shape[0] == 1 |
| | assert S.shape[0] == 1 |
| | assert X.shape[1] == S.shape[1] |
| | assert C.shape[1] == C.shape[1] |
| |
|
| | X, C, S = [T.squeeze(0).cpu().data.numpy() for T in [X, C, S]] |
| |
|
| | chain_ids = np.abs(C) |
| |
|
| | atom_count = 0 |
| | new_system = cls("system") |
| |
|
| | for i, chain_id in enumerate(np.unique(chain_ids)): |
| | if chain_id == 0: |
| | continue |
| |
|
| | chain_bool = chain_ids == chain_id |
| | X_chain = X[chain_bool, :, :].tolist() |
| | C_chain = C[chain_bool].tolist() |
| | S_chain = S[chain_bool].tolist() |
| |
|
| | |
| | chain = new_system.add_chain("A") |
| | for chain_ix, (X_i, C_i, S_i) in enumerate(zip(X_chain, C_chain, S_chain)): |
| | resname = polyseq.to_triple(alphabet[int(S_i)]) |
| |
|
| | |
| | residue = chain.add_residue( |
| | resname, chain_ix + 1, str(chain_ix + 1), " " |
| | ) |
| |
|
| | if C_i > 0: |
| | atom_names = constants.ATOMS_BB |
| |
|
| | if all_atom and resname in constants.AA_GEOMETRY: |
| | atom_names = ( |
| | atom_names + constants.AA_GEOMETRY[resname]["atoms"] |
| | ) |
| |
|
| | for atom_ix, atom_name in enumerate(atom_names): |
| | x, y, z = X_i[atom_ix] |
| | atom_count += 1 |
| | residue.add_atom(atom_name, False, x, y, z, 1.0, 0.0, " ") |
| |
|
| | |
| | for ci, chain in enumerate(new_system.chains()): |
| | seq = [None] * chain.num_residues() |
| | het = [None] * chain.num_residues() |
| | for ri, res in enumerate(chain.residues()): |
| | seq[ri] = res.name |
| | het[ri] = all(a.het for a in res.atoms()) |
| | entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) |
| | entity = SystemEntity( |
| | entity_type, f"chain {chain.cid}", polymer_type, seq, het |
| | ) |
| | new_system.add_new_entity(entity, [ci]) |
| |
|
| | return new_system |
| |
|
| | def to_XCS( |
| | self, |
| | all_atom: bool = False, |
| | batch_dimension: bool = True, |
| | mask_unknown: bool = True, |
| | unknown_token: int = 0, |
| | reorder_chain: bool = True, |
| | alternate_alphabet=None, |
| | alternate_atoms=None, |
| | get_indices=False, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Convert System object to XCS format. |
| | |
| | `C` tensor has shape [num_residues], where it codes positions as 0 |
| | when masked, positive integers for chain indices, and negative integers |
| | to represent missing residues of the corresponding positive integers. |
| | |
| | `S` tensor has shape [num_residues], it will map residue amino acid to alphabet integers. |
| | If it is not found in `alphabet`, it will default to `unknown_token`. Set `mask_unknown` to true if |
| | also want to mask `unk residue` in `chain_map` |
| | |
| | This function takes into account missing residues and updates chain_map |
| | accordingly. |
| | |
| | Args: |
| | system (type): generate System object to convert. |
| | all_atom (bool): Include side chain atoms. Default is `False`. |
| | batch_dimension (bool): Include a batch dimension. Default is `True`. |
| | mask_unknown (bool): Mask residues not found in the alphabet. Default is |
| | `True`. |
| | unknown_token (int): Default token index if a residue is not found in |
| | the alphabet. Default is `0`. |
| | reorder_chain (bool): If set to true will start indexing chain at 1, |
| | else will use the alphabet index (Default: True) |
| | altenate_alphabet (str): Alternative alphabet if not `None`. |
| | alternate_atoms (list): Alternate atom name subset for `X` if not `None`. |
| | get_indices (bool): Also return the location indices corresponding to the |
| | returned `X` tensor. |
| | |
| | Returns: |
| | X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. |
| | `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. |
| | C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes |
| | positions as 0 when masked, positive integers for chain indices, |
| | and negative integers to represent missing residues of the |
| | corresponding positive integers. |
| | S (torch.LongTensor): Sequence with shape `(1, num_residues)`. |
| | location_indices (np.ndaray, optional): location indices corresponding to |
| | the coordinates in `X`. |
| | |
| | """ |
| | alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet |
| |
|
| | |
| | C = [] |
| | for ch_id, chain in enumerate(self.chains()): |
| | ch_str = chain.cid |
| | if ch_str in list(constants.CHAIN_ALPHABET): |
| | map_ch_id = list(constants.CHAIN_ALPHABET).index(ch_str) |
| | else: |
| | |
| | map_ch_id = np.setdiff1d(np.arange(1, len(constants.CHAIN_ALPHABET)), np.unique(C))[0] |
| | |
| | if reorder_chain: |
| | map_ch_id = ch_id + 1 |
| | C += [map_ch_id] * chain.num_residues() |
| |
|
| | |
| | oneLetterSeq = self.sequence(format="one-letter-string") |
| |
|
| | if len(oneLetterSeq) != len(C): |
| | logging.warning("Warning, System and chain_map length don't agree") |
| |
|
| | |
| | atom_names = None |
| | if all_atom: |
| | num_atoms = 14 if all_atom else 4 |
| | else: |
| | if alternate_atoms is not None: |
| | atom_names = alternate_atoms |
| | else: |
| | atom_names = constants.ATOMS_BB |
| | num_atoms = len(atom_names) |
| | atom_names = {a: i for (i, a) in enumerate(atom_names)} |
| | num_residues = self.num_residues() |
| | X = np.zeros([num_residues, num_atoms, 3]) |
| | location_indices = ( |
| | np.zeros([num_residues * num_atoms], dtype=int) if get_indices else None |
| | ) |
| |
|
| | S = [] |
| | for i in range(num_residues): |
| | |
| | is_mask = False |
| |
|
| | |
| | if oneLetterSeq[i] in list(alphabet): |
| | S.append(alphabet.index(oneLetterSeq[i])) |
| | else: |
| | S.append(unknown_token) |
| | if mask_unknown: |
| | is_mask = True |
| |
|
| | |
| | res = self.get_residue(i) |
| | if res is None or not res.has_structure(): |
| | is_mask = True |
| |
|
| | |
| | if is_mask: |
| | |
| | C[i] = -abs(C[i]) |
| | else: |
| | |
| | if all_atom: |
| | code3 = constants.AA20_1_TO_3[oneLetterSeq[i]] |
| | atom_names = ( |
| | constants.ATOMS_BB + constants.AA_GEOMETRY[code3]["atoms"] |
| | ) |
| | atom_names = {a: i for (i, a) in enumerate(atom_names)} |
| |
|
| | X[ |
| | i, : |
| | ] = np.nan |
| | num_rem = len(atom_names) |
| | for atom in res.atoms(): |
| | name = System.protein_backbone_atom_type(atom.name, False, True) |
| | if name is None: |
| | name = atom.name |
| | ix = atom_names.get(name, None) |
| | if ix is None or not np.isnan(X[i, ix, 0]): |
| | continue |
| | for loc in atom.locations(): |
| | X[i, ix] = loc.coors |
| | if location_indices is not None: |
| | location_indices[i * num_atoms + ix] = loc.get_index() |
| | num_rem -= 1 |
| | break |
| | if num_rem == 0: |
| | break |
| | if num_rem != 0: |
| | C[i] = -abs(C[i]) |
| | X[i, :] = 0 |
| | np.nan_to_num(X[i, :], copy=False, nan=0) |
| |
|
| | |
| | X = torch.tensor(X).float() |
| | C = torch.tensor(C).type(torch.long) |
| | S = torch.tensor(S).type(torch.long) |
| |
|
| | |
| | if batch_dimension: |
| | X = X.unsqueeze(0) |
| | C = C.unsqueeze(0) |
| | S = S.unsqueeze(0) |
| |
|
| | if location_indices is not None: |
| | return X, C, S, location_indices |
| |
|
| | return X, C, S |
| |
|
| | def update_with_XCS(self, X, C=None, S=None, alternate_alphabet=None): |
| | """Update the System with XCS coordinates. NOTE: if the System has |
| | more than one model, and if the shape of the System changes (i.e., |
| | atoms are added or deleted), the additional models will be wiped. |
| | |
| | Args: |
| | X (Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. |
| | `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. |
| | C (LongTensor): Chain map with shape `(1, num_residues)`. It codes |
| | positions as 0 when masked, positive integers for chain indices, |
| | and negative integers to represent missing residues of the |
| | corresponding positive integers. Defaults to the current System's |
| | chain map. |
| | S (LongTensor): Sequence with shape `(1, num_residues)`. Defaults to |
| | the current System's sequence. |
| | """ |
| | if C is None or S is None: |
| | _, _C, _S = self.to_XCS() |
| | if C is None: |
| | C = _C |
| | if S is None: |
| | S = _S |
| |
|
| | |
| | if not ( |
| | (X.shape[1] == self.num_residues()) |
| | and (X.shape[1] == C.shape[1]) |
| | and (X.shape[1] == S.shape[1]) |
| | ): |
| | raise Exception( |
| | f"input tensor sizes {X.shape}, {C.shape}, and {S.shape}, disagree with System size {self.num_residues()}" |
| | ) |
| |
|
| | def _process_inputs(T): |
| | if T is not None: |
| | if len(T.shape) == 2 or len(T.shape) == 4: |
| | T = T.squeeze(0) |
| | T = T.to("cpu").detach().numpy() |
| | return T |
| |
|
| | X, C, S = map(_process_inputs, [X, C, S]) |
| |
|
| | shape_changed = False |
| | alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet |
| | for i, res in enumerate(self.residues()): |
| | |
| | if not res.has_structure() or C[i] <= 0: |
| | continue |
| |
|
| | |
| | resname = "UNK" |
| | if S is not None and S[i] < len(alphabet): |
| | resname = polyseq.to_triple(alphabet[S[i]]) |
| | |
| | if res.name != resname: |
| | res.rename(resname) |
| |
|
| | |
| | atoms_sys = [atom.name for atom in res.atoms()] |
| | atoms_XCS = constants.ATOMS_BB |
| | if resname in constants.AA_GEOMETRY: |
| | atoms_XCS = atoms_XCS + constants.AA_GEOMETRY[resname]["atoms"] |
| | atoms_XCS = atoms_XCS[: X.shape[1]] |
| | to_delete = [] |
| | for ix_a, atom in enumerate(res.atoms()): |
| | name = atom.name |
| | if name not in atoms_XCS or name in atoms_sys[:ix_a]: |
| | to_delete.append(atom) |
| | if len(to_delete) > 0: |
| | shape_changed = True |
| | res.delete_atoms(to_delete) |
| |
|
| | |
| | for x_id, atom_name in enumerate(atoms_XCS): |
| | atom = res.find_atom(atom_name) |
| | x, y, z = [X[i][x_id][k].item() for k in range(3)] |
| | if atom is not None and atom.num_locations() > 0: |
| | atom.x = x |
| | atom.y = y |
| | atom.z = z |
| | else: |
| | shape_changed = True |
| | if atom is not None: |
| | atom.add_location(x, y, z) |
| | else: |
| | res.add_atom(atom_name, False, x, y, z, 1.0, 0.0) |
| |
|
| | |
| | if shape_changed: |
| | self._extra_models = [] |
| |
|
| | def __str__(self): |
| | return "system " + self.name |
| |
|
| | def chains(self): |
| | """Chain iterator (generator function).""" |
| | for ci in range(len(self._chains)): |
| | yield ChainView(ci, self) |
| |
|
| | def get_chain(self, ci: int): |
| | """Returns the chain by index. |
| | |
| | Args: |
| | ci (int): Chain index (from 0) |
| | |
| | Returns: |
| | ChainView object corresponding to the chain in question. |
| | """ |
| | return ChainView(ci, self) |
| |
|
| | def get_chain_by_id(self, cid: str, segid=False): |
| | """Returns the chain by its string ID. |
| | |
| | Args: |
| | cid (str): Chain ID. |
| | segid (bool, optional): If set to True (default is False) will |
| | return the chain with the matching segment ID and not chain ID. |
| | |
| | Returns: |
| | ChainView object corresponding to the chain in question. |
| | """ |
| | for ci, chain in enumerate(self.chains()): |
| | if (not segid and cid == chain.cid) or (segid and cid == chain.segid): |
| | return ChainView(ci, self) |
| | return None |
| |
|
| | def get_chains(self): |
| | """Returns the list of all chains.""" |
| | return [ChainView(ci, self) for ci in range(len(self._chains))] |
| |
|
| | def get_chains_of_entity(self, entity_id: int, by=None): |
| | """Returns the list of chains that correspond to the given entity ID. |
| | |
| | Args: |
| | entity_id (int): Entity ID. |
| | by (str, optional): If specified as "index", will return a |
| | list of chain indices instead of ChainView objects. |
| | |
| | Returns: |
| | List of ChainView objects or chain indices. |
| | """ |
| | cixs = [ci for (ci, eid) in enumerate(self._chain_entities) if entity_id == eid] |
| | if by == "index": |
| | return cixs |
| | return [ChainView(ci, self) for ci in cixs] |
| |
|
| | def residues(self): |
| | """Residue iterator (generator function).""" |
| | for chain in self.chains(): |
| | for residue in chain.residues(): |
| | yield residue |
| |
|
| | def get_residue(self, gti: int): |
| | """Returns the residue at the given global index. |
| | |
| | Args: |
| | gti (int): Global residue index. |
| | |
| | Returns: |
| | ResidueView object corresponding to the index. |
| | """ |
| | if gti < 0: |
| | raise Exception(f"negative residue index: {gti}") |
| | off = 0 |
| | for chain in self.chains(): |
| | nr = chain.num_residues() |
| | if gti < off + nr: |
| | return chain.get_residue(gti - off) |
| | off = off + nr |
| | raise Exception( |
| | f"residue index {gti} out of range for System, which has {self.num_residues()} residues" |
| | ) |
| |
|
| | def atoms(self): |
| | """Iterator of atoms in this System (generator function).""" |
| | for chain in self.chains(): |
| | for residue in chain.residues(): |
| | for atom in residue.atoms(): |
| | yield atom |
| |
|
| | def get_atom(self, aidx: int): |
| | """Returns the atom at the given global atom index. |
| | |
| | Args: |
| | gti (int): Global atom index. |
| | |
| | Returns: |
| | AtomView object corresponding to the index. |
| | """ |
| | if aidx < 0: |
| | raise Exception(f"negative atom index: {aidx}") |
| | off = 0 |
| | for chain in self.chains(): |
| | na = chain.num_atoms() |
| | if aidx < off + na: |
| | return chain.get_atom(aidx - off) |
| | off = off + na |
| | raise Exception( |
| | f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" |
| | ) |
| |
|
| | def locations(self): |
| | """Iterator of atoms in this System (generator function).""" |
| | for chain in self.chains(): |
| | for residue in chain.residues(): |
| | for atom in residue.atoms(): |
| | for loc in atom.locations(): |
| | yield loc |
| |
|
| | def _new_locations(self): |
| | new_locs = self._locations.copy() |
| | for li in range(len(new_locs)): |
| | new_locs["coor"][li] = [np.nan] * 5 |
| | return new_locs |
| |
|
| | def select(self, expression: str, left_associativity: bool = True): |
| | """Evalates the given selection expression and returns all atoms |
| | involved in the result as a list of AtomView's. |
| | |
| | Args: |
| | expression (str): selection expression. |
| | left_associativity (bool, optional): determines whether operators |
| | in the expression are left-associative. |
| | |
| | Returns: |
| | List of AtomView's. |
| | """ |
| | val, selex_info = self._select( |
| | expression, left_associativity=left_associativity |
| | ) |
| |
|
| | |
| | result = [selex_info["all_atoms"][i].atom for i in sorted(val)] |
| |
|
| | return result |
| |
|
| | def select_residues( |
| | self, |
| | expression: str, |
| | gti: bool = False, |
| | allow_unstructured=False, |
| | left_associativity: bool = True, |
| | ): |
| | """Evalates the given selection expression and returns all residues with any |
| | atoms involved in the result as a list of ResidueView's or list of gti's. |
| | |
| | Args: |
| | expression (str): selection expression. |
| | gti (bool): if True (default is False), will return a list of gti |
| | instead of a list of ResidueView's. |
| | allow_unstructured (bool): If True (default is False), will allow |
| | unstructured residues to be selected. |
| | left_associativity (bool, optional): determines whether operators |
| | in the expression are left-associative. |
| | |
| | Returns: |
| | List of ResidueView's or gti's (ints). |
| | """ |
| | val, selex_info = self._select( |
| | expression, |
| | unstructured=allow_unstructured, |
| | left_associativity=left_associativity, |
| | ) |
| |
|
| | |
| | if gti: |
| | result = sorted(set([selex_info["all_atoms"][i].rix for i in val])) |
| | else: |
| | residues = dict() |
| | for i in val: |
| | a = selex_info["all_atoms"][i] |
| | residues[a.rix] = a.atom.residue |
| | result = [residues[rix] for rix in sorted(residues.keys())] |
| |
|
| | return result |
| |
|
| | def select_chains( |
| | self, expression: str, allow_unstructured=False, left_associativity: bool = True |
| | ): |
| | """Evalates the given selection expression and returns all chains with any |
| | atoms involved in the result as a list of ChainView's. |
| | |
| | Args: |
| | expression (str): selection expression. |
| | allow_unstructured (bool): If True (default is False), will allow |
| | unstructured chains to be selected. |
| | left_associativity (bool, optional): determines whether operators |
| | in the expression are left-associative. |
| | |
| | Returns: |
| | List of ResidueView's or gti's (ints). |
| | """ |
| | val, selex_info = self._select( |
| | expression, |
| | unstructured=allow_unstructured, |
| | left_associativity=left_associativity, |
| | ) |
| |
|
| | |
| | chains = dict() |
| | for i in val: |
| | a = selex_info["all_atoms"][i] |
| | chains[a.cix] = a.atom.chain |
| | result = [chains[rix] for rix in sorted(chains.keys())] |
| |
|
| | return result |
| |
|
| | def _select( |
| | self, |
| | expression: str, |
| | unstructured: bool = False, |
| | left_associativity: bool = True, |
| | ): |
| | |
| | @dataclass(frozen=True) |
| | class MappableAtom: |
| | atom: AtomView |
| | aix: int |
| | rix: int |
| | cix: int |
| |
|
| | def __hash__(self) -> int: |
| | return self.aix |
| |
|
| | |
| | all_atoms = [None] * self.num_atoms() |
| | cix, rix, aix = 0, 0, 0 |
| | for chain in self.chains(): |
| | for residue in chain.residues(): |
| | for atom in residue.atoms(): |
| | all_atoms[aix] = MappableAtom(atom, aix, rix, cix) |
| | aix = aix + 1 |
| |
|
| | |
| | |
| | |
| | if residue.num_atoms() == 0: |
| | view = DummyAtomView(residue) |
| | view.dummy = True |
| | |
| | all_atoms.append(None) |
| | all_atoms[aix] = MappableAtom(view, aix, rix, cix) |
| | aix = aix + 1 |
| | rix = rix + 1 |
| | cix = cix + 1 |
| |
|
| | _selex_info = {"all_atoms": all_atoms} |
| | _selex_info["all_indices_set"] = set([a.aix for a in all_atoms]) |
| |
|
| | |
| | |
| | tree = ExpressionTreeEvaluator( |
| | ["hyd", "all", "none"], |
| | ["not", "byres", "bychain", "first", "last", |
| | "chain", "authchain", "segid", "namesel", "gti", "resix", "resid", |
| | "authresid", "resname", "re", "x", "y", "z", "b", "icode", "name"], |
| | ["and", "or", "around", "saround"], |
| | eval_function=partial(self._selex_eval, _selex_info), |
| | left_associativity=left_associativity, |
| | debug=False, |
| | ) |
| | |
| |
|
| | |
| | val = tree.evaluate(expression) |
| |
|
| | |
| | |
| | |
| | if not unstructured: |
| | val = { |
| | i for i in val if not hasattr(_selex_info["all_atoms"][i].atom, "dummy") |
| | } |
| |
|
| | return val, _selex_info |
| |
|
| | def save_selection( |
| | self, |
| | expression: Optional[str] = None, |
| | gti: Optional[List[int]] = None, |
| | selname: str = "_default", |
| | allow_unstructured=False, |
| | left_associativity: bool = True, |
| | ): |
| | """Performs a selection on the System according to the given |
| | selection string and saves the indices of residues involved in |
| | the result (global template indices) under the given name. |
| | |
| | Args: |
| | expression (str): (optional) selection expression. |
| | gti (list of int): (optional) list of gti to define selection expression |
| | selname (str): selection name. |
| | allow_unstructured (bool): If True (default is False), will allow |
| | unstructured residues to be selected. |
| | left_associativity (bool, optional): determines whether operators |
| | in the expression are left-associative. |
| | """ |
| | if gti is not None: |
| | if expression is not None: |
| | warnings.warn( |
| | f"Expression and gti are both not null, expression will be ignored" |
| | f" and gti will be used!" |
| | ) |
| | result = sorted(gti) |
| | else: |
| | result = self.select_residues( |
| | expression, |
| | allow_unstructured=allow_unstructured, |
| | left_associativity=left_associativity, |
| | gti=True, |
| | ) |
| |
|
| | |
| | self._selections[selname] = result |
| |
|
| | def get_selected(self, selname: str = "_default"): |
| | """Returns the list of gti saved under the specified name. |
| | |
| | Args: |
| | selname (str): selection name. |
| | |
| | Returns: |
| | List of global template indices. |
| | """ |
| | if selname not in self._selections: |
| | raise Exception( |
| | f"selection by name '{selname}' does not exist in the System" |
| | ) |
| | return self._selections[selname] |
| |
|
| | def has_selection(self, selname: str = "_default"): |
| | """Returns whether the given named selection exists. |
| | |
| | Args: |
| | selname (str): selection name. |
| | |
| | Returns: |
| | Whether the selection exists in the System. |
| | """ |
| | return selname in self._selections |
| |
|
| | def get_selection_names(self): |
| | """Returns the list of all currently stored named selections.""" |
| | return list(self._selections.keys()) |
| |
|
| | def remove_selection(self, selname: str = "_default"): |
| | """Deletes the selection under the specified name. |
| | |
| | Args: |
| | selname (str): selection name. |
| | """ |
| | if selname not in self._selections: |
| | raise Exception( |
| | f"selection by name '{selname}' does not exist in the System" |
| | ) |
| | del self._selections[selname] |
| |
|
| | def _selex_eval(self, _selex_info, op: str, left, right): |
| | def _is_numeric(string: str) -> bool: |
| | try: |
| | float(string) |
| | return True |
| | except ValueError: |
| | return False |
| |
|
| | def _is_int(string: str) -> bool: |
| | try: |
| | int(string) |
| | return True |
| | except ValueError: |
| | return False |
| |
|
| | def _unpack_operands(operands, dests): |
| | assert len(operands) == len(dests) |
| | unpacked = [None] * len(operands) |
| | succ = True |
| | for i, (operand, dest) in enumerate(zip(operands, dests)): |
| | if dest is None: |
| | if operand is not None: |
| | succ = False |
| | break |
| | elif dest == "result": |
| | if not (isinstance(operand, dict) and "result" in operand): |
| | succ = False |
| | break |
| | unpacked[i] = operand["result"] |
| | elif dest == "string": |
| | if not (len(operand) == 1 and isinstance(operand[0], str)): |
| | succ = False |
| | break |
| | unpacked[i] = operand[0] |
| | elif dest == "strings": |
| | if not ( |
| | isinstance(operand, list) |
| | and all([isinstance(val, str) for val in operands]) |
| | ): |
| | succ = False |
| | break |
| | unpacked[i] = operands |
| | elif dest == "float": |
| | if not (len(operand) == 1 and _is_numeric(operand[0])): |
| | succ = False |
| | break |
| | unpacked[i] = float(operand[0]) |
| | elif dest == "floats": |
| | if not ( |
| | isinstance(operand, list) |
| | and all([_is_numeric(val) for val in operands]) |
| | ): |
| | succ = False |
| | break |
| | unpacked[i] = [float(val) for val in operands] |
| | elif dest == "range": |
| | test = _parse_range(operand) |
| | if test is None: |
| | succ = False |
| | break |
| | unpacked[i] = test |
| | elif dest == "int": |
| | if not (len(operand) == 1 and _is_int(operand[0])): |
| | succ = False |
| | break |
| | unpacked[i] = int(operand[0]) |
| | elif dest == "ints": |
| | if not ( |
| | isinstance(operand, list) |
| | and all([_is_int(val) for val in operands]) |
| | ): |
| | succ = False |
| | break |
| | unpacked[i] = [int(val) for val in operands] |
| | elif dest == "int_range": |
| | test = _parse_int_range(operand) |
| | if test is None: |
| | succ = False |
| | break |
| | unpacked[i] = test |
| | elif dest == "int_range_string": |
| | test = _parse_int_range(operand, string=True) |
| | if test is None: |
| | succ = False |
| | break |
| | unpacked[i] = test |
| | return unpacked, succ |
| |
|
| | def _parse_range(operands: list): |
| | """Parses range information given a list of operands that were originally separated |
| | by spaces. Allowed range expressiosn are of the form: `< n`, `> n`, `n:m` with |
| | optional spaces allowed between operands.""" |
| | if not ( |
| | isinstance(operands, list) |
| | and all([isinstance(opr, str) for opr in operands]) |
| | ): |
| | return None |
| | operand = "".join(operands) |
| | if operand.startswith(">") or operand.startswith("<"): |
| | if not _is_numeric(operand[1:]): |
| | return None |
| | num = float(operand[1:]) |
| | if operand.startswith(">"): |
| | test = lambda x, cut=num: x > cut |
| | else: |
| | test = lambda x, cut=num: x < cut |
| | elif ":" in operand: |
| | parts = operand.split(":") |
| | if (len(parts) != 2) or not all([_is_numeric(p) for p in parts]): |
| | return None |
| | parts = [float(p) for p in parts] |
| | test = lambda x, lims=parts: lims[0] < x < lims[1] |
| | elif _is_numeric(operand): |
| | target = float(operand) |
| | test = lambda x, t=target: x == t |
| | else: |
| | return None |
| | return test |
| |
|
| | def _parse_int_range(operands: list, string: bool = False): |
| | """Parses range of integers information given a list of operands that were |
| | originally separated by spaces. Allowed range expressiosn are of the form: |
| | `n`, `n-m`, `n+m`, with optional spaces allowed anywhere and combinations |
| | also allowed (e.g., "n+m+s+r-p+a").""" |
| | if not ( |
| | isinstance(operands, list) |
| | and all([isinstance(opr, str) for opr in operands]) |
| | ): |
| | return None |
| | operand = "".join(operands) |
| | operands = operand.split("+") |
| | ranges = [] |
| | for operand in operands: |
| | m = re.fullmatch("(.*\d)-(.+)", operand) |
| | if m: |
| | if not all([_is_int(g) for g in m.groups()]): |
| | return None |
| | r = range(int(m.group(1)), int(m.group(2)) + 1) |
| | ranges.append(r) |
| | else: |
| | if not _is_int(operand): |
| | return None |
| | if string: |
| | ranges.append(set([operand])) |
| | else: |
| | ranges.append(set([int(operand)])) |
| | if string: |
| | ranges = [[str(x) for x in r] for r in ranges] |
| | test = lambda x, ranges=ranges: any([x in r for r in ranges]) |
| | return test |
| |
|
| | |
| | result = set() |
| | if op in ("and", "or"): |
| | (Si, Sj), succ = _unpack_operands([left, right], ["result", "result"]) |
| | if not succ: |
| | return None |
| | if op == "and": |
| | result = set(Si).intersection(set(Sj)) |
| | else: |
| | result = set(Si).union(set(Sj)) |
| | elif op == "not": |
| | (_, S), succ = _unpack_operands([left, right], [None, "result"]) |
| | if not succ: |
| | return None |
| | result = _selex_info["all_indices_set"].difference(S) |
| | elif op == "all": |
| | (_, _), succ = _unpack_operands([left, right], [None, None]) |
| | if not succ: |
| | return None |
| | result = _selex_info["all_indices_set"] |
| | elif op == "none": |
| | (_, _), succ = _unpack_operands([left, right], [None, None]) |
| | if not succ: |
| | return None |
| | elif op == "around": |
| | (S, rad), succ = _unpack_operands([left, right], ["result", "float"]) |
| | if not succ: |
| | return None |
| |
|
| | |
| | atom_indices = np.asarray( |
| | [ |
| | ai.aix |
| | for ai in _selex_info["all_atoms"] |
| | for xi in ai.atom.locations() |
| | ] |
| | ) |
| | X_i = np.asarray( |
| | [ |
| | [xi.x, xi.y, xi.z] |
| | for ai in _selex_info["all_atoms"] |
| | for xi in ai.atom.locations() |
| | ] |
| | ) |
| | X_j = np.asarray( |
| | [ |
| | [xi.x, xi.y, xi.z] |
| | for j in S |
| | for xi in _selex_info["all_atoms"][j].atom.locations() |
| | ] |
| | ) |
| | D = np.sqrt(((X_j[np.newaxis, :, :] - X_i[:, np.newaxis, :]) ** 2).sum(-1)) |
| | ix_match = (D <= rad).sum(1) > 0 |
| | match_hits = atom_indices[ix_match] |
| | result = set(match_hits.tolist()) |
| | elif op == "saround": |
| | (S, srad), succ = _unpack_operands([left, right], ["result", "int"]) |
| | if not succ: |
| | return None |
| | for j in S: |
| | aj = _selex_info["all_atoms"][j] |
| | rj = aj.rix |
| | for ai in _selex_info["all_atoms"]: |
| | if aj.atom.residue.chain != ai.atom.residue.chain: |
| | continue |
| | ri = ai.rix |
| | if abs(ri - rj) <= srad: |
| | result.add(ai.aix) |
| | elif op == "byres": |
| | (_, S), succ = _unpack_operands([left, right], [None, "result"]) |
| | if not succ: |
| | return None |
| | gtis = set() |
| | for j in S: |
| | gtis.add(_selex_info["all_atoms"][j].rix) |
| | for a in _selex_info["all_atoms"]: |
| | if a.rix in gtis: |
| | result.add(a.aix) |
| | elif op == "bychain": |
| | (_, S), succ = _unpack_operands([left, right], [None, "result"]) |
| | if not succ: |
| | return None |
| | cixs = set() |
| | for j in S: |
| | cixs.add(_selex_info["all_atoms"][j].cix) |
| | for a in _selex_info["all_atoms"]: |
| | if a.cix in cixs: |
| | result.add(a.aix) |
| | elif op in ("first", "last"): |
| | (_, S), succ = _unpack_operands([left, right], [None, "result"]) |
| | if not succ: |
| | return None |
| | if op == "first": |
| | mi = min([_selex_info["all_atoms"][i].aix for i in S]) |
| | else: |
| | mi = max([_selex_info["all_atoms"][i].aix for i in S]) |
| | result.add(mi) |
| | elif op == "name": |
| | (_, name), succ = _unpack_operands([left, right], [None, "string"]) |
| | if not succ: |
| | return None |
| | for a in _selex_info["all_atoms"]: |
| | if a.atom.name == name: |
| | result.add(a.aix) |
| | elif op in ("re", "hyd"): |
| | if op == "re": |
| | (_, regex), succ = _unpack_operands([left, right], [None, "string"]) |
| | else: |
| | (_, _), succ = _unpack_operands([left, right], [None, None]) |
| | regex = "[0123456789]?H.*" |
| | if not succ: |
| | return None |
| | ex = re.compile(regex) |
| | for a in _selex_info["all_atoms"]: |
| | if a.atom.name is not None and ex.fullmatch(a.atom.name): |
| | result.add(a.aix) |
| | elif op in ("chain", "authchain", "segid"): |
| | (_, match_id), succ = _unpack_operands([left, right], [None, "string"]) |
| | if not succ: |
| | return None |
| | if op == "chain": |
| | prop = "cid" |
| | elif op == "authchain": |
| | prop = "authid" |
| | elif op == "segid": |
| | prop = "segid" |
| | for a in _selex_info["all_atoms"]: |
| | if getattr(a.atom.residue.chain, prop) == match_id: |
| | result.add(a.aix) |
| | elif op == "resid": |
| | (_, test), succ = _unpack_operands([left, right], [None, "int_range"]) |
| | if not succ: |
| | return None |
| | for a in _selex_info["all_atoms"]: |
| | if test(a.atom.residue.num): |
| | result.add(a.aix) |
| | elif op in ("resname", "icode"): |
| | (_, match_id), succ = _unpack_operands([left, right], [None, "string"]) |
| | if not succ: |
| | return None |
| | if op == "resname": |
| | prop = "name" |
| | elif op == "icode": |
| | prop = "icode" |
| | for a in _selex_info["all_atoms"]: |
| | if getattr(a.atom.residue, prop) == match_id: |
| | result.add(a.aix) |
| | elif op == "authresid": |
| | (_, test), succ = _unpack_operands( |
| | [left, right], [None, "int_range_string"] |
| | ) |
| | if not succ: |
| | return None |
| | for a in _selex_info["all_atoms"]: |
| | if test(a.atom.residue.authid): |
| | result.add(a.aix) |
| | elif op == "gti": |
| | (_, test), succ = _unpack_operands([left, right], [None, "int_range"]) |
| | if not succ: |
| | return None |
| | for a in _selex_info["all_atoms"]: |
| | if test(a.rix): |
| | result.add(a.aix) |
| | elif op in ("x", "y", "z", "b", "occ"): |
| | (_, test), succ = _unpack_operands([left, right], [None, "range"]) |
| | if not succ: |
| | return None |
| | prop = op |
| | if op == "b": |
| | prop = "B" |
| | for a in _selex_info["all_atoms"]: |
| | for loc in a.atom.locations(): |
| | if test(getattr(loc, prop)): |
| | result.add(a.aix) |
| | break |
| | elif op == "namesel": |
| | (_, selname), succ = _unpack_operands([left, right], [None, "string"]) |
| | if not succ: |
| | return None |
| | if selname not in self._selections: |
| | return None |
| | gtis = set(self._selections[selname]) |
| | for a in _selex_info["all_atoms"]: |
| | if a.rix in gtis: |
| | result.add(a.aix) |
| | else: |
| | return None |
| |
|
| | return {"result": result} |
| |
|
| | def __getitem__(self, chain_idx: int): |
| | """Returns the chain at the given index.""" |
| | return self.get_chain(chain_idx) |
| |
|
| | def add_chain( |
| | self, |
| | cid: str, |
| | segid: str = None, |
| | authid: str = None, |
| | entity_id: int = None, |
| | auto_rename: bool = True, |
| | at: int = None, |
| | ): |
| | """Adds a new chain to the System and returns a reference to it. |
| | |
| | Args: |
| | cid (str): Chain ID. |
| | segid (str): Segment ID. |
| | authid (str): Author chain ID. |
| | entity_id (int, optional): Entity ID of the entity corresponding to this chain. |
| | auto_rename (bool, optional): If True, will pick a unique chain ID if the specified |
| | one clashes with an already existing chain. |
| | |
| | Returns: |
| | AtomView object corresponding to the index. |
| | """ |
| | if auto_rename: |
| | cid = self._pick_unique_chain_name(cid) |
| | if segid is None: |
| | segid = cid |
| | if authid is None: |
| | authid = cid |
| | if at is None: |
| | at = self.num_chains() |
| | self._chains.append({"cid": cid, "segid": segid, "authid": authid}) |
| | self._chain_entities.append(entity_id) |
| | else: |
| | self._chains.insert(at, {"cid": cid, "segid": segid, "authid": authid}) |
| | self._chain_entities.insert(at, entity_id) |
| | return ChainView(at, self) |
| |
|
| | def _append_residue(self, name: str, num: int, authid: str, icode: str): |
| | """Add a new residue to the end this System. Internal method, do not use. |
| | |
| | Args: |
| | name (str): Residue name. |
| | num (int): Residue number (i.e., residue ID). |
| | authid (str): Author residue ID. |
| | icode (str): Insertion code. |
| | |
| | Returns: |
| | Global index to the newly added residue. |
| | """ |
| | self._chains.append_child( |
| | {"name": name, "resnum": num, "authresid": authid, "icode": icode} |
| | ) |
| | return len(self._residues) - 1 |
| |
|
| | def _append_atom( |
| | self, |
| | name: str, |
| | het: bool, |
| | x: float = None, |
| | y: float = None, |
| | z: float = None, |
| | occ: float = None, |
| | B: float = None, |
| | alt: str = None, |
| | ): |
| | """Adds a new atom to the end of this System. Internal method, do not use. |
| | |
| | Args: |
| | name (str): Atom name. |
| | het (bool): Whether it is a hetero-atom. |
| | x, y, z (float): Atom location coordinates. |
| | occ (float): Occupancy. |
| | B (float): B-factor. |
| | alt (str): Alternative position character. |
| | |
| | Returns: |
| | Global index to the newly added atom. |
| | """ |
| | self._residues.append_child({"name": name, "het": het}) |
| | return len(self._atoms) - 1 |
| |
|
| | def _append_location(self, x, y, z, occ, B, alt): |
| | """Adds a location to the end of this System. Internal method, do not use. |
| | |
| | Args: |
| | x, y, z (float): coordinates of the location. |
| | occ (float): occupancy for the location. |
| | B (float): B-factor for the location. |
| | alt (str): alternative location character. |
| | |
| | Returns: |
| | Global index to the newly added location. |
| | """ |
| | self._atoms.append_child({"coor": [x, y, z, occ, B], "alt": alt}) |
| | return len(self._locations) - 1 |
| |
|
| | def add_new_entity(self, entity: SystemEntity, chain_indices: list): |
| | """Adds a new entity to the list contained within the System and |
| | assigns chains with provided indices to this entity. |
| | |
| | Args: |
| | entity (SystemEntity): The new entity to add to the System. |
| | chain_indices (list): a list of Chain indices for chains to |
| | assign to this entity. |
| | |
| | Returns: |
| | The entity ID of the newly added entity. |
| | """ |
| | new_entity_id = len(self._entities) |
| | while new_entity_id in self._entities: |
| | new_entity_id = new_entity_id + 1 |
| | self._entities[new_entity_id] = entity |
| | for ci in chain_indices: |
| | self._chain_entities[ci] = new_entity_id |
| | return new_entity_id |
| |
|
| | def delete_entity(self, entity_id: int): |
| | """Deletes the entity with the specified ID. Takes care to unlink |
| | any chains belonging to this entity from it. |
| | |
| | Args: |
| | entity_id (int): Entity ID. |
| | """ |
| | chain_indices = self.get_chains_of_entity(entity_id) |
| | for ci in chain_indices: |
| | self._chain_entities[ci] = None |
| | del self._entities[entity_id] |
| |
|
| | def _pick_unique_chain_name(self, hint: str, verbose=False): |
| | goodNames = list( |
| | "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" |
| | ) |
| | taken = set([chain.cid for chain in self.chains()]) |
| |
|
| | |
| | for cid in [hint] + goodNames: |
| | if cid not in taken: |
| | return cid |
| | if verbose: |
| | warnings.warn( |
| | "ran out of reasonable single-letter chain names, will use more than one character (PDB sctructure may be repeating chain IDs upon writing, but should still have unique segment IDs)!" |
| | ) |
| |
|
| | |
| | for i in range(-1, len(goodNames)): |
| | |
| | base = hint if i < 0 else goodNames[i : i + 1] |
| | if base == "": |
| | continue |
| | for k in range(1000): |
| | longName = f"{base}{k}" |
| | if longName not in taken: |
| | return longName |
| | raise Exception( |
| | "ran out of even multi-character chain names; PDB structure appears to have an enormous number of chains" |
| | ) |
| |
|
| | def _ensure_unique_entity(self, ci: int): |
| | """Any time we need to update some piece of information about a Chain that |
| | relates to its entity (e.g., sequence info or hetero info), we cannot just |
| | update it directly because other Chains may be pointing to the same entity. |
| | This function checks for any other chains pointing to the same entity as the |
| | specified chain, and (if so) assigns the given chain to a new (duplicate) |
| | entity and returns its new ID. This clears the way updates of this Chain's entity. |
| | |
| | Args: |
| | ci (int): Index of the Chain for which we are trying to update |
| | entity information. |
| | |
| | Returns: |
| | entity ID for either a newly created entity mapped to the Chain or its |
| | initial entity ID if no other chains point to the same entity. |
| | """ |
| | chain = self.get_chain(ci) |
| | entity_id = chain.get_entity_id() |
| | if entity_id is None: |
| | return entity_id |
| |
|
| | |
| | unique = True |
| | for other in self.chains(): |
| | if (other != chain) and (entity_id == other.get_entity_id()): |
| | unique = False |
| | break |
| | if unique: |
| | return entity_id |
| |
|
| | |
| | new_entity = copy.deepcopy(self._entities[entity_id]) |
| | new_entity_id = self.add_new_entity(new_entity, [ci]) |
| | return new_entity_id |
| |
|
| | def num_chains(self): |
| | """Returns the number of chains in the System.""" |
| | return len(self._chains) |
| |
|
| | def num_chains_of_entity(self, entity_id: int): |
| | """Returns the number of chains of a given entity. |
| | |
| | Args: |
| | entity_id (int): Entity ID. |
| | |
| | Returns: |
| | number of chains mapping to the entity. |
| | """ |
| |
|
| | return sum([entity_id == eid for eid in self._chain_entities]) |
| |
|
| | def num_molecules_of_entity(self, entity_id: int): |
| | if self._entities[entity_id].is_polymer(): |
| | return self.num_chains_of_entity(entity_id) |
| | cixs = [ci for (ci, id) in enumerate(self._chain_entities) if id == entity_id] |
| | return sum([self[ci].num_residues() for ci in cixs]) |
| |
|
| | def num_entities(self): |
| | """Returns the number of entities in the System.""" |
| | return len(self._entities) |
| |
|
| | def num_residues(self): |
| | """Returns the number of residues in the System.""" |
| | return len(self._residues) |
| |
|
| | def num_structured_residues(self): |
| | """Returns the number of residues with any structure information.""" |
| | return sum([chain.num_structured_residues() for chain in self.chains()]) |
| |
|
| | def num_atoms(self): |
| | """Returns the number of atoms in the System.""" |
| | return len(self._atoms) |
| |
|
| | def num_structured_atoms(self): |
| | """Returns the number of atoms with any location information.""" |
| | num = 0 |
| | for chain in self.chains(): |
| | for residue in chain.residues(): |
| | for atom in residue.atoms(): |
| | num = num + (atom.num_locations() > 0) |
| | return num |
| |
|
| | def num_atom_locations(self): |
| | """Returns the number of atom locations. Note that an atom can have |
| | multiple (alternative) locations and this functions counts all. |
| | """ |
| | return len(self._locations) |
| |
|
| | def num_models(self): |
| | """Returns the number of models in the System. A model is effectively |
| | a conformation of the molecular system and each System object can have |
| | an arbitrary number of different conformations. |
| | """ |
| | return len(self._extra_models) + 1 |
| |
|
| | def swap_model(self, i: int): |
| | """Swaps the model at index `i` with the current model (i.e., the |
| | model at index 0). |
| | |
| | Args: |
| | i (int): Model index |
| | """ |
| | if i == 0: |
| | return |
| | if i < 0 or i >= self.num_models(): |
| | raise Exception(f"model index {i} out of range") |
| | tmp = self._locations |
| | self._locations = self._extra_models[i - 1] |
| | self._extra_models[i - 1] = tmp |
| |
|
| | def add_model(self, other: System): |
| | """Adds a new model to the System by taking the current model from the |
| | specified System `other`. Note that `other` and the present System |
| | must have the same number of atom locations of matching atom and |
| | residue names. |
| | |
| | Args: |
| | other (System): The System to take the model from. |
| | """ |
| | if len(self._locations) != len(other._locations): |
| | raise Exception( |
| | f"System has {len(self._locations)} atom locations while {len(other._locations)} were provided" |
| | ) |
| | self._extra_models.append(other._locations.copy()) |
| | self._extra_models[-1].set_parent(self._atoms) |
| |
|
| | def add_model_from_X(self, X: torch.Tensor): |
| | """Adds a new model to the System with given coordinates. |
| | |
| | Args: |
| | X (torch.Tensor): Coordinate tensor of shape |
| | `(residues, atoms (4 or 14), coordinates (3))` |
| | """ |
| | if len(self._locations) != X.numel() / 3: |
| | raise Exception( |
| | f"System has {len(self._locations)} atom locations while provided tensor shape is {X.shape}" |
| | ) |
| | X = X.detach().cpu() |
| | self._extra_models.append(self._locations.copy()) |
| | self._extra_models[-1]["coor"][:, 0:3] = X.flatten(0, 1) |
| | return None |
| |
|
| | def num_assemblies(self): |
| | """Returns the number of biological assemblies defined in this System.""" |
| | return len(self._assembly_info.assemblies) |
| |
|
| | @staticmethod |
| | def from_CIF_string(cif_string: str): |
| | """Initializes and returns a System object from a CIF string.""" |
| | import io |
| |
|
| | f = io.StringIO(cif_string) |
| | return System._read_cif(f)[0] |
| |
|
| | @staticmethod |
| | def from_CIF(input_file: str): |
| | """Initializes and returns a System object from a CIF file.""" |
| | if input_file.endswith('.cif.gz'): |
| | with gzip.open(input_file, 'rb') as f_in: |
| | file_content = f_in.read() |
| | file_stream = io.BytesIO(file_content) |
| | f = io.TextIOWrapper(file_stream, encoding='utf-8') |
| | if input_file.endswith('.cif'): |
| | f = open(input_file, "r") |
| | return System._read_cif(f)[0] |
| |
|
| | @staticmethod |
| | def _read_cif(f, strict=False): |
| | def _warn_or_error(strict: bool, msg: str): |
| | if strict: |
| | raise Exception(msg) |
| | else: |
| | warnings.warn(msg) |
| |
|
| | is_read = { |
| | part: False for part in ["coors", "entities", "sequence", "entity_poly"] |
| | } |
| | category = "" |
| | (in_loop, success) = (False, True) |
| | peeked = sp.PeekedLine("", 0) |
| | |
| | num_of_mols = dict() |
| |
|
| | system = System("system") |
| | while sp.peek_line(f, peeked): |
| | if peeked.line.startswith("#"): |
| | |
| | sp.advance(f, peeked) |
| | elif peeked.line.startswith("data_"): |
| | |
| | sp.advance(f, peeked) |
| | elif peeked.line.startswith("loop_"): |
| | in_loop = True |
| | category = "" |
| | sp.advance(f, peeked) |
| | else: |
| | (cat, name, val) = ("", "", "") |
| | if peeked.line.startswith("_"): |
| | (cat, name, val) = sp.star_item_parse(peeked.line) |
| | if cat != category: |
| | if category != "": |
| | in_loop = False |
| | category = cat |
| |
|
| | if (cat == "_entry") and (name == "id"): |
| | if val != "": |
| | system.name = val |
| | sp.advance(f, peeked) |
| | elif cat == "_entity_poly": |
| | if is_read["entity_poly"]: |
| | raise Exception("entity_poly block encountered multiple times") |
| | tab = sp.star_read_data(f, ["entity_id", "type"], in_loop) |
| | for row in tab: |
| | ent_id = int(row[0]) - 1 |
| | if ent_id not in system._entities: |
| | system._entities[ent_id] = SystemEntity( |
| | None, None, row[1], None, None |
| | ) |
| | else: |
| | system._entities[ent_id]._polymer_type = row[1] |
| | is_read["entity_poly"] = True |
| | elif cat == "_entity": |
| | if is_read["entities"]: |
| | raise Exception( |
| | f"entities block encountered multiple times: {peeked.line}" |
| | ) |
| | tab = sp.star_read_data( |
| | f, |
| | ["id", "type", "pdbx_description", "pdbx_number_of_molecules"], |
| | in_loop, |
| | ) |
| | for row in tab: |
| | ent_id = int(row[0]) - 1 |
| | if ent_id not in system._entities: |
| | system._entities[ent_id] = SystemEntity( |
| | row[1], row[2], None, None, None |
| | ) |
| | else: |
| | system._entities[ent_id]._type = row[1] |
| | system._entities[ent_id]._desc = row[2] |
| | if row[3].isnumeric(): |
| | num_of_mols[ent_id] = int(row[3]) |
| | is_read["entities"] = True |
| | elif cat == "_entity_poly_seq": |
| | if is_read["sequence"]: |
| | raise Exception(f"sequence block encountered multiple times") |
| | tab = sp.star_read_data( |
| | f, ["entity_id", "num", "mon_id", "hetero"], in_loop |
| | ) |
| | (seq, het) = ([], []) |
| | for i in range(len(tab)): |
| | |
| | seq.append(tab[i][2]) |
| | het.append(tab[i][3].startswith("y")) |
| | if (i == len(tab) - 1) or (tab[i][0] != tab[i + 1][0]): |
| | ent_id = int(tab[i][0]) - 1 |
| | system._entities[ent_id]._seq = seq |
| | system._entities[ent_id]._het = het |
| | (seq, het) = ([], []) |
| | is_read["sequence"] = True |
| | elif cat == "_pdbx_struct_assembly": |
| | tab = sp.star_read_data(f, ["id", "details"], in_loop) |
| | for row in tab: |
| | system._assembly_info.assemblies[row[0]] = {"details": row[1]} |
| | elif cat == "_pdbx_struct_assembly_gen": |
| | tab = sp.star_read_data( |
| | f, ["assembly_id", "oper_expression", "asym_id_list"], in_loop |
| | ) |
| | for row in tab: |
| | assembly = system._assembly_info.assemblies[row[0]] |
| | if "instructions" not in assembly: |
| | assembly["instructions"] = [] |
| | chain_ids = [cid.strip() for cid in row[2].strip().split(",")] |
| | assembly["instructions"].append( |
| | {"oper_expression": row[1], "chains": chain_ids} |
| | ) |
| | elif cat == "_pdbx_struct_oper_list": |
| | tab = sp.star_read_data( |
| | f, |
| | [ |
| | "id", |
| | "type", |
| | "name", |
| | "matrix[1][1]", |
| | "matrix[1][2]", |
| | "matrix[1][3]", |
| | "matrix[2][1]", |
| | "matrix[2][2]", |
| | "matrix[2][3]", |
| | "matrix[3][1]", |
| | "matrix[3][2]", |
| | "matrix[3][3]", |
| | "vector[1]", |
| | "vector[2]", |
| | "vector[3]", |
| | ], |
| | in_loop, |
| | ) |
| | for row in tab: |
| | system._assembly_info.operations[ |
| | row[0] |
| | ] = SystemAssemblyInfo.make_operation( |
| | row[1], row[2], row[3:12], row[12:15] |
| | ) |
| | elif cat == "_generate_selections": |
| | tab = sp.star_read_data(f, ["name", "indices"], in_loop) |
| | for row in tab: |
| | system._selections[row[0]] = [ |
| | int(gti.strip()) for gti in row[1].strip().split() |
| | ] |
| | elif cat == "_generate_labels": |
| | tab = sp.star_read_data(f, ["name", "index", "value"], in_loop) |
| | for row in tab: |
| | if row[0] not in system._labels: |
| | system._labels[row[0]] = dict() |
| | idx = int(row[1]) |
| | system._labels[row[0]][int(row[1])] = row[2] |
| | elif cat == "_atom_site": |
| | if is_read["coors"]: |
| | raise Exception(f"ATOM_SITE block encountered multiple times") |
| | |
| | tab = sp.star_read_data( |
| | f, |
| | [ |
| | "group_PDB", |
| | "id", |
| | "label_atom_id", |
| | "label_alt_id", |
| | "label_comp_id", |
| | "label_asym_id", |
| | "label_entity_id", |
| | "label_seq_id", |
| | "pdbx_PDB_ins_code", |
| | "Cartn_x", |
| | "Cartn_y", |
| | "Cartn_z", |
| | "occupancy", |
| | "B_iso_or_equiv", |
| | "pdbx_PDB_model_num", |
| | "auth_seq_id", |
| | "auth_asym_id", |
| | ], |
| | in_loop, |
| | cols=False, |
| | has_blocks=False, |
| | ) |
| |
|
| | groupCol = 0 |
| | idxCol = 1 |
| | atomNameCol = 2 |
| | altIdCol = 3 |
| | resNameCol = 4 |
| | chainNameCol = 5 |
| | entityIdCol = 6 |
| | seqIdCol = 7 |
| | insCodeCol = 8 |
| | xCol = 9 |
| | yCol = 10 |
| | zCol = 11 |
| | occCol = 12 |
| | bCol = 13 |
| | modelCol = 14 |
| | authSeqIdCol = 15 |
| | authChainNameCol = 16 |
| |
|
| | ( |
| | atom, |
| | residue, |
| | chain, |
| | prev_chain, |
| | prev_residue, |
| | prev_atom, |
| | prev_entity_id, |
| | prev_seq_id, |
| | prev_auth_seq_id, |
| | ) = (None, None, None, None, None, None, None, None, None) |
| | loc = None |
| | aIdx = 0 |
| | for i in range(len(tab)): |
| | if i == 0: |
| | first_model = tab[i][modelCol] |
| | prev_model = first_model |
| | elif (tab[i][modelCol] != prev_model) or ( |
| | tab[i][modelCol] != first_model |
| | ): |
| | if tab[i][modelCol] != prev_model: |
| | aIdx = 0 |
| | num_loc = system.num_atom_locations() |
| | |
| | |
| | |
| | system._extra_models.append(system._new_locations()) |
| | prev_model = tab[i][modelCol] |
| | locations_generator = (l for l in system.locations()) |
| |
|
| | loc = next(locations_generator, None) |
| | if aIdx >= num_loc: |
| | _warn_or_error( |
| | strict, |
| | f"at atom id: {tab[i][idxCol]} -- too many atoms in model {tab[i][modelCol]} relative to first model {first_model}", |
| | ) |
| | success = False |
| | system._extra_models.clear() |
| | break |
| |
|
| | |
| | same = ( |
| | (loc is not None) |
| | and (tab[i][chainNameCol] == loc.atom.residue.chain.cid) |
| | and (tab[i][resNameCol] == loc.atom.residue.name) |
| | and ( |
| | int( |
| | sp.star_value( |
| | tab[i][seqIdCol], loc.atom.residue.num |
| | ) |
| | ) |
| | == loc.atom.residue.num |
| | ) |
| | and (tab[i][atomNameCol] == loc.atom.name) |
| | ) |
| | if not same: |
| | _warn_or_error( |
| | strict, |
| | f"at atom id: {tab[i][idxCol]} -- atoms in model {tab[i][modelCol]} do not correspond exactly to atoms in first model", |
| | ) |
| | success = False |
| | system._extra_models.clear() |
| | break |
| |
|
| | coor = [ |
| | float(tab[i][c]) |
| | for c in [xCol, yCol, zCol, occCol, bCol] |
| | ] |
| | system._extra_models[-1]["coor"][aIdx] = coor |
| | system._extra_models[-1]["alt"][aIdx] = sp.star_value( |
| | tab[i][altIdCol], " " |
| | )[0] |
| | aIdx = aIdx + 1 |
| | continue |
| |
|
| | |
| | if ( |
| | (chain is None) |
| | or (prev_entity_id != tab[i][entityIdCol]) |
| | or (tab[i][chainNameCol] != chain.cid) |
| | ): |
| | authid = ( |
| | tab[i][authChainNameCol] |
| | if (tab[i][authChainNameCol] != "") |
| | else tab[i][chainNameCol] |
| | ) |
| | chain = system.add_chain( |
| | tab[i][chainNameCol], |
| | tab[i][chainNameCol], |
| | authid, |
| | int(tab[i][entityIdCol]) - 1, |
| | ) |
| |
|
| | |
| | if ( |
| | (residue is None) |
| | or (chain != prev_chain) |
| | or (prev_seq_id != tab[i][seqIdCol]) |
| | or (prev_auth_seq_id != tab[i][authSeqIdCol]) |
| | ): |
| | resnum = ( |
| | int(tab[i][seqIdCol]) |
| | if sp.star_value_defined(tab[i][seqIdCol]) |
| | else chain.num_residues() + 1 |
| | ) |
| | ri = system._append_residue( |
| | tab[i][resNameCol], |
| | resnum, |
| | tab[i][authSeqIdCol], |
| | sp.star_value(tab[i][insCodeCol], " ")[0], |
| | ) |
| | residue = ResidueView(ri, chain) |
| |
|
| | |
| | |
| | |
| | x, y, z, occ, B = [ |
| | float(tab[i][col]) |
| | for col in [xCol, yCol, zCol, occCol, bCol] |
| | ] |
| | alt = sp.star_value(tab[i][altIdCol], " ")[0] |
| | if ( |
| | (atom is None) |
| | or (residue != prev_residue) |
| | or (tab[i][atomNameCol] != atom.name) |
| | ): |
| | ai = system._append_atom( |
| | tab[i][atomNameCol], (tab[i][groupCol] == "HETATM") |
| | ) |
| | atom = AtomView(ai, residue) |
| | system._append_location(x, y, z, occ, B, alt) |
| |
|
| | prev_chain = chain |
| | prev_residue = residue |
| | prev_entity_id = tab[i][entityIdCol] |
| | prev_seq_id = tab[i][seqIdCol] |
| | prev_auth_seq_id = tab[i][authSeqIdCol] |
| | is_read["coors"] = True |
| | else: |
| | sp.advance(f, peeked) |
| |
|
| | |
| | |
| | for entity_id in num_of_mols: |
| | if system._entities[entity_id].is_polymer(): |
| | rem = num_of_mols[entity_id] - system.num_chains_of_entity(entity_id) |
| | for _ in range(rem): |
| | |
| | system.add_chain("A", None, None, entity_id, auto_rename=True) |
| |
|
| | |
| | |
| | for chain in system.chains(): |
| | entity = chain.get_entity() |
| | if not entity.is_polymer() or entity._seq is None: |
| | continue |
| | k = 0 |
| | for ri in range(len(entity._seq)): |
| | cur_res = chain.get_residue(k) if k < chain.num_residues() else None |
| | if (cur_res is None) or (cur_res.num > ri + 1): |
| | |
| | chain.add_residue(entity._seq[ri], ri + 1, str(ri + 1), " ", at=k) |
| | elif cur_res.num < ri + 1: |
| | _warn_or_error( |
| | strict, f"inconsistent numbering in chain {chain.cid}" |
| | ) |
| | break |
| | k = k + 1 |
| |
|
| | |
| | for chain in system.chains(): |
| | if not chain.check_sequence(): |
| | _warn_or_error( |
| | strict, |
| | f"chain {chain.cid} did not pass sequence check against corresponding entity", |
| | ) |
| |
|
| | system._reindex() |
| | return system, success |
| |
|
| | @staticmethod |
| | def from_PDB_string(cif_string: str, options=""): |
| | """Initializes and returns a System object from a PDB string.""" |
| | import io |
| |
|
| | f = io.StringIO(cif_string) |
| | sys = System._read_pdb(f, options) |
| | sys.name = "from_string" |
| | return sys |
| |
|
| | @staticmethod |
| | def from_PDB(input_file: str, options=""): |
| | """Initializes and returns a System object from a PDB file.""" |
| | f = open(input_file, "r") |
| | sys = System._read_pdb(f, options) |
| | sys.name = input_file |
| | return sys |
| |
|
| | @staticmethod |
| | def _read_pdb(f, strict=False, options=""): |
| | def _to_float(strval, default): |
| | v = default |
| | try: |
| | v = float(strval) |
| | except: |
| | pass |
| | return v |
| |
|
| | last_resnum = None |
| | last_resname = None |
| | last_icode = None |
| | last_chain_id = None |
| | last_alt = None |
| | chain = None |
| | residue = None |
| |
|
| | |
| | ter = True |
| |
|
| | |
| | |
| | options = options.upper() |
| | |
| | |
| | usese_gid = True if ("USESEGID" in options) else False |
| |
|
| | |
| | charmm_format = True if ("CHARMM" in options) else False |
| |
|
| | |
| | |
| | charmm19_format = True if ("CHARMM19" in options) else False |
| |
|
| | |
| | uniq_chain_ids = False if ("ALLOW DUPLICATE CIDS" in options) else True |
| |
|
| | |
| | fix_Ile_CD = False if ("ALLOW ILE CD" in options) else True |
| |
|
| | |
| | |
| | icodes_as_sep_res = True |
| |
|
| | |
| | ignore_ter = True if ("IGNORE-TER" in options) else False |
| |
|
| | |
| | verbose = False if ("QUIET" in options) else True |
| |
|
| | chains_to_rename = [] |
| |
|
| | |
| | system = System("system") |
| | all_system = system |
| | model_index = 0 |
| | for line in f: |
| | line = line.strip() |
| | if line.startswith("ENDMDL"): |
| | |
| | if model_index: |
| | try: |
| | all_system.add_model(system) |
| | except Exception as e: |
| | warnings.warn( |
| | f"error when adding model {model_index + 1}: {str(e)}, skipping model..." |
| | ) |
| | system = System("system") |
| | model_index = model_index + 1 |
| | last_resnum = None |
| | last_resname = None |
| | last_icode = None |
| | last_chain_id = None |
| | last_alt = None |
| | chain = None |
| | residue = None |
| | continue |
| | if line.startswith("END"): |
| | break |
| | if line.startswith("MODEL"): |
| | |
| | continue |
| | if line.startswith("TER") and not ignore_ter: |
| | ter = True |
| | continue |
| | if not (line.startswith("ATOM") or line.startswith("HETATM")): |
| | continue |
| |
|
| | """ Now read atom record. Sometimes PDB lines are too short (if they do not contain some |
| | of the last optional columns). We don't want to read past the end of the string!""" |
| | line += " " * 100 |
| | atominx = int(line[6:11]) |
| | atomname = line[12:16].strip() |
| | alt = line[16:17] |
| | resname = line[17:21].strip() |
| | chain_id = line[21:22].strip() |
| | resnum = int(line[23:27]) if charmm_format else int(line[22:26]) |
| | icode = " " if charmm_format else line[26:27] |
| | x = float(line[30:38]) |
| | y = float(line[38:46]) |
| | z = float(line[46:54]) |
| | seg_id = line[72:76].strip() |
| | B = _to_float(line[60:66], 0.0) |
| | occ = _to_float(line[54:60], 0.0) |
| | het = line.startswith("HETATM") |
| |
|
| | |
| | if usese_gid: |
| | chain_id = seg_id |
| | elif (chain_id == "") and (len(seg_id) > 0) and seg_id[0].isalnum(): |
| | |
| | |
| | chain_id = seg_id[0:1] |
| |
|
| | |
| | if (chain_id != last_chain_id) or ter: |
| | cid_used = system.get_chain_by_id(chain_id) is not None |
| | chain = system.add_chain(chain_id, seg_id, chain_id, auto_rename=False) |
| | |
| | |
| | if uniq_chain_ids and cid_used: |
| | chain.cid = chain.cid + f"|to rename {len(chains_to_rename)}" |
| | if model_index == 0: |
| | chains_to_rename.append(chain) |
| | if verbose: |
| | warnings.warn( |
| | "chain name '" |
| | + chain_id |
| | + "' was repeated while reading, will rename at the end..." |
| | ) |
| |
|
| | |
| | last_resnum = None |
| | last_resname = None |
| | ter = False |
| |
|
| | if charmm19_format: |
| | if resname == "HSE": |
| | resname = "HSD" |
| | if resname == "HSD": |
| | resname = "HIS" |
| | if resname == "HSC": |
| | resname = "HSP" |
| |
|
| | |
| | |
| | if fix_Ile_CD and (resname == "ILE") and (atomname == "CD"): |
| | atomname = "CD1" |
| |
|
| | |
| | really_new_atom = True |
| | if ( |
| | (resnum != last_resnum) |
| | or (resname != last_resname) |
| | or (icodes_as_sep_res and (icode != last_icode)) |
| | ): |
| | |
| | |
| | |
| | |
| | |
| | |
| | if ( |
| | (resnum == last_resnum) |
| | and (resname != last_resname) |
| | and (alt != last_alt) |
| | and (not icodes_as_sep_res or (icode == last_icode)) |
| | ): |
| | continue |
| |
|
| | residue = chain.add_residue( |
| | resname, chain.num_residues() + 1, str(resnum), icode[0] |
| | ) |
| | elif alt != " ": |
| | |
| | |
| | |
| | |
| | a = residue.find_atom(atomname) |
| | if a is not None: |
| | really_new_atom = False |
| | a.add_location(x, y, z, occ, B, alt[0]) |
| |
|
| | |
| | if really_new_atom: |
| | a = residue.add_atom(atomname, het, x, y, z, occ, B, alt[0]) |
| |
|
| | |
| | last_resnum = resnum |
| | last_icode = icode |
| | last_resname = resname |
| | last_chain_id = chain_id |
| | last_alt = alt |
| |
|
| | |
| | for chain in chains_to_rename: |
| | parts = chain.cid.split("|") |
| | assert ( |
| | len(parts) > 1 |
| | ), "something went wrong when renaming a chain at the end of reading" |
| | name = all_system._pick_unique_chain_name(parts[0], verbose) |
| | chain.cid = name |
| | if len(name): |
| | chain.segid = name |
| |
|
| | |
| | for ci, chain in enumerate(all_system.chains()): |
| | seq = [None] * chain.num_residues() |
| | het = [None] * chain.num_residues() |
| | for ri, res in enumerate(chain.residues()): |
| | seq[ri] = res.name |
| | het[ri] = all(a.het for a in res.atoms()) |
| | entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) |
| | entity = SystemEntity( |
| | entity_type, f"chain {chain.cid}", polymer_type, seq, het |
| | ) |
| | all_system.add_new_entity(entity, [ci]) |
| |
|
| | return all_system |
| |
|
| | def to_CIF(self, output_file: str): |
| | """Writes the System to a CIF file.""" |
| | f = open(output_file, "w") |
| | self._write_cif(f) |
| |
|
| | def to_CIF_string(self): |
| | """Returns a CIF string representing the System.""" |
| | import io |
| |
|
| | f = io.StringIO("") |
| | self._write_cif(f) |
| | cif_str = f.getvalue() |
| | f.close() |
| | return cif_str |
| |
|
| | def _write_cif(self, f): |
| | |
| | _specials_atom_names = [ |
| | "MG", "CL", "FE", "ZN", "MN", "NI", "SE", "CU", "BR", "CO", "AS", |
| | "BE", "RU", "RB", "ZR", "OS", "SR", "GD", "MO", "AU", "AG", "PT", |
| | "AL", "XE", "BE", "CS", "EU", "IR", "AM", "TE", "BA", "SB" |
| | ] |
| | |
| | _ambiguous_atom_names = ["CA", "CD", "NA", "HG", "PB"] |
| |
|
| | def _guess_type(atom_name, res_name): |
| | if len(atom_name) > 0 and atom_name[0] == '"': |
| | atom_name = atom_name.replace('"', "") |
| | if atom_name[:2] in _specials_atom_names: |
| | return atom_name[:2] |
| | else: |
| | if atom_name in _ambiguous_atom_names and res_name == atom_name: |
| | return atom_name |
| | elif atom_name == "UNK": |
| | return "X" |
| | return atom_name[:1] |
| |
|
| | entry_id = self.name.strip() |
| | if entry_id == "": |
| | entry_id = "system" |
| | f.write( |
| | "data_GNR8\n#\n" |
| | + "_entry.id " |
| | + sp.star_string_escape(entry_id) |
| | + "\n#\n" |
| | ) |
| |
|
| | |
| | sp.star_loop_header_write( |
| | f, "_entity", ["id", "type", "pdbx_description", "pdbx_number_of_molecules"] |
| | ) |
| | for id, entity in self._entities.items(): |
| | num_mol = self.num_molecules_of_entity(id) |
| | f.write( |
| | f"{id + 1} {sp.star_string_escape(entity._type)} {sp.star_string_escape(entity._desc)} {num_mol}\n" |
| | ) |
| | f.write("#\n") |
| |
|
| | |
| | sp.star_loop_header_write( |
| | f, "_entity_poly_seq", ["entity_id", "num", "mon_id", "hetero"] |
| | ) |
| | for id, entity in self._entities.items(): |
| | if entity._seq is not None: |
| | for i, (res, het) in enumerate(zip(entity._seq, entity._het)): |
| | f.write(f"{id + 1} {i + 1} {res} {'y' if het else 'n'}\n") |
| | f.write("#\n") |
| |
|
| | |
| | sp.star_loop_header_write(f, "_entity_poly", ["entity_id", "type"]) |
| | for id, entity in self._entities.items(): |
| | if entity.is_polymer(): |
| | f.write(f"{id + 1} {sp.star_string_escape(entity._polymer_type)}\n") |
| | f.write("#\n") |
| |
|
| | if self.num_assemblies(): |
| | assemblies = self._assembly_info.assemblies |
| | ops = self._assembly_info.operations |
| | |
| | sp.star_loop_header_write(f, "_pdbx_struct_assembly", ["id", "details"]) |
| | for assembly_id, assembly in assemblies.items(): |
| | f.write(f"{assembly_id} {sp.star_string_escape(assembly['details'])}\n") |
| | f.write("#\n") |
| |
|
| | |
| | sp.star_loop_header_write( |
| | f, |
| | "_pdbx_struct_assembly_gen", |
| | ["assembly_id", "oper_expression", "asym_id_list"], |
| | ) |
| | for assembly_id, assembly in assemblies.items(): |
| | for instruction in assembly["instructions"]: |
| | chain_list = ",".join([str(ci) for ci in instruction["chains"]]) |
| | f.write( |
| | f"{assembly_id} {sp.star_string_escape(instruction['oper_expression'])} {chain_list}\n" |
| | ) |
| | f.write("#\n") |
| |
|
| | |
| | sp.star_loop_header_write( |
| | f, |
| | "_pdbx_struct_oper_list", |
| | [ |
| | "id", |
| | "type", |
| | "name", |
| | "matrix[1][1]", |
| | "matrix[1][2]", |
| | "matrix[1][3]", |
| | "matrix[2][1]", |
| | "matrix[2][2]", |
| | "matrix[2][3]", |
| | "matrix[3][1]", |
| | "matrix[3][2]", |
| | "matrix[3][3]", |
| | "vector[1]", |
| | "vector[2]", |
| | "vector[3]", |
| | ], |
| | ) |
| | for op_id, op in ops.items(): |
| | f.write( |
| | f"{op_id} {sp.star_string_escape(op['type'])} {sp.star_string_escape(op['name'])} " |
| | ) |
| | f.write( |
| | f"{float(op['matrix'][0][0]):g} {float(op['matrix'][0][1]):g} {float(op['matrix'][0][2]):g} " |
| | ) |
| | f.write( |
| | f"{float(op['matrix'][1][0]):g} {float(op['matrix'][1][1]):g} {float(op['matrix'][1][2]):g} " |
| | ) |
| | f.write( |
| | f"{float(op['matrix'][2][0]):g} {float(op['matrix'][2][1]):g} {float(op['matrix'][2][2]):g} " |
| | ) |
| | f.write( |
| | f"{float(op['vector'][0]):g} {float(op['vector'][1]):g} {float(op['vector'][2]):g}\n" |
| | ) |
| | f.write("#\n") |
| |
|
| | sp.star_loop_header_write( |
| | f, |
| | "_atom_site", |
| | [ |
| | "group_PDB", |
| | "id", |
| | "label_atom_id", |
| | "label_alt_id", |
| | "label_comp_id", |
| | "label_asym_id", |
| | "label_entity_id", |
| | "label_seq_id", |
| | "pdbx_PDB_ins_code", |
| | "Cartn_x", |
| | "Cartn_y", |
| | "Cartn_z", |
| | "occupancy", |
| | "B_iso_or_equiv", |
| | "pdbx_PDB_model_num", |
| | "auth_seq_id", |
| | "auth_asym_id", |
| | "type_symbol", |
| | ], |
| | ) |
| | idx = -1 |
| | for model_index in range(self.num_models()): |
| | self.swap_model(model_index) |
| | for chain, entity_id in zip(self.chains(), self._chain_entities): |
| | authchainid = ( |
| | chain.authid if sp.star_value_defined(chain.authid) else chain.cid |
| | ) |
| | for residue in chain.residues(): |
| | authresid = ( |
| | residue.authid |
| | if sp.star_value_defined(residue.authid) |
| | else residue.num |
| | ) |
| | for atom in residue.atoms(): |
| | idx = idx + 1 |
| | for location in atom.locations(): |
| | |
| | if not location.defined(): |
| | continue |
| |
|
| | coor = location.coor_info |
| | f.write("HETATM " if atom.het else "ATOM ") |
| | f.write( |
| | f"{idx + 1} {atom.name} {sp.atom_site_token(location.alt)} " |
| | ) |
| | entity_id_str = ( |
| | f"{entity_id + 1}" if entity_id is not None else "?" |
| | ) |
| | f.write( |
| | f"{residue.name} {chain.cid} {entity_id_str} {residue.num} " |
| | ) |
| | f.write( |
| | f"{sp.atom_site_token(residue.icode)} {coor[0]:g} {coor[1]:g} {coor[2]:g} " |
| | ) |
| | f.write(f"{coor[3]:g} {coor[4]:g} {model_index} ") |
| | f.write( |
| | f"{authresid} {authchainid} {_guess_type(atom.name, residue.name)}\n" |
| | ) |
| | self.swap_model(model_index) |
| | f.write("#\n") |
| |
|
| | |
| | if len(self._selections): |
| | sp.star_loop_header_write(f, "_generate_selections", ["name", "indices"]) |
| | for name, indices in self._selections.items(): |
| | f.write( |
| | f"{sp.star_string_escape(name)} \"{' '.join([str(i) for i in indices])}\"\n" |
| | ) |
| | f.write("#\n") |
| |
|
| | |
| | if len(self._labels): |
| | sp.star_loop_header_write(f, "_generate_labels", ["name", "index", "value"]) |
| | for category, label_dict in self._labels.items(): |
| | for gti, label in label_dict.items(): |
| | f.write( |
| | f"{sp.star_string_escape(category)} {gti} {sp.star_string_escape(label)}\n" |
| | ) |
| | f.write("#\n") |
| |
|
| | def to_PDB(self, output_file: str, options: str = "", mask_indices=None, seq=None): |
| | """Writes the System to a PDB file. |
| | |
| | Args: |
| | output_file (str): output PDB file name. |
| | options (str, optional): a string specifying various options for |
| | the writing process. The presence of certain sub-strings will |
| | trigger specific behaviors. Currently recognized sub-strings |
| | include "CHARMM", "CHARMM19", "CHARMM22", "RENUMBER", "NOEND", |
| | "NOTER", and "NOALT". This option is case-insensitive. |
| | """ |
| | f = open(output_file, "w") |
| | self._write_pdb(f, options, mask_indices=mask_indices, seq=seq) |
| |
|
| | def to_PDB_string(self, options=""): |
| | """Writes the System to a PDB string. The options string has the same |
| | interpretation as with System::toPDB. |
| | """ |
| | import io |
| |
|
| | f = io.StringIO("") |
| | self._write_pdb(f, options) |
| | cif_str = f.getvalue() |
| | f.close() |
| | return cif_str |
| |
|
| | def _write_pdb(self, f, options="", mask_indices=None, seq=None): |
| | def _pdb_line(loc: AtomLocationView, ai: int, ri=None, rn=None, an=None): |
| | if rn is None: |
| | rn = loc.atom.residue.name |
| | if ri is None: |
| | ri = loc.atom.residue.num |
| | if an is None: |
| | an = loc.atom.name |
| | icode = loc.atom.residue.icode |
| | cid = loc.atom.residue.chain.cid |
| | if len(cid) > 1: |
| | cid = cid[0] |
| | segid = loc.atom.residue.chain.segid |
| | if len(segid) > 4: |
| | segid = segid[0:4] |
| |
|
| | |
| | if len(an) < 4: |
| | an_str = " %-.3s" % an |
| | else: |
| | an_str = "%.4s" % an |
| |
|
| | |
| | |
| | line = ( |
| | "%6s%5d %-4s%c%-4s%.1s%4d%c %8.3f%8.3f%8.3f%6.2f%6.2f %.4s" |
| | % ( |
| | "HETATM" if loc.atom.het else "ATOM ", |
| | ai % 100000, |
| | an_str, |
| | loc.alt, |
| | rn, |
| | cid, |
| | ri % 10000, |
| | icode, |
| | loc.x, |
| | loc.y, |
| | loc.z, |
| | loc.occ, |
| | loc.B, |
| | segid, |
| | ) |
| | ) |
| |
|
| | return line |
| |
|
| | |
| | |
| | options = options.upper() |
| | |
| | charmmFormat = True if "CHARMM" in options else False |
| |
|
| | |
| | |
| | charmm19Format = True if "CHARMM19" in options else False |
| |
|
| | |
| | |
| | |
| | charmm22Format = True if "CHARMM22" in options else False |
| |
|
| | |
| | renumber = True if "RENUMBER" in options else False |
| |
|
| | |
| | |
| | noend = True if "NOEND" in options else False |
| |
|
| | |
| | |
| | |
| | noter = True if "NOTER" in options else False |
| |
|
| | |
| | writeAlt = True if "NOALT" in options else False |
| |
|
| | |
| | |
| | genericFormat = False |
| |
|
| | if charmm19Format and charmm22Format: |
| | raise Exception( |
| | "CHARMM 19 and 22 formatting options cannot be specified together" |
| | ) |
| |
|
| | atomIndex = 1 |
| | for ci, chain in enumerate(self.chains()): |
| | for ri, residue in enumerate(chain.residues()): |
| | for ai, atom in enumerate(residue.atoms()): |
| | |
| | atomname = atom.name |
| | resname = residue.name |
| | if seq is not None: |
| | resname = str(seq[ri]) |
| | if charmmFormat: |
| | if (residue.name == "ILE") and (atom.name == "CD1"): |
| | atomname = "CD" |
| | if (atom.name == "O") and (ri == chain.num_residues() - 1): |
| | atomname = "OT1" |
| | if (atom.name == "OXT") and (ri == chain.num_residues() - 1): |
| | atomname = "OT2" |
| | if residue.name == "HOH": |
| | resname = "TIP3" |
| |
|
| | if charmm19Format: |
| | if residue.name == "HSD": |
| | resname = "HIS" |
| | if residue.name == "HSE": |
| | resname = "HSD" |
| | if residue.name == "HSC": |
| | resname = "HSP" |
| | elif charmm22Format: |
| | """This will convert from CHARMM19 to CHARMM22 as well as from a generic downlodaded |
| | * PDB file to one ready for use in CHARMM22. The latter is because in the all-hydrogen |
| | * topology, HIS protonation state must be explicitly specified, so there is no HIS per se. |
| | * Whereas in typical downloaded PDB files HIS is used for all histidines (usually, one |
| | * does not even really know the protonation state). Whether sometimes people do specify it |
| | * nevertheless, and what naming format they use to do so, I am not sure (welcome to the |
| | * PDB file format). But certainly almost always it is just HIS. Below HIS is renamed to |
| | * HSD, the neutral form with proton on ND1. This is an assumption; not a perfect one, but |
| | * something needs to be assumed. Doing this renaming will make the PDB file work in MM |
| | * packages with the all-hydrogen model.""" |
| | if residue.name == "HSD": |
| | resname = "HSE" |
| | if residue.name == "HIS": |
| | resname = "HSD" |
| | if residue.name == "HSP": |
| | resname = "HSC" |
| | elif genericFormat: |
| | if residue.name in ["HSD", "HSP", "HSE", "HSC"]: |
| | resname = "HIS" |
| | if (residue.name == "ILE") and (atom.name == "CD"): |
| | atomname = "CD1" |
| |
|
| | |
| | for li in range(atom.num_locations()): |
| | if renumber: |
| | f.write( |
| | _pdb_line( |
| | atom.get_location(li), |
| | atomIndex, |
| | ri=ri + 1, |
| | rn=resname, |
| | an=atomname, |
| | ) |
| | + "\n" |
| | ) |
| | else: |
| | a = atom.get_location(li) |
| | if mask_indices is not None: |
| | if ri in mask_indices: |
| | a.atom.B = 0 |
| | else: |
| | a.atom.B = 1 |
| | f.write( |
| | _pdb_line( |
| | a, |
| | atomIndex, |
| | rn=resname, |
| | an=atomname, |
| | ) |
| | + "\n" |
| | ) |
| | atomIndex = atomIndex + 1 |
| |
|
| | if not noter and (ri == chain.num_residues() - 1): |
| | f.write("TER\n") |
| | if not noend and (ci == self.num_chains() - 1): |
| | f.write("END\n") |
| |
|
| | def canonicalize_protein( |
| | self, |
| | level=2, |
| | drop_coors_unknowns=False, |
| | drop_coors_missing_backbone=False, |
| | filter_by_entity=False, |
| | ): |
| | """Canonicalize the calling System object (in place) by assuming that it represents |
| | a protein molecular system. Different canonicalization rigor and options |
| | can be specified but are all optional. |
| | |
| | Args: |
| | level (int): Canonicalization level that determines which nonstandard-to-standard |
| | residue mappings are performed. Possible values are 1, 2 or 3, with 2 being |
| | the default and higher values meaning more rigorous (and less conservative) |
| | canonicalization. With level 1, only truly equivalent mappings are performed |
| | (e.g., different His protonation states are mapped to the canonical residue |
| | name HIS that does not specify protonation). Level 2 adds to this some less |
| | exact but still quite close mappings--i.e., seleno-methionine (MSE) and seleno- |
| | cystine (SEC) to methionine (MET) and cystine (CYS). Level 3 further adds |
| | even less equivalent but still reasonable mappings--i.e., phosphorylated SER, |
| | THR, TYR, and HIS to their unphosphorylated counterparts as well as S-oxy Cys |
| | to Cys. |
| | drop_coors_unknowns (bool, optional): if True, will discard structural information |
| | for all residues that are not natural or mappable under the current level. |
| | NOTE: any sequence record for these residues (i.e., if they are part of a |
| | polymer entity) will be preserved. |
| | drop_coors_missing_backbone (bool, optional): if True, will discard structural |
| | information for residues that do not have at least the N, CA, C, and O |
| | backbone atoms. Same note applies regarding the full sequence record as in |
| | the above. |
| | filter_by_entity (bool, optional): if True, will remove any chains that do not |
| | represent polymer/polypeptide entities. This is convenient for cases where a |
| | System object has both protein and non-protein components. However, depending |
| | on how the System object was generated, entity metadata may not have been filled, |
| | so applying this canonicalization approach will remove the entire structure. |
| | For this reason, the option is False by default. |
| | """ |
| |
|
| | def _mod_to_standard_aa_mappings( |
| | less_standard: bool, almost_standard: bool, standard: bool |
| | ): |
| | |
| | standard_map = {"HSD": "HIS", "HSE": "HIS", "HSC": "HIS", "HSP": "HIS"} |
| |
|
| | |
| | |
| | almost_standard_map = {"MSE": "MET", "SEC": "CYS"} |
| |
|
| | |
| | |
| | |
| | less_standard_map = { |
| | "HIP": "HIS", |
| | "CSX": "CYS", |
| | "SEP": "SER", |
| | "TPO": "THR", |
| | "PTR": "TYR", |
| | } |
| |
|
| | ret = dict() |
| | if standard: |
| | ret.update(standard_map) |
| | if almost_standard: |
| | ret.update(almost_standard_map) |
| | if less_standard: |
| | ret.update(less_standard_map) |
| | return ret |
| |
|
| | def _to_standard_aa_mappings( |
| | less_standard: bool, almost_standard: bool, standard: bool |
| | ): |
| | |
| | mapping = _mod_to_standard_aa_mappings( |
| | less_standard, almost_standard, standard |
| | ) |
| |
|
| | |
| | import src.chroma.utility.polyseq as polyseq |
| |
|
| | for aa in polyseq.canonical_amino_acids(): |
| | mapping[aa] = aa |
| |
|
| | return mapping |
| |
|
| | less_standard, almost_standard, standard = False, False, False |
| | if level == 3: |
| | less_standard, almost_standard, standard = True, True, True |
| | elif level == 2: |
| | less_standard, almost_standard, standard = False, True, True |
| | elif level == 1: |
| | less_standard, almost_standard, standard = False, False, True |
| | else: |
| | raise Exception(f"unknown canonicalization level {level}") |
| |
|
| | to_standard = _to_standard_aa_mappings(less_standard, almost_standard, standard) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | chains_to_delete = [] |
| | residues_to_rename = dict() |
| | for ci, chain in enumerate(self.chains()): |
| | entity = chain.get_entity() |
| | if filter_by_entity: |
| | if ( |
| | (entity is None) |
| | or (entity._type != "polymer") |
| | or ("polypeptide" not in entity.polymer_time) |
| | ): |
| | chains_to_delete.append(chain) |
| | continue |
| |
|
| | |
| | cleared_residues = 0 |
| | for residue in reversed(list(chain.residues())): |
| | aa = residue.name |
| | delete_atoms = False |
| | |
| | if aa in to_standard: |
| | aa_new = to_standard[aa] |
| | if aa != aa_new: |
| | |
| | if ( |
| | (aa == "HSD") |
| | or (aa == "HSE") |
| | or (aa == "HSC") |
| | or (aa == "HSP") |
| | ) and (aa_new == "HIS"): |
| | pass |
| | elif ((aa == "MSE") and (aa_new == "MET")) or ( |
| | (aa == "SEC") and (aa_new == "CYS") |
| | ): |
| | SE = residue.find_atom("SE") |
| | if SE is not None: |
| | if aa == "MSE": |
| | SE.residue.rename("SD") |
| | else: |
| | SE.residue.rename("SG") |
| | elif ( |
| | ((aa == "HIP") and (aa_new == "HIS")) |
| | or ((aa == "SEP") and (aa_new == "SER")) |
| | or ((aa == "TPO") and (aa_new == "THR")) |
| | or ((aa == "PTR") and (aa_new == "TYR")) |
| | ): |
| | |
| | for atomname in ["P", "O1P", "O2P", "O3P", "HOP2", "HOP3"]: |
| | a = residue.find_atom(atomname) |
| | if a is not None: |
| | a.delete() |
| | elif (aa == "CSX") and (aa_new == "CYS"): |
| | a = residue.find_atom("OD") |
| | if a is not None: |
| | a.delete() |
| |
|
| | |
| | entity_id = chain.get_entity_id() |
| | if entity_id is None: |
| | residue.rename(aa_new) |
| | else: |
| | if entity_id not in residues_to_rename: |
| | residues_to_rename[entity_id] = dict() |
| | if ci not in residues_to_rename[entity_id]: |
| | residues_to_rename[entity_id][ci] = list() |
| | residues_to_rename[entity_id][ci].append( |
| | (residue.get_index_in_chain(), aa_new) |
| | ) |
| | else: |
| | if aa == "ARG": |
| | A = {an: None for an in ["CD", "NE", "CZ", "NH1", "NH2"]} |
| | for an in A: |
| | atom = residue.find_atom(an) |
| | if atom is not None and atom.num_locations(): |
| | A[an] = atom.get_location(0) |
| | if all([a is not None for n, a in A.items()]): |
| | dihe1 = System.dihedral( |
| | A["CD"], A["NE"], A["CZ"], A["NH1"] |
| | ) |
| | dihe2 = System.dihedral( |
| | A["CD"], A["NE"], A["CZ"], A["NH2"] |
| | ) |
| | if abs(dihe1) > abs(dihe2): |
| | A["NH1"].name = "NH2" |
| | A["NH2"].name = "NH1" |
| | elif drop_coors_unknowns: |
| | delete_atoms = True |
| |
|
| | if not drop_coors_missing_backbone: |
| | if not delete_atoms and not residue.has_full_backbone(): |
| | delete_atoms = True |
| |
|
| | if delete_atoms: |
| | residue.delete_atoms() |
| | cleared_residues += 1 |
| |
|
| | |
| | |
| | |
| | |
| | if ( |
| | not filter_by_entity |
| | and (cleared_residues != 0) |
| | and (cleared_residues == chain.num_residues()) |
| | ): |
| | chains_to_delete.append(chain) |
| |
|
| | |
| | |
| | for entity_id, ops in residues_to_rename.items(): |
| | chain_indices = set(ops.keys()) |
| | entity_chains = set(self.get_chains_of_entity(entity_id, by="index")) |
| | unique_renames = set([tuple(v) for v in ops.values()]) |
| | fork = True |
| | if (chain_indices == entity_chains) and (len(unique_renames) == 1): |
| | |
| | fork = False |
| | for ci, renames in ops.items(): |
| | chain = self.get_chain(ci) |
| | for ri, new_name in renames: |
| | chain.get_residue(ri).rename(new_name, fork_entity=fork) |
| |
|
| | |
| | for chain in reversed(chains_to_delete): |
| | chain.delete() |
| |
|
| | self._reindex() |
| |
|
| | def sequence(self, format="three-letter-list"): |
| | """Returns the full sequence of this System, concatenated over all |
| | chains in their order within the System. |
| | |
| | Args: |
| | format (str): sequence format. Possible options are either |
| | "three-letter-list" (default) or "one-letter-string". |
| | |
| | Returns: |
| | List (default) or string. |
| | """ |
| | if format == "three-letter-list": |
| | seq = [] |
| | else: |
| | seq = "" |
| |
|
| | for chain in self.chains(): |
| | seq = seq + chain.sequence(format) |
| | return seq |
| |
|
| | @staticmethod |
| | def distance(a1: AtomLocationView, a2: AtomLocationView): |
| | """Computes the distance between atom locations `a1` and `a2`.""" |
| | v21 = a1.coors - a2.coors |
| | return np.linalg.norm(v21) |
| |
|
| | @staticmethod |
| | def angle( |
| | a1: AtomLocationView, a2: AtomLocationView, a3: AtomLocationView, radians=False |
| | ): |
| | """Computes the angle formed by three 3D points represented by AtomLocationView objects. |
| | |
| | Args: |
| | a1, a2, a3 (AtomLocationView): three 3D points. |
| | radian (bool, optional): if True (default False), will return the angle in radians. |
| | Otherwise, in degrees. |
| | |
| | Returns: |
| | Angle `a1`-`a2`-`a3`. |
| | """ |
| | v21 = a1.coors - a2.coors |
| | v23 = a3.coors - a2.coors |
| | v21 = v21 / np.linalg.norm(v21) |
| | v23 = v23 / np.linalg.norm(v23) |
| | c = np.dot(v21, v23) |
| | return np.arctan2(np.sqrt(1 - c * c), c) * (1 if radians else 180.0 / np.pi) |
| |
|
| | @staticmethod |
| | def dihedral( |
| | a1: AtomLocationView, |
| | a2: AtomLocationView, |
| | a3: AtomLocationView, |
| | a4: AtomLocationView, |
| | radians=False, |
| | ): |
| | """Computes the dihedral angle formed by four 3D points represented by AtomLocationView objects. |
| | |
| | Args: |
| | a1, a2, a3, a4 (AtomLocationView): four 3D points. |
| | radian (bool, optional): if True (default False), will return the angle in radians. |
| | Otherwise, in degrees. |
| | |
| | Returns: |
| | Dihedral angle `a1`-`a2`-`a3`-`a4`. |
| | """ |
| | AB = a1.coors - a2.coors |
| | CB = a3.coors - a2.coors |
| | DC = a4.coors - a3.coors |
| |
|
| | if min([np.linalg.norm(p) for p in [AB, CB, DC]]) == 0.0: |
| | raise Exception("some points coincide in dihedral calculation") |
| |
|
| | ABxCB = np.cross(AB, CB) |
| | ABxCB = ABxCB / np.linalg.norm(ABxCB) |
| | DCxCB = np.cross(DC, CB) |
| | DCxCB = DCxCB / np.linalg.norm(DCxCB) |
| |
|
| | |
| | dotp = np.dot(ABxCB, DCxCB) |
| | if dotp > 1.0: |
| | dotp = 1.0 |
| | elif dotp < -1.0: |
| | dotp = -1.0 |
| |
|
| | angle = np.arccos(dotp) |
| | if np.dot(ABxCB, DC) > 0: |
| | angle *= -1 |
| | if not radians: |
| | angle *= 180.0 / np.pi |
| |
|
| | return angle |
| |
|
| | @staticmethod |
| | def protein_backbone_atom_type(atom_name: str, no_hyd=True, by_name=True): |
| | """Backbone atoms can be either nitrogens, carbons, oxigens, or hydrogens. |
| | Specifically, possible known names in each category are: |
| | 'N', 'NT' |
| | 'CA', 'C', 'CY', 'CAY' |
| | 'OY', 'O', 'OCT*', 'OXT', 'OT1', 'OT2' |
| | 'H', 'HY*', 'HA*', 'HN', 'HT*', '1H', '2H', '3H' |
| | """ |
| | array = ["N", "CA", "C", "O", "H"] if by_name else [0, 1, 2, 3, 4] |
| | if atom_name in ["N", "NT"]: |
| | return array[0] |
| | if atom_name == "CA": |
| | return array[1] |
| | if (atom_name == "C") or (atom_name == "CY"): |
| | return array[2] |
| | if atom_name in ["O", "OY", "OXT", "OT1", "OT2"] or atom_name.startswith("OCT"): |
| | return array[3] |
| | if not no_hyd: |
| | if atom_name in ["H", "HA", "HN"]: |
| | return array[4] |
| | if atom_name.startswith("HT") or atom_name.startswith("HY"): |
| | return array[4] |
| | |
| | if ( |
| | atom_name.startswith("1H") |
| | or atom_name.startswith("2H") |
| | or atom_name.startswith("3H") |
| | ): |
| | return array[4] |
| | return None |
| |
|
| |
|
| | @dataclass |
| | class SystemEntity: |
| | """A molecular entity represented in a molecular system.""" |
| |
|
| | _type: str |
| | _desc: str |
| | _polymer_type: str |
| | _seq: list |
| | _het: list |
| |
|
| | def is_polymer(self): |
| | """Returns whether the entity represents a polymer.""" |
| | return self._type == "polymer" |
| |
|
| | @classmethod |
| | def guess_entity_and_polymer_type(cls, seq: List): |
| | is_poly = np.mean([polyseq.is_polymer_residue(res, None) for res in seq]) > 0.8 |
| | polymer_type = None |
| | if is_poly: |
| | entity_type = "polymer" |
| | for ptype in polyseq.polymerType: |
| | if ( |
| | np.mean([polyseq.is_polymer_residue(res, ptype) for res in seq]) |
| | > 0.8 |
| | ): |
| | polymer_type = polyseq.polymer_type_name(ptype) |
| | break |
| | else: |
| | entity_type = "unknown" |
| |
|
| | return entity_type, polymer_type |
| |
|
| | @property |
| | def type(self): |
| | return self._type |
| |
|
| | @property |
| | def description(self): |
| | return self._desc |
| |
|
| | @property |
| | def polymer_type(self): |
| | return self._polymer_type |
| |
|
| | @property |
| | def sequence(self): |
| | return self._seq |
| |
|
| | @property |
| | def hetero(self): |
| | return self._het |
| |
|
| |
|
| | @dataclass |
| | class BaseView: |
| | """An abstract base "view" class for accessing different parts of System.""" |
| |
|
| | _ix: int |
| | _parent: object |
| |
|
| | def get_index(self): |
| | """Return the index of this atom location in its System.""" |
| | return self._ix |
| |
|
| | def is_valid(self): |
| | return self._ix >= 0 and self._parent is not None |
| |
|
| | def _delete(self): |
| | at = self._ix - self.parent._siblings.child_index(self.parent._ix, 0) |
| | self.parent._siblings.delete_child(self.parent._ix, at) |
| |
|
| | @property |
| | def parent(self): |
| | return self._parent |
| |
|
| |
|
| | @dataclass |
| | class ChainView(BaseView): |
| | """A Chain view, allowing hierarchical exploration and editing.""" |
| |
|
| | def __init__(self, ix: int, system: System): |
| | self._ix = ix |
| | self._parent = system |
| | self._siblings = system._chains |
| |
|
| | def __str__(self): |
| | return f"{self.cid} ({self.segid}/{self.authid}) -> {str(self.system)}" |
| |
|
| | def residues(self): |
| | for rn in range(self.num_residues()): |
| | ri = self._siblings.child_index(self._ix, rn) |
| | yield ResidueView(ri, self) |
| |
|
| | def num_residues(self): |
| | """Returns the number of residues in the Chain.""" |
| | return self._siblings.num_children(self._ix) |
| |
|
| | def num_structured_residues(self): |
| | return sum([res.has_structure() for res in self.residues()]) |
| |
|
| | def num_atoms(self): |
| | return sum([res.num_atoms() for res in self.residues()]) |
| |
|
| | def num_atom_locations(self): |
| | return sum([res.num_atom_locations() for res in self.residues()]) |
| |
|
| | def sequence(self, format="three-letter-list"): |
| | """Returns the sequence of this chain. See `System::sequence()` for |
| | possible formats. |
| | """ |
| | if format == "three-letter-list": |
| | seq = [None] * self.num_residues() |
| | for ri, residue in enumerate(self.residues()): |
| | seq[ri] = residue.name |
| | return seq |
| | elif format == "one-letter-string": |
| | import src.data.protein.polyseq as polyseq |
| |
|
| | seq = [None] * self.num_residues() |
| | for ri, residue in enumerate(self.residues()): |
| | seq[ri] = polyseq.to_single(residue.name) |
| | return "".join(seq) |
| | else: |
| | raise Exception(f"unknown sequence format {format}") |
| |
|
| | def get_residue(self, ri: int): |
| | """Get the residue at the specified index within the Chain. |
| | |
| | Args: |
| | ri (int): Residue index within the Chain. |
| | |
| | Returns: |
| | ResidueView object corresponding to the residue in question. |
| | """ |
| | if ri < 0 or ri >= self.num_residues(): |
| | raise Exception( |
| | f"residue index {ri} out of range for Chain, which has {self.num_residues()} residues" |
| | ) |
| | ri = self._siblings.child_index(self._ix, ri) |
| | return ResidueView(ri, self) |
| |
|
| | def get_residue_index(self, residue: ResidueView): |
| | """Get the index of the given residue in this Chain.""" |
| | return residue._ix - self._siblings.child_index(self._ix, 0) |
| |
|
| | def get_atom(self, aidx: int): |
| | """Get the atom at index `aidx` within this chain.""" |
| | if aidx < 0: |
| | raise Exception(f"negative atom index: {aidx}") |
| | off = 0 |
| | for residue in self.residues(): |
| | na = residue.num_atoms() |
| | if aidx < off + na: |
| | return residue.get_atom(aidx - off) |
| | off = off + na |
| | raise Exception( |
| | f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" |
| | ) |
| |
|
| | def get_atoms(self): |
| | """Return a list of all atoms in this chain.""" |
| | atoms_views = [] |
| | for residue in self.residues(): |
| | atoms_views.extend(residue.get_atoms()) |
| | return atoms_views |
| |
|
| | def __getitem__(self, res_idx: int): |
| | return self.get_residue(res_idx) |
| |
|
| | def get_entity_id(self): |
| | """Return the entity ID corresponding to this chain.""" |
| | return self.system._chain_entities[self._ix] |
| |
|
| | def get_entity(self): |
| | """Return the entity this chain belongs to.""" |
| | entity_id = self.get_entity_id() |
| | if entity_id is None: |
| | return None |
| | return self.system._entities[entity_id] |
| |
|
| | def check_sequence(self): |
| | """Compare the list of residue names of this chain to the corresponding entity sequence record.""" |
| | entity = self.get_entity() |
| | if entity is not None and entity.is_polymer(): |
| | if self.num_residues() != len(entity._seq): |
| | return False |
| | for res, ent_aan in zip(self.residues(), entity._seq): |
| | if res.name != ent_aan: |
| | return False |
| | return True |
| |
|
| | def add_residue(self, name: str, num: int, authid: str, icode: str = " ", at=None): |
| | """Add a new residue to this chain. |
| | |
| | Args: |
| | name (str): Residue name. |
| | num (int): Residue number (i.e., residue ID). |
| | authid (str): Author residue ID. |
| | icode (str): Insertion code. |
| | at (int, optional): Index at which to insert the residue. Default |
| | is to append to the end of the chain (i.e., equivalent of ``at` |
| | being equal to the present length of the chain). |
| | """ |
| | if at is None: |
| | at = self.num_residues() |
| | ri = self._siblings.insert_child( |
| | self._ix, |
| | at, |
| | {"name": name, "resnum": num, "authresid": authid, "icode": icode}, |
| | ) |
| | return ResidueView(ri, self) |
| |
|
| | def delete(self, keep_entity=False): |
| | """Deletes this Chain from its System. |
| | |
| | Args: |
| | keep_entity (bool, optional): If False (default) and if the chain |
| | being deleted happens to be the last representative of the |
| | entity it belongs to, the entity will be deleted. If True, the |
| | entity will always be kept. |
| | """ |
| | |
| | self.system._assembly_info.delete_chain(self.cid) |
| |
|
| | |
| | if not keep_entity: |
| | eid = self.get_entity_id() |
| | if self.system.num_chains_of_entity(eid) == 0: |
| | self.system.delete_entity(eid) |
| |
|
| | self.system._chain_entities.pop(self._ix) |
| | self._siblings.delete(self._ix) |
| | self._ix = -1 |
| |
|
| | @property |
| | def system(self): |
| | return self._parent |
| |
|
| | @property |
| | def cid(self): |
| | return self._siblings["cid"][self._ix] |
| |
|
| | @property |
| | def segid(self): |
| | return self._siblings["segid"][self._ix] |
| |
|
| | @property |
| | def authid(self): |
| | return self._siblings["authid"][self._ix] |
| |
|
| | @cid.setter |
| | def cid(self, val): |
| | self._siblings["cid"][self._ix] = val |
| |
|
| | @segid.setter |
| | def segid(self, val): |
| | self._siblings["segid"][self._ix] = val |
| |
|
| | @authid.setter |
| | def authid(self, val): |
| | self._siblings["authid"][self._ix] = val |
| |
|
| |
|
| | @dataclass |
| | class ResidueView(BaseView): |
| | """A Residue view, allowing hierarchical exploration and editing.""" |
| |
|
| | def __init__(self, ix: int, chain: ChainView): |
| | self._ix = ix |
| | self._parent = chain |
| | self._siblings = chain.system._residues |
| |
|
| | def __str__(self): |
| | return f"{self.name} {self.num} ({self.authid}) -> {str(self.chain)}" |
| |
|
| | def atoms(self): |
| | off = self._siblings.child_index(self._ix, 0) |
| | for an in range(self.num_atoms()): |
| | yield AtomView(off + an, self) |
| |
|
| | def num_atoms(self): |
| | return self._siblings.num_children(self._ix) |
| |
|
| | def num_atom_locations(self): |
| | return sum([a.num_locations() for a in self.atoms()]) |
| |
|
| | def has_structure(self): |
| | """Returns whether the atom has any structural information (i.e., one or more locations).""" |
| | for a in self.atoms(): |
| | if a.num_locations(): |
| | return True |
| | return False |
| |
|
| | def get_atom(self, ai: int): |
| | """Get the atom at the specified index within the Residue. |
| | |
| | Args: |
| | atom_idx (int): Atom index within the Residue. |
| | |
| | Returns: |
| | AtomView object corresponding to the atom in question. |
| | """ |
| |
|
| | if ai < 0 or ai >= self.num_atoms(): |
| | raise Exception( |
| | f"atom index {ai} out of range for Residue, which has {self.num_atoms()} atoms" |
| | ) |
| | ai = self._siblings.child_index(self._ix, ai) |
| | return AtomView(ai, self) |
| |
|
| | def get_atom_index(self, atom: AtomView): |
| | """Get the index of the given atom in this Residue.""" |
| | return atom._ix - self._siblings.child_index(self._ix, 0) |
| |
|
| | def find_atom(self, name): |
| | """Find and return the first atom (as AtomView object) with the given name |
| | within the Residue or None.""" |
| | for atom in self.atoms(): |
| | if atom.name == name: |
| | return atom |
| | return None |
| |
|
| | def __getitem__(self, atom_idx: int): |
| | return self.get_atom(atom_idx) |
| |
|
| | def get_index_in_chain(self): |
| | """Return the index of the Residue in its parent Chain.""" |
| | return self.chain.get_residue_index(self) |
| |
|
| | def rename(self, new_name: str, fork_entity=True): |
| | """Assigns the residue a new name with all proper updates. |
| | |
| | Args: |
| | new_name (str): New residue name. |
| | fork_entity (bool, optional): If True (default) and if parent |
| | chain corresponds to an entity that has other chains |
| | associated with it and there is a real renaming (i.e., |
| | the old name is not the same as the new name), will |
| | make a new (duplicate) entity for to this chain and |
| | will edit the new one, leaving the old one unchanged. |
| | If False, will not perform this regardless. NOTE: |
| | setting this to False can create an inconsistent state |
| | between chain and entity sequence information. |
| | """ |
| | entity_id = self.chain.get_entity_id() |
| | if entity_id is not None: |
| | entity = self.system._entities[entity_id] |
| | ri = self.get_index_in_chain() |
| | if fork_entity and (entity._seq[ri] != new_name): |
| | ci = self.chain.get_index() |
| | entity_id = self.system._ensure_unique_entity(ci) |
| | entity = self.system._entities[entity_id] |
| | entity._seq[ri] = new_name |
| | self._siblings["name"][self._ix] = new_name |
| |
|
| | def add_atom( |
| | self, |
| | name: str, |
| | het: bool, |
| | x: float = None, |
| | y: float = None, |
| | z: float = None, |
| | occ: float = 1.0, |
| | B: float = 0.0, |
| | alt: str = " ", |
| | at=None, |
| | ): |
| | """Adds a new atom to the residue (appending it at the end) and |
| | returns an AtomView to it. If atom location information is |
| | specified, will also add a location to the atom. |
| | |
| | Args: |
| | name (str): Atom name. |
| | het (bool): Whether it is a hetero-atom. |
| | x, y, z (float): Atom location coordinates. |
| | occ (float): Occupancy. |
| | B (float): B-factor. |
| | alt (str): Alternative position character. |
| | at (int, optional): Index at which to insert the atom. Default |
| | is to append to the end of the residue (i.e., equivalent of |
| | ``at` being equal to the number of atoms in the residue). |
| | |
| | Returns: |
| | AtomView object corresponding to the newly added atom. |
| | """ |
| | if at is None: |
| | at = self.num_atoms() |
| | ai = self._siblings.insert_child(self._ix, at, {"name": name, "het": het}) |
| | atom = AtomView(ai, self) |
| |
|
| | |
| | if x is not None: |
| | atom.add_location(x, y, z, occ, B, alt) |
| |
|
| | return atom |
| |
|
| | def delete(self, fork_entity=True): |
| | """Deletes this residue from its Chain/System. |
| | |
| | Args: |
| | fork_entity (bool, optional): If True (default) and if parent |
| | chain corresponds to an entity that has other chains |
| | associated with it, will make a new (duplicate) entity |
| | for to this chain and will edit the new one, leaving the |
| | old one unchanged. If False, will not perform this. |
| | NOTE: setting this to False can create an inconsistent state |
| | between chain and entity sequence information. |
| | """ |
| | |
| | entity_id = self.chain.get_entity_id() |
| | if entity_id is not None: |
| | entity = self.system._entities[entity_id] |
| | ri = self.get_index_in_chain() |
| | if fork_entity: |
| | ci = self.chain.get_index() |
| | entity_id = self.system._ensure_unique_entity(ci) |
| | entity = self.system._entities[entity_id] |
| | entity._seq.pop(ri) |
| |
|
| | |
| | self._delete() |
| | self._ix = -1 |
| |
|
| | def delete_atoms(self, atoms=None): |
| | """Delete either the specified list of atoms or all atoms from the residue. |
| | |
| | Args: |
| | atoms (list, optional): List of AtomView objects corresponding to the |
| | atoms to delete. If not specified, will delete all atoms in the residue. |
| | """ |
| | if atoms is None: |
| | atoms = list(self.atoms()) |
| | for atom in reversed(atoms): |
| | if atom.residue != self: |
| | raise Exception(f"Atom {atom} does not belong to Residue {self}") |
| | atom.delete() |
| |
|
| | @property |
| | def chain(self): |
| | return self._parent |
| |
|
| | @property |
| | def system(self): |
| | return self.chain.system |
| |
|
| | @property |
| | def name(self): |
| | return self._siblings["name"][self._ix] |
| |
|
| | @property |
| | def num(self): |
| | return self._siblings["resnum"][self._ix] |
| |
|
| | @property |
| | def authid(self): |
| | return self._siblings["authresid"][self._ix] |
| |
|
| | @property |
| | def icode(self): |
| | return self._siblings["icode"][self._ix] |
| |
|
| | def get_backbone(self, no_hyd=True): |
| | """Assuming that this is a protein residue (i.e., an amino acid), returns the |
| | list of atoms corresponding to the residue's backbone, in the order: |
| | backbone amide (N), alpha carbon (CA), carbonyl carbon (C), carbonyl oxygen (O), |
| | and amide hydrogen (H, optional). |
| | |
| | Args: |
| | no_hyd (bool, optional): If True (default), will exclude the amide hydrogen |
| | and only return four atoms. If False, will include the amide hydrogen. |
| | |
| | Returns: |
| | A list with each entry being an AtomView object corresponding to the backbone |
| | atom in the order above or None if the atom does not exist in the residue. |
| | """ |
| | bb = [None] * (4 if no_hyd else 5) |
| | left = len(bb) |
| | for atom in self.atoms(): |
| | i = System.protein_backbone_atom_type(atom.name, no_hyd) |
| | if i is None or bb[i] is not None: |
| | continue |
| | bb[i] = atom |
| | left = left - 1 |
| | if left == 0: |
| | break |
| | return bb |
| |
|
| | def has_full_backbone(self, no_hyd=True): |
| | """Assuming that this is a protein residue (i.e., an amino acid), returns |
| | whether the residue harbors a structurally defined backbone (i.e., has |
| | all backbone atoms each of which has location information). |
| | |
| | Args: |
| | no_hyd (bool, optional): If True (default), will ignore whether the amide |
| | hydrogen exists or not (if False will consider it). |
| | |
| | Returns: |
| | Boolean indicating whether there is a full backbone in the residue. |
| | """ |
| | bb = self.get_backbone(no_hyd) |
| | return all([(a is not None) and a.num_locations() for a in bb]) |
| |
|
| | def delete_non_backbone(self, no_hyd=True): |
| | """Assuming that this is a protein residue (i.e., an amino acid), deletes |
| | all atoms except backbone atoms. |
| | |
| | Args: |
| | no_hyd (bool, optional): If True (default), will not consider the amide |
| | hydrogen as a backbone atom (if False will consider it). |
| | """ |
| | to_delete = [] |
| | for atom in self.atoms(): |
| | if System.protein_backbone_atom_type(atom.name, no_hyd) is None: |
| | to_delete.append(atom) |
| | self.delete_atoms(to_delete) |
| |
|
| |
|
| | @dataclass |
| | class AtomView(BaseView): |
| | """An Atom view, allowing hierarchical exploration and editing.""" |
| |
|
| | def __init__(self, ix: int, residue: ResidueView): |
| | self._ix = ix |
| | self._parent = residue |
| | self._siblings = residue.system._atoms |
| |
|
| | def __str__(self): |
| | string = self.name + (" (HET) " if self.het else " ") |
| | if self.num_locations() > 0: |
| | string = string + str(self.get_location(0)) |
| | string = string + f" ({self.num_locations()})" |
| | return string + " -> " + str(self.residue) |
| |
|
| | def locations(self): |
| | off = self._siblings.child_index(self._ix, 0) |
| | for ln in range(self.num_locations()): |
| | yield AtomLocationView(off + ln, self) |
| |
|
| | def num_locations(self): |
| | return self._siblings.num_children(self._ix) |
| |
|
| | def __getitem__(self, loc_idx: int): |
| | return self.get_location(loc_idx) |
| |
|
| | def get_location(self, li: int = 0): |
| | """Returns the (li+1)-th location of the atom.""" |
| | if li < 0 or li >= self.num_locations(): |
| | raise Exception( |
| | f"location index {li} out of range for Atom with {self.num_locations()} locations" |
| | ) |
| | li = self._siblings.child_index(self._ix, li) |
| | return AtomLocationView(li, self) |
| |
|
| | def add_location(self, x, y, z, occ=1.0, B=0.0, alt=" ", at=None): |
| | """Adds a location to this atom, append it to the end. |
| | |
| | Args: |
| | x, y, z (float): coordinates of the location. |
| | occ (float): occupancy for the location. |
| | B (float): B-factor for the location. |
| | alt (str): alternative location character. |
| | at (int, optional): Index at which to insert the location. Default |
| | is to append at the end (i.e., equivalent of ``at` being equal |
| | to the current number of locations). |
| | """ |
| | if at is None: |
| | at = self.num_locations() |
| | li = self._siblings.insert_child( |
| | self._ix, at, {"coor": [x, y, z, occ, B], "alt": alt} |
| | ) |
| | return AtomLocationView(li, self) |
| |
|
| | def delete(self): |
| | """Deletes this atom from its Residue/Chain/System.""" |
| | self._delete() |
| | self._ix = -1 |
| |
|
| | @property |
| | def residue(self): |
| | return self._parent |
| |
|
| | @property |
| | def chain(self): |
| | return self.residue.chain |
| |
|
| | @property |
| | def system(self): |
| | return self.chain.system |
| |
|
| | @property |
| | def name(self): |
| | return self._siblings["name"][self._ix] |
| |
|
| | @property |
| | def het(self): |
| | return self._siblings["het"][self._ix] |
| |
|
| | """Location information getters and setters operate on the default (first) |
| | location for this atom and throw an index error if there are no locations.""" |
| |
|
| | @property |
| | def x(self): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | return self.system._locations["coor"][ix, 0] |
| |
|
| | @property |
| | def y(self): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | return self.system._locations["coor"][ix, 1] |
| |
|
| | @property |
| | def z(self): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | return self.system._locations["coor"][ix, 2] |
| |
|
| | @property |
| | def coors(self): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | return self.system._locations["coor"][ix, 0:3] |
| |
|
| | @property |
| | def occ(self): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | return self.system._locations["coor"][ix, 3] |
| |
|
| | @property |
| | def B(self): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | return self.system._locations["coor"][ix, 4] |
| |
|
| | @property |
| | def alt(self): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | return self.system._locations["alt"][ix] |
| |
|
| | @x.setter |
| | def x(self, val): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | self.system._locations["coor"][ix, 0] = val |
| |
|
| | @y.setter |
| | def y(self, val): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | self.system._locations["coor"][ix, 1] = val |
| |
|
| | @z.setter |
| | def z(self, val): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | self.system._locations["coor"][ix, 2] = val |
| |
|
| | @occ.setter |
| | def occ(self, val): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | self.system._locations["coor"][ix, 3] = val |
| |
|
| | @B.setter |
| | def B(self, val): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | self.system._locations["coor"][ix, 4] = val |
| |
|
| | @alt.setter |
| | def alt(self, val): |
| | if self._siblings.num_children(self._ix) == 0: |
| | raise Exception("atom has no locations") |
| | ix = self._siblings.child_index(self._ix, 0) |
| | self.system._locations["alt"][ix] = val |
| |
|
| |
|
| | class DummyAtomView(AtomView): |
| | """An dummy Atom view that can be attached to a residue but that does not |
| | have any locations and with no other information.""" |
| |
|
| | def __init__(self, residue: ResidueView): |
| | self._ix = -1 |
| | self._parent = residue |
| |
|
| | def __str__(self): |
| | return "DUMMY -> " + str(self.residue) |
| |
|
| | def locations(self): |
| | return |
| | yield |
| |
|
| | def num_locations(self): |
| | return 0 |
| |
|
| | def __getitem__(self, loc_idx: int): |
| | return None |
| |
|
| | def get_location(self, li: int = 0): |
| | raise Exception(f"no locations in DUMMY atom") |
| |
|
| | def add_location(self, x, y, z, occ, B, alt, at=None): |
| | raise Exception(f"can't add no locations to DUMMY atom") |
| |
|
| | @property |
| | def residue(self): |
| | return self._parent |
| |
|
| | @property |
| | def chain(self): |
| | return self.residue.chain |
| |
|
| | @property |
| | def system(self): |
| | return self.chain.system |
| |
|
| | @property |
| | def name(self): |
| | return None |
| |
|
| | @property |
| | def het(self): |
| | return None |
| |
|
| | @property |
| | def x(self): |
| | raise Exception(f"no coordinates in DUMMY atom") |
| |
|
| | @property |
| | def y(self): |
| | raise Exception(f"no coordinates in DUMMY atom") |
| |
|
| | @property |
| | def z(self): |
| | raise Exception(f"no coordinates in DUMMY atom") |
| |
|
| | @property |
| | def occ(self): |
| | raise Exception(f"no occupancy in DUMMY atom") |
| |
|
| | @property |
| | def B(self): |
| | raise Exception(f"no B-factor in DUMMY atom") |
| |
|
| | @property |
| | def alt(self): |
| | raise Exception(f"no alt flag in DUMMY atom") |
| |
|
| | @x.setter |
| | def x(self, val): |
| | raise Exception(f"can't set coordinate for DUMMY atom") |
| |
|
| | @y.setter |
| | def y(self, val): |
| | raise Exception(f"can't set coordinate for DUMMY atom") |
| |
|
| | @z.setter |
| | def z(self, val): |
| | raise Exception(f"can't set coordinate for DUMMY atom") |
| |
|
| | @occ.setter |
| | def occ(self, val): |
| | raise Exception(f"can't set occupancy for DUMMY atom") |
| |
|
| | @B.setter |
| | def B(self, val): |
| | raise Exception(f"can't set B-factor for DUMMY atom") |
| |
|
| | @alt.setter |
| | def alt(self, val): |
| | raise Exception(f"can't set alt flag for DUMMY atom") |
| |
|
| |
|
| | @dataclass |
| | class AtomLocationView(BaseView): |
| | """An AtomLocation view, allowing hierarchical exploration and editing.""" |
| |
|
| | def __init__(self, ix: int, atom: AtomView): |
| | self._ix = ix |
| | self._parent = atom |
| | self._siblings = atom.system._locations |
| |
|
| | def __str__(self): |
| | return f"{self.x} {self.y} {self.z}" |
| |
|
| | def swap(self, other: AtomLocationView): |
| | """Swaps information between itself and the provided atom location. |
| | |
| | Args: |
| | other (AtomLocationView): the other atom location to swap with. |
| | """ |
| | self.x, other.x = other.x, self.x |
| | self.y, other.y = other.y, self.y |
| | self.z, other.z = other.z, self.z |
| | self.occ, other.occ = other.occ, self.occ |
| | self.B, other.B = other.B, self.B |
| | self.alt, other.alt = other.alt, self.alt |
| |
|
| | def defined(self): |
| | """Return whether this is a valid location.""" |
| | return (self.x is not None) and (self.y is not None) and (self.z is not None) |
| |
|
| | @property |
| | def atom(self): |
| | return self._parent |
| |
|
| | @property |
| | def residue(self): |
| | return self.atom.residue |
| |
|
| | @property |
| | def chain(self): |
| | return self.residue.chain |
| |
|
| | @property |
| | def system(self): |
| | return self.chain.system |
| |
|
| | @property |
| | def x(self): |
| | return self.system._locations["coor"][self._ix, 0] |
| |
|
| | @property |
| | def y(self): |
| | return self.system._locations["coor"][self._ix, 1] |
| |
|
| | @property |
| | def z(self): |
| | return self.system._locations["coor"][self._ix, 2] |
| |
|
| | @property |
| | def occ(self): |
| | return self.system._locations["coor"][self._ix, 3] |
| |
|
| | @property |
| | def B(self): |
| | return self.system._locations["coor"][self._ix, 4] |
| |
|
| | @property |
| | def alt(self): |
| | return self.system._locations["alt"][self._ix] |
| |
|
| | @property |
| | def coors(self): |
| | return np.array(self.system._locations["coor"][self._ix, 0:3]) |
| |
|
| | @property |
| | def coor_info(self): |
| | return np.array(self.system._locations["coor"][self._ix]) |
| |
|
| | @x.setter |
| | def x(self, val): |
| | self.system._locations["coor"][self._ix, 0] = val |
| |
|
| | @y.setter |
| | def y(self, val): |
| | self.system._locations["coor"][self._ix, 1] = val |
| |
|
| | @z.setter |
| | def z(self, val): |
| | self.system._locations["coor"][self._ix, 2] = val |
| |
|
| | @coors.setter |
| | def coors(self, val): |
| | self.system._locations["coor"][self._ix, 0:3] = val |
| |
|
| | @coor_info.setter |
| | def coor_info(self, val): |
| | self.system._locations["coor"][self._ix] = val |
| |
|
| | @occ.setter |
| | def occ(self, val): |
| | self.system._locations["coor"][self._ix, 3] = val |
| |
|
| | @B.setter |
| | def B(self, val): |
| | self.system._locations["coor"][self._ix, 4] = val |
| |
|
| | @alt.setter |
| | def alt(self, val): |
| | self.system._locations["alt"][self._ix] = val |
| |
|
| |
|
| | class ExpressionTreeEvaluator: |
| | """A class for evaluating custom logical parenthetical expressions. The |
| | implementation is very generic, supports nullary, unary, and binary |
| | operators, and does not know anything about what the expressions actually |
| | mean. Instead the class interprets the expression as a tree of sub- |
| | expressions, governed by parentheses and operators, and traverses the |
| | calling upon a user-specified evaluation function to evaluate leaf |
| | nodes as the tree is gradually collapsed into a single node. This |
| | can be used for evaluating set expressions, algebraic expressions, and |
| | others. |
| | |
| | Args: |
| | operators_nullary (list): A list of strings designating nullary operators |
| | (i.e., operators that do not have any operands). E.g., if the language |
| | describes selection algebra, these could be "hyd", "all", or "none"]. |
| | operators_unary (list): A list of strings designating unary operators |
| | (i.e., operators that have one operand, which must comes to the right |
| | of the operator). E.g., if the language describes selection algebra, |
| | these could be "name", "resid", or "chain". |
| | operators_binary (list): A list of strings designating binary operators |
| | (i.e., operators that have two operands, one on each side of the |
| | operator). E.g., if the language describes selection algebra, thse |
| | could be "and", "or", or "around". |
| | eval_function (str): A function that is able to evaluate a leaf node of |
| | the expression tree. It shall accept three parameters: |
| | |
| | operator (str): name of the operator |
| | left: the left operand. Will be None if the left operand is missing or |
| | not relevant. Otherwise, can be either a list of strings, which |
| | should represent an evaluatable sub-expression corresponding to the |
| | left operand, or the result of a prior evaluation of this function. |
| | right: Same as `left` but for the right operand. |
| | |
| | The function should attempt to evaluate the resulting expression and |
| | return None in the case of failing or a dictionary with the result of |
| | the evaluation stored under key "result". |
| | left_associativity (bool): If True (the default), operators are taken to be |
| | left-associative. Meaning something like "A and B or C" is "(A and B) or C". |
| | If False, the operators are taken to be right-associative, such that |
| | the same expression becomes "A and (B or C)". NOTE: MST is right-associative |
| | but often human intiution tends to be left-associative. |
| | debug (bool): If True (default is false), will print a great deal of debugging |
| | messages to help diagnose any evaluation problems. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | operators_nullary: list, |
| | operators_unary: list, |
| | operators_binary: list, |
| | eval_function: function, |
| | left_associativity: bool = True, |
| | debug: bool = False, |
| | ): |
| | self.operators_nullary = operators_nullary |
| | self.operators_unary = operators_unary |
| | self.operators_binary = operators_binary |
| | self.operators = operators_nullary + operators_unary + operators_binary |
| | self.eval_function = eval_function |
| | self.debug = debug |
| | self.left_associativity = left_associativity |
| |
|
| | def _traverse_expression_tree(self, E, i=0, eval_all=True, debug=False): |
| | def _collect_operands(E, j): |
| | |
| | operands = [] |
| | for k in range(len(E[j:])): |
| | if E[j + k] in self.operators: |
| | k = k - 1 |
| | break |
| | operands.append(E[j + k]) |
| | return operands, j + k + 1 |
| |
|
| | def _find_matching_close_paren(E, beg: int): |
| | c = 0 |
| | for i in range(beg, len(E)): |
| | if E[i] == "(": |
| | c = c + 1 |
| | elif E[i] == ")": |
| | c = c - 1 |
| | if c == 0: |
| | return i |
| | return None |
| |
|
| | def _my_eval(op, left, right, debug=False): |
| | if debug: |
| | print( |
| | f"\t-> evaluating {operand_str(left)} | {op} | {operand_str(right)}" |
| | ) |
| | result = self.eval_function(op, left, right) |
| | if debug: |
| | print(f"\t-> got result {operand_str(result)}") |
| | return result |
| |
|
| | def operand_str(operand): |
| | if isinstance(operand, dict): |
| | if "result" in operand and len(operand["result"]) > 15: |
| | vec = list(operand["result"]) |
| | beg = ", ".join([str(i) for i in vec[:5]]) |
| | end = ", ".join([str(i) for i in vec[-5:]]) |
| | return "{'result': " + f"{beg} ... {end} ({len(vec)} long)" + "}" |
| | return str(operand) |
| | return str(operand) |
| |
|
| | left, right, op = None, None, None |
| | if debug: |
| | print(f"-> received {E[i:]}") |
| |
|
| | while i < len(E): |
| | if all([x is None for x in (left, right, op)]): |
| | |
| | if E[i] == "(": |
| | end = _find_matching_close_paren(E, i) |
| | if end is None: |
| | return None, f"parenthesis imbalance starting with {E[i:]}" |
| | |
| | left, rem = self._traverse_expression_tree( |
| | E[i + 1 : end], 0, eval_all=True, debug=debug |
| | ) |
| | if left is None: |
| | return None, rem |
| | i = end + 1 |
| | if not eval_all: |
| | return left, i |
| | elif E[i] in self.operators_nullary: |
| | |
| | left = _my_eval(E[i], None, None, debug) |
| | if left is None: |
| | return None, f"failed to evaluate nullary operator '{E[i]}'" |
| | i = i + 1 |
| | elif E[i] in self.operators_unary: |
| | op = E[i] |
| | i = i + 1 |
| | elif E[i] in self.operators: |
| | |
| | return None, f"unexpected binary operator in the context {E[i:]}" |
| | else: |
| | |
| | left, i = _collect_operands(E, i) |
| | elif (left is not None) and (op is None) and (right is None): |
| | |
| | if E[i] not in self.operators_binary: |
| | return ( |
| | None, |
| | f"expected end or a binary operator when got '{E[i]}' in expression: {E}", |
| | ) |
| | op = E[i] |
| | i = i + 1 |
| | elif ( |
| | (left is None) and (op in self.operators_unary) and (right is None) |
| | ) or ( |
| | (left is not None) and (op in self.operators_binary) and (right is None) |
| | ): |
| | |
| | |
| | |
| | if ( |
| | E[i] in (self.operators_nullary + self.operators_unary) |
| | or E[i] == "(" |
| | ): |
| | right, i = self._traverse_expression_tree( |
| | E, i, eval_all=not self.left_associativity, debug=debug |
| | ) |
| | if right is None: |
| | return None, i |
| | else: |
| | right, i = _collect_operands(E, i) |
| |
|
| | |
| | |
| | |
| | |
| | result = _my_eval(op, left, right, debug) |
| | if result is None: |
| | return ( |
| | None, |
| | f"failed to evaluate operator '{op}' (in expression {E}) with operands {operand_str(left)} and {operand_str(right)}", |
| | ) |
| | if not eval_all: |
| | return result, i |
| | left = result |
| | op, right = None, None |
| |
|
| | else: |
| | return ( |
| | None, |
| | f"encountered an unexpected condition when evaluating {E}: left is {operand_str(left)}, op is {op}, or right {operand_str(right)}", |
| | ) |
| |
|
| | if (op is not None) or (right is not None): |
| | return None, f"expression ended unexpectedly" |
| | if left is None: |
| | return None, f"failed to evaluate expression: {E}" |
| |
|
| | return left, i |
| |
|
| | def evaluate(self, expression: str): |
| | """Evaluates the expression and returns the result.""" |
| |
|
| | def _split_tokens(expr): |
| | |
| | parts = list(re.split("([()])", expr)) |
| | |
| | return [ |
| | t.strip() |
| | for p in parts |
| | for t in re.split("\s+", p.strip()) |
| | if t.strip() != "" |
| | ] |
| |
|
| | |
| | E = _split_tokens(expression) |
| | val, rem = self._traverse_expression_tree(E, debug=self.debug) |
| | if val is None: |
| | raise Exception( |
| | f"failed to evaluate expression: '{expression}', reason: {rem}" |
| | ) |
| |
|
| | return val["result"] |
| |
|