File size: 6,050 Bytes
cc95e47
 
 
 
aa7be3b
cc95e47
 
 
 
 
505dbb9
cc95e47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505dbb9
cc95e47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505dbb9
cc95e47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import pandas as pd
import matplotlib.pyplot as plt
import pypdb
import biotite.database.rcsb as rcsb
from lynxkite_core.ops import op
import os
import numpy as np
from Bio.PDB import PDBList, PDBParser, Superimposer


@op("LynxKite Graph Analytics", "PDB composite search")
def get_pdb_count(
    *, ligand_id: str, experimental_method: str, max_resolution: float, polymer_count: int
):
    """
    Query the RCSB PDB for structures matching specified criteria and
    return the list of matching PDB IDs.

    Parameters
    ----------
    bundle : LynxKiteBundle
        The workflow bundle (unused here, included for op compatibility).
    ligand_id : str
        Non-polymer component ID to filter on (e.g., 'STI').
    experimental_method : str
        Experimental method to filter by (e.g., 'X-RAY DIFFRACTION').
    max_resolution : float
        Maximum resolution (Å) to include (<= this value).
    polymer_count : int
        Exact number of polymer chains in the structure.

    Returns
    -------
    List[str]
        A list of PDB IDs matching all criteria.
    """

    # 1) Query by ligand ID
    q_ligand = rcsb.FieldQuery(
        "rcsb_nonpolymer_entity_container_identifiers.nonpolymer_comp_id", exact_match=ligand_id
    )
    count_ligand = rcsb.count(q_ligand)
    print(f"Number of matches for ligand '{ligand_id}': {count_ligand}")

    # 2) Query by experimental method
    q_method = rcsb.FieldQuery("exptl.method", exact_match=experimental_method)
    count_method = rcsb.count(q_method)
    print(f"Number of matches for experimental method '{experimental_method}': {count_method}")

    # 3) Query by resolution
    q_resolution = rcsb.FieldQuery(
        "rcsb_entry_info.resolution_combined", less_or_equal=max_resolution
    )
    count_resolution = rcsb.count(q_resolution)
    print(f"Number of matches with resolution ≤ {max_resolution}: {count_resolution}")

    # 4) Query by polymer chain count
    q_polymer = rcsb.FieldQuery(
        "rcsb_entry_info.deposited_polymer_entity_instance_count", equals=polymer_count
    )
    count_polymer = rcsb.count(q_polymer)
    print(f"Number of matches with polymer count == {polymer_count}: {count_polymer}")

    # 5) Composite query (AND all criteria)
    composite_q = rcsb.CompositeQuery([q_ligand, q_method, q_resolution, q_polymer], "and")
    pdb_ids = rcsb.search(composite_q)

    # print(f"Number of composite matches: {len(pdb_ids)}")
    # print("Selected PDB IDs:")
    # print(*pdb_ids)
    pdb_ids = rcsb.search(composite_q)
    # Fetch PDBx descriptors
    pdbs_info = [pypdb.get_all_info(pid) for pid in pdb_ids]
    print(pdbs_info)
    title = [pdb_info["struct"]["title"] for pdb_info in pdbs_info]

    # Build DataFrame
    return pd.DataFrame({"pdb_id": pdb_ids, "description": title})


@op("LynxKite Graph Analytics", "PDB alignment RMSD")
def compute_pdb_rmsd(df: pd.DataFrame, *, pdb_id_col: str = "pdb_id") -> pd.DataFrame:
    """
    Accepts a DataFrame with a column of PDB IDs, downloads PDB files,
    selects Cα atoms from chain A (or first chain), superimposes using BioPython's
    Superimposer, computes the RMSD matrix (only on common residues),
    and returns it as a DataFrame.

    Parameters
    ----------
    df : pd.DataFrame
        Input DataFrame containing PDB IDs.
    pdb_id_col : str
        Name of the column in `df` with PDB IDs (default 'pdb_id').

    Returns
    -------
    pd.DataFrame
        Square DataFrame of pairwise RMSD values (Å), indexed and columned by PDB IDs.
    """

    # Prepare PDB directory
    pdb_dir = "pdb_files"
    os.makedirs(pdb_dir, exist_ok=True)

    out = df.copy()
    ids = out[pdb_id_col].tolist()
    n = len(ids)

    pdbl = PDBList()
    parser = PDBParser(QUIET=True)
    atom_dicts = []
    for pid in ids:
        pdbl.retrieve_pdb_file(pid, pdir=pdb_dir, file_format="pdb")
        path = os.path.join(pdb_dir, f"pdb{pid.lower()}.ent")
        struct = parser.get_structure(pid, path)
        model = struct[0]
        try:
            chain = model["A"]
        except KeyError:
            chain = next(model.get_chains())
        ca_atoms = {residue.id: residue["CA"] for residue in chain if residue.has_id("CA")}
        atom_dicts.append(ca_atoms)

    rmsd_mat = np.zeros((n, n))
    sup = Superimposer()
    for i in range(n):
        for j in range(i + 1, n):
            common = sorted(set(atom_dicts[i].keys()) & set(atom_dicts[j].keys()))
            if not common:
                rmsd = np.nan
            else:
                fixed_atoms = [atom_dicts[i][k] for k in common]
                moving_atoms = [atom_dicts[j][k] for k in common]
                sup.set_atoms(fixed_atoms, moving_atoms)
                rmsd = sup.rms
            rmsd_mat[i, j] = rmsd_mat[j, i] = round(rmsd, 1) if not np.isnan(rmsd) else np.nan

    return pd.DataFrame(rmsd_mat, index=ids, columns=ids)


@op("LynxKite Graph Analytics", "Plot matrix", view="matplotlib")
def plot_heatmap_from_df(
    df: pd.DataFrame, *, value_label: str = "Value", title: str = None
) -> plt.Figure:
    """
    Plot a heatmap of a square DataFrame using matplotlib.

    Parameters
    ----------
    df : pd.DataFrame
        Square DataFrame of values to plot.
    value_label : str
        Label for the color bar.
    title : str, optional
        Title for the plot.

    Returns
    -------
    plt.Figure
        The matplotlib Figure object containing the heatmap.
    """
    fig, ax = plt.subplots()
    im = ax.imshow(df.values)
    # create and label the colorbar
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label(value_label)

    if title:
        ax.set_title(title)
    labels = df.index.tolist()
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=90)
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels)
    # Annotate each cell
    for i in range(df.shape[0]):
        for j in range(df.shape[1]):
            ax.text(j, i, f"{df.iat[i, j]:.1f}", ha="center", va="center")
    plt.tight_layout()
    return fig