File size: 3,348 Bytes
084b58f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from collections import defaultdict
from rdkit.Chem.Draw import rdMolDraw2D
import numpy as np

class MoleculeDrawer:
    def __init__(self, output_dir="output/tmp"):
        self.output_dir = os.path.join(output_dir, "raw/images")
        os.makedirs(self.output_dir, exist_ok=True)
        self.aa2color_dict = {
            "Asp": (0.902, 0.039, 0.039), "Glu": (0.961, 0.1, 0.537), "Arg": (0.078, 0.353, 1), "Lys": (0.42, 0.353, 1),
            "His": (0.51, 0.51, 0.824), "Tyr": (0.196, 0.196, 0.667), "Phe": (0.341, 0.196, 0.667), "Trp": (0.706, 0.353, 0.706),
            "Asn": (0, 0.863, 0.863), "Gln": (0.5, 0.82, 0.863), "Met": (0.902, 0.902, 0), "Cys": (0.722, 0.902, 0),
            "Ser": (0.98, 0.588, 0), "Thr": (0, 0.612, 0.412), "Gly": (0.98, 0.922, 0.922), "Ala": (0.784, 0.784, 0.639),
            "Val": (0.059, 0.51, 0.059), "Leu": (0.29, 0.51, 0.059), "Ile": (0.29, 0.51, 0.471), "Pro": (1, 0.588, 0.51)
        }
    
    def sort_atom_highlights(self, mol):
        atom_highlights = defaultdict(list)
        for atom_idx in range(mol.GetNumAtoms()):
            labelled_atom = mol.GetAtomWithIdx(atom_idx)
            AA_label = labelled_atom.GetProp("AA")
            if self.label_belongs_to_AA(AA_label):
                three_letter_label = AA_label[:3]
                atom_highlights[atom_idx].append(self.aa2color_dict[three_letter_label])

        # Convert defaultdict to dict of lists
        return {k: list(v) for k, v in atom_highlights.items()}

    def create_colormap(self):
        legend_data = [(aa[:3], color) for aa, color in self.aa2color_dict.items() if aa != "Unk"]
        fig, ax = plt.subplots(figsize=(1, 1))
        cmap = ListedColormap([color for _, color in legend_data])
        cax = ax.matshow(np.arange(len(legend_data)).reshape(1, -1), cmap=cmap)
        cbar = fig.colorbar(cax, ticks=np.arange(len(legend_data)), aspect=5)
        cbar.set_ticklabels([label for label, _ in legend_data])
        cbar.ax.tick_params(labelsize=3)
        ax.axis("off")
        plt.savefig(os.path.join(self.output_dir, "colormap.png"), bbox_inches="tight", dpi=300)
        plt.close()
    
    def draw_input_mol(self, mol, mol_index, seq, bond_highlights):
        atom_highlights = self.sort_atom_highlights(mol)

        # Ensure bond_highlights is a dict of lists
        bond_highlights = {k: list(v) for k, v in bond_highlights.items()} if bond_highlights else {}

        mol_name = f"mol_{mol_index}"
        legend = f'{mol_name}\nseq: {seq}\n{"8< = peptide bond"}\nAA_NAME:SEEN_COUNT:SEQUENCE_POSITION\n'

        self.draw_mol(mol, atom_highlights, bond_highlights, legend, mol_name)
        self.create_colormap()

    
    def draw_mol(self, mol, atom_highlights, bond_highlights, legend, mol_name):
        view = rdMolDraw2D.MolDraw2DSVG(600, 300)
        view.drawOptions().useBWAtomPalette()
        view.DrawMoleculeWithHighlights(mol, legend, dict(atom_highlights), dict(bond_highlights), {}, {})
        view.FinishDrawing()
        with open(os.path.join(self.output_dir, f"{mol_name}.svg"), "w") as f:
            f.write(view.GetDrawingText())

    def label_belongs_to_AA(self, label):
        shorter_label = label[:3]
        return shorter_label != "Unk" and not label.startswith("X")