import torch import joblib import numpy as np import mdtraj as md import matplotlib.pyplot as plt import pyemma.coordinates as coor from .utils import compute_dihedral from matplotlib.colors import LinearSegmentedColormap class Plot: def __init__(self, args, mds): self.device = args.device self.save_dir = args.save_dir self.molecule = args.molecule self.start_state = args.start_state self.num_samples = args.num_samples self.start_position = mds.start_position self.target_position = mds.target_position self.energy_function = mds.energy_function def __call__(self): positions, potentials = [], [] for i in range(self.num_samples): position = np.load(f"{self.save_dir}/positions/{i}.npy").astype(np.float32) potential = self.energy_function(position)[1] positions.append(torch.from_numpy(position).to(self.device)) potentials.append(potential) self.paths(positions) def paths(self, positions): zorder = 32 circle_size = 500 saddle_size = 2400 custom_colors_1 = ["#05009E", "#6B67EE", "#50B2D7", "#B0ADF1"] custom_colors_2 = ["#05009E", "#6B67EE", "#50B2D7", "#F7EFFF"] custom_cmap_1 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_1) custom_cmap_2 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_2) if self.molecule == "aldp": angle_1 = [6, 8, 14, 16] angle_2 = [1, 6, 8, 14] plt.clf() plt.close() fig = plt.figure(figsize=(7, 7)) ax = fig.add_subplot(111) plt.xlim([-np.pi, np.pi]) plt.ylim([-np.pi, np.pi]) with open("./data/aldp/landscape.dat") as f: lines = f.readlines() dims = [90, 90] locations = torch.zeros((int(dims[0]), int(dims[1]), 2)) data = torch.zeros((int(dims[0]), int(dims[1]))) i = 0 for line in lines[1:]: splits = line[0:-1].split(" ") vals = [y for y in splits if y != ""] x = float(vals[0]) y = float(vals[1]) val = float(vals[-1]) locations[i // 90, i % 90, :] = torch.tensor([x, y]) data[i // 90, i % 90] = val i = i + 1 xs = np.arange(-np.pi, np.pi + 0.1, 0.1) ys = np.arange(-np.pi, np.pi + 0.1, 0.1) x, y = np.meshgrid(xs, ys) inp = torch.tensor(np.array([x, y])).view(2, -1).T loc = locations.view(-1, 2) distances = torch.cdist(inp, loc.double(), p=2) index = distances.argmin(dim=1) a = torch.div(index, locations.shape[0], rounding_mode="trunc") b = index % locations.shape[0] z = data[a, b] z = z.view(y.shape[0], y.shape[1]) plt.contourf(xs, ys, z, levels=100, zorder=0, cmap=custom_cmap_2) # changed cm = custom_cmap_2 """ax.set_prop_cycle( color=[cm(1.0 * i / len(positions)) for i in range(len(positions))] )""" for position in positions: psi = compute_dihedral(position[:, angle_1, :]).detach().cpu().numpy() phi = compute_dihedral(position[:, angle_2, :]).detach().cpu().numpy() # changed to white paths ax.plot( phi, psi, marker="o", linestyle="None", markersize=2, alpha=1.0, markerfacecolor="white", markeredgecolor="none", # no edge markeredgewidth=0, ) #mark endpoint end_phi, end_psi = phi[-1], psi[-1] ax.scatter( [end_phi], [end_psi], s=70, c="#D577FF", edgecolors="w", linewidths=0.8, zorder=zorder + 1, marker="o" # use "D" or "*" if you prefer ) start_psi = ( compute_dihedral(self.start_position[:, angle_1, :]) .detach() .cpu() .numpy() ) start_phi = ( compute_dihedral(self.start_position[:, angle_2, :]) .detach() .cpu() .numpy() ) target_psi = ( compute_dihedral(self.target_position[:, angle_1, :]) .detach() .cpu() .numpy() ) target_phi = ( compute_dihedral(self.target_position[:, angle_2, :]) .detach() .cpu() .numpy() ) phis_saddle = [-0.035, -0.017] psis_saddle = [1.605, -0.535] """ax.scatter( phis_saddle, psis_saddle, edgecolors="black", c="w", zorder=zorder, s=saddle_size, marker="*", )""" ax.scatter( start_phi, start_psi, edgecolors="w", c="#9793F8", zorder=zorder, s=circle_size, marker="*", ) ax.scatter( target_phi, target_psi, edgecolors="w", c="#9793F8", zorder=zorder, s=circle_size, marker="*", ) plt.xlabel("\u03A6", fontsize=35, fontweight="medium") plt.ylabel("\u03A8", fontsize=35, fontweight="medium") else: fig = plt.figure(figsize=(7, 7)) ax = fig.add_subplot(111) cm = plt.get_cmap("gist_rainbow") """ax.set_prop_cycle( color=[cm(1.0 * i / len(positions)) for i in range(len(positions))] )""" pmf = np.load(f"./data/{self.molecule}/pmf.npy") xs = np.load(f"./data/{self.molecule}/xs.npy") ys = np.load(f"./data/{self.molecule}/ys.npy") plt.pcolormesh(xs, ys, pmf.T, cmap=custom_cmap_1) tica_model = joblib.load(f"./data/{self.molecule}/tica_model.pkl") feat = coor.featurizer(f"./data/{self.molecule}/{self.start_state}.pdb") feat.add_backbone_torsions(cossin=True) for position in positions: traj = md.Trajectory( position.cpu().numpy(), md.load(f"./data/{self.molecule}/{self.start_state}.pdb").topology, ) feature = feat.transform(traj) tica = tica_model.transform(feature) ax.plot( tica[:, 0], tica[:, 1], marker="o", linestyle="None", markersize=2, alpha=1.0, markerfacecolor="white", markeredgecolor="none", # no edge markeredgewidth=0, ) end_x, end_y = tica[-1, 0], tica[-1, 1] ax.scatter( [end_x], [end_y], s=70, c="#D577FF", edgecolors="w", linewidths=0.8, zorder=zorder + 1, marker="o" ) start_position = md.Trajectory( self.start_position.cpu().numpy(), md.load(f"./data/{self.molecule}/{self.start_state}.pdb").topology, ) feature = feat.transform(start_position) start_tica = tica_model.transform(feature) ax.scatter( start_tica[:, 0], start_tica[:, 1], edgecolors="w", c="#9793F8", zorder=zorder, s=circle_size, marker="*", ) target_position = md.Trajectory( self.target_position.cpu().numpy(), md.load(f"./data/{self.molecule}/{self.start_state}.pdb").topology, ) feature = feat.transform(target_position) target_tica = tica_model.transform(feature) ax.scatter( target_tica[:, 0], target_tica[:, 1], edgecolors="w", c="#9793F8", zorder=zorder, s=circle_size, marker="*", ) plt.xlabel("TIC 1", fontsize=35, fontweight="medium") plt.ylabel("TIC 2", fontsize=35, fontweight="medium") plt.xlim(xs.min(), xs.max()) plt.ylim(ys.min(), ys.max()) plt.tick_params( left=False, right=False, labelleft=False, labelbottom=False, bottom=False, ) plt.tight_layout() plt.savefig(f"{self.save_dir}/paths.png", dpi=300, bbox_inches="tight") plt.show() plt.close() return fig