| | import torch |
| | import numpy as np |
| | import os |
| | import glob |
| | import random |
| | import matplotlib |
| | import imageio |
| |
|
| | matplotlib.use('Agg') |
| | import matplotlib.pyplot as plt |
| | from analysis.molecule_builder import get_bond_order |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def save_xyz_file(path, one_hot, positions, atom_decoder, id_from=0, |
| | name='molecule', batch_mask=None): |
| | try: |
| | os.makedirs(path) |
| | except OSError: |
| | pass |
| |
|
| | if batch_mask is None: |
| | batch_mask = torch.zeros(len(one_hot)) |
| |
|
| | for batch_i in torch.unique(batch_mask): |
| | cur_batch_mask = (batch_mask == batch_i) |
| | n_atoms = int(torch.sum(cur_batch_mask).item()) |
| | f = open(path + name + '_' + "%03d.xyz" % (batch_i + id_from), "w") |
| | f.write("%d\n\n" % n_atoms) |
| | atoms = torch.argmax(one_hot[cur_batch_mask], dim=1) |
| | batch_pos = positions[cur_batch_mask] |
| | for atom_i in range(n_atoms): |
| | atom = atoms[atom_i] |
| | atom = atom_decoder[atom] |
| | f.write("%s %.9f %.9f %.9f\n" % (atom, batch_pos[atom_i, 0], batch_pos[atom_i, 1], batch_pos[atom_i, 2])) |
| | f.close() |
| |
|
| |
|
| | def load_molecule_xyz(file, dataset_info): |
| | with open(file, encoding='utf8') as f: |
| | n_atoms = int(f.readline()) |
| | one_hot = torch.zeros(n_atoms, len(dataset_info['atom_decoder'])) |
| | positions = torch.zeros(n_atoms, 3) |
| | f.readline() |
| | atoms = f.readlines() |
| | for i in range(n_atoms): |
| | atom = atoms[i].split(' ') |
| | atom_type = atom[0] |
| | one_hot[i, dataset_info['atom_encoder'][atom_type]] = 1 |
| | position = torch.Tensor([float(e) for e in atom[1:]]) |
| | positions[i, :] = position |
| | return positions, one_hot |
| |
|
| |
|
| | def load_xyz_files(path, shuffle=True): |
| | files = glob.glob(path + "/*.xyz") |
| | if shuffle: |
| | random.shuffle(files) |
| | return files |
| |
|
| |
|
| | |
| | |
| | |
| | def draw_sphere(ax, x, y, z, size, color, alpha): |
| | u = np.linspace(0, 2 * np.pi, 100) |
| | v = np.linspace(0, np.pi, 100) |
| |
|
| | xs = size * np.outer(np.cos(u), np.sin(v)) |
| | ys = size * np.outer(np.sin(u), np.sin(v)) * 0.8 |
| | zs = size * np.outer(np.ones(np.size(u)), np.cos(v)) |
| | |
| | |
| |
|
| | ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, |
| | linewidth=0, |
| | alpha=alpha) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, |
| | dataset_info): |
| | |
| | |
| |
|
| | x = positions[:, 0] |
| | y = positions[:, 1] |
| | z = positions[:, 2] |
| | |
| |
|
| | |
| | colors_dic = np.array(dataset_info['colors_dic']) |
| | radius_dic = np.array(dataset_info['radius_dic']) |
| | area_dic = 1500 * radius_dic ** 2 |
| | |
| |
|
| | areas = area_dic[atom_type] |
| | radii = radius_dic[atom_type] |
| | colors = colors_dic[atom_type] |
| |
|
| | if spheres_3d: |
| | for i, j, k, s, c in zip(x, y, z, radii, colors): |
| | draw_sphere(ax, i.item(), j.item(), k.item(), 0.7 * s, c, alpha) |
| | else: |
| | ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, |
| | c=colors) |
| |
|
| | for i in range(len(x)): |
| | for j in range(i + 1, len(x)): |
| | p1 = np.array([x[i], y[i], z[i]]) |
| | p2 = np.array([x[j], y[j], z[j]]) |
| | dist = np.sqrt(np.sum((p1 - p2) ** 2)) |
| | atom1, atom2 = dataset_info['atom_decoder'][atom_type[i]], \ |
| | dataset_info['atom_decoder'][atom_type[j]] |
| | s = (atom_type[i], atom_type[j]) |
| |
|
| | draw_edge_int = get_bond_order(dataset_info['atom_decoder'][s[0]], |
| | dataset_info['atom_decoder'][s[1]], |
| | dist) |
| | line_width = 2 |
| |
|
| | draw_edge = draw_edge_int > 0 |
| | if draw_edge: |
| | if draw_edge_int == 4: |
| | linewidth_factor = 1.5 |
| | else: |
| | |
| | |
| | linewidth_factor = 1 |
| | ax.plot([x[i], x[j]], [y[i], y[j]], [z[i], z[j]], |
| | linewidth=line_width * linewidth_factor, |
| | c=hex_bg_color, alpha=alpha) |
| |
|
| |
|
| | def plot_data3d(positions, atom_type, dataset_info, camera_elev=0, |
| | camera_azim=0, save_path=None, spheres_3d=False, |
| | bg='black', alpha=1.): |
| | black = (0, 0, 0) |
| | white = (1, 1, 1) |
| | hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666' |
| |
|
| | from mpl_toolkits.mplot3d import Axes3D |
| | fig = plt.figure() |
| | ax = fig.add_subplot(projection='3d') |
| | ax.set_aspect('auto') |
| | ax.view_init(elev=camera_elev, azim=camera_azim) |
| | if bg == 'black': |
| | ax.set_facecolor(black) |
| | else: |
| | ax.set_facecolor(white) |
| | |
| | ax.xaxis.pane.set_alpha(0) |
| | ax.yaxis.pane.set_alpha(0) |
| | ax.zaxis.pane.set_alpha(0) |
| | ax._axis3don = False |
| |
|
| | if bg == 'black': |
| | ax.w_xaxis.line.set_color("black") |
| | else: |
| | ax.w_xaxis.line.set_color("white") |
| |
|
| | plot_molecule(ax, positions, atom_type, alpha, spheres_3d, |
| | hex_bg_color, dataset_info) |
| |
|
| | |
| | max_value = positions.abs().max().item() |
| |
|
| | |
| | axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) |
| | ax.set_xlim(-axis_lim, axis_lim) |
| | ax.set_ylim(-axis_lim, axis_lim) |
| | ax.set_zlim(-axis_lim, axis_lim) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | dpi = 120 if spheres_3d else 50 |
| |
|
| | if save_path is not None: |
| | plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi) |
| |
|
| | if spheres_3d: |
| | img = imageio.imread(save_path) |
| | img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8') |
| | imageio.imsave(save_path, img_brighter) |
| | else: |
| | plt.show() |
| | plt.close() |
| |
|
| |
|
| | def plot_data3d_uncertainty( |
| | all_positions, all_atom_types, dataset_info, camera_elev=0, |
| | camera_azim=0, |
| | save_path=None, spheres_3d=False, bg='black', alpha=1.): |
| | black = (0, 0, 0) |
| | white = (1, 1, 1) |
| | hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666' |
| |
|
| | from mpl_toolkits.mplot3d import Axes3D |
| | fig = plt.figure() |
| | ax = fig.add_subplot(projection='3d') |
| | ax.set_aspect('auto') |
| | ax.view_init(elev=camera_elev, azim=camera_azim) |
| | if bg == 'black': |
| | ax.set_facecolor(black) |
| | else: |
| | ax.set_facecolor(white) |
| | |
| | ax.xaxis.pane.set_alpha(0) |
| | ax.yaxis.pane.set_alpha(0) |
| | ax.zaxis.pane.set_alpha(0) |
| | ax._axis3don = False |
| |
|
| | if bg == 'black': |
| | ax.w_xaxis.line.set_color("black") |
| | else: |
| | ax.w_xaxis.line.set_color("white") |
| |
|
| | for i in range(len(all_positions)): |
| | positions = all_positions[i] |
| | atom_type = all_atom_types[i] |
| | plot_molecule(ax, positions, atom_type, alpha, spheres_3d, |
| | hex_bg_color, dataset_info) |
| |
|
| | if 'qm9' in dataset_info['name']: |
| | max_value = all_positions[0].abs().max().item() |
| |
|
| | |
| | axis_lim = min(40, max(max_value + 0.3, 3.2)) |
| | ax.set_xlim(-axis_lim, axis_lim) |
| | ax.set_ylim(-axis_lim, axis_lim) |
| | ax.set_zlim(-axis_lim, axis_lim) |
| | elif dataset_info['name'] == 'geom': |
| | max_value = all_positions[0].abs().max().item() |
| |
|
| | |
| | axis_lim = min(40, max(max_value / 2 + 0.3, 3.2)) |
| | ax.set_xlim(-axis_lim, axis_lim) |
| | ax.set_ylim(-axis_lim, axis_lim) |
| | ax.set_zlim(-axis_lim, axis_lim) |
| | elif dataset_info['name'] == 'pdbbind': |
| | max_value = all_positions[0].abs().max().item() |
| |
|
| | |
| | axis_lim = min(40, max(max_value / 2 + 0.3, 3.2)) |
| | ax.set_xlim(-axis_lim, axis_lim) |
| | ax.set_ylim(-axis_lim, axis_lim) |
| | ax.set_zlim(-axis_lim, axis_lim) |
| | else: |
| | raise ValueError(dataset_info['name']) |
| |
|
| | dpi = 120 if spheres_3d else 50 |
| |
|
| | if save_path is not None: |
| | plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi) |
| |
|
| | if spheres_3d: |
| | img = imageio.imread(save_path) |
| | img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8') |
| | imageio.imsave(save_path, img_brighter) |
| | else: |
| | plt.show() |
| | plt.close() |
| |
|
| |
|
| | def plot_grid(): |
| | import matplotlib.pyplot as plt |
| | from mpl_toolkits.axes_grid1 import ImageGrid |
| |
|
| | im1 = np.arange(100).reshape((10, 10)) |
| | im2 = im1.T |
| | im3 = np.flipud(im1) |
| | im4 = np.fliplr(im2) |
| |
|
| | fig = plt.figure(figsize=(10., 10.)) |
| | grid = ImageGrid(fig, 111, |
| | nrows_ncols=(6, 6), |
| | axes_pad=0.1, |
| | ) |
| |
|
| | for ax, im in zip(grid, [im1, im2, im3, im4]): |
| | |
| |
|
| | ax.imshow(im) |
| |
|
| | plt.show() |
| |
|
| |
|
| | def visualize(path, dataset_info, max_num=25, wandb=None, spheres_3d=False): |
| | files = load_xyz_files(path)[0:max_num] |
| | for file in files: |
| | positions, one_hot = load_molecule_xyz(file, dataset_info) |
| | atom_type = torch.argmax(one_hot, dim=1).numpy() |
| | dists = torch.cdist(positions.unsqueeze(0), |
| | positions.unsqueeze(0)).squeeze(0) |
| | dists = dists[dists > 0] |
| | |
| | plot_data3d(positions, atom_type, dataset_info=dataset_info, |
| | save_path=file[:-4] + '.png', |
| | spheres_3d=spheres_3d) |
| |
|
| | if wandb is not None: |
| | path = file[:-4] + '.png' |
| | |
| | im = plt.imread(path) |
| | wandb.log({'molecule': [wandb.Image(im, caption=path)]}) |
| |
|
| |
|
| | def visualize_chain(path, dataset_info, wandb=None, spheres_3d=False, |
| | mode="chain"): |
| | files = load_xyz_files(path) |
| | files = sorted(files) |
| | save_paths = [] |
| |
|
| | for i in range(len(files)): |
| | file = files[i] |
| |
|
| | positions, one_hot = load_molecule_xyz(file, dataset_info=dataset_info) |
| |
|
| | atom_type = torch.argmax(one_hot, dim=1).numpy() |
| | fn = file[:-4] + '.png' |
| | plot_data3d(positions, atom_type, dataset_info=dataset_info, |
| | save_path=fn, spheres_3d=spheres_3d, alpha=1.0) |
| | save_paths.append(fn) |
| |
|
| | imgs = [imageio.imread(fn) for fn in save_paths] |
| | dirname = os.path.dirname(save_paths[0]) |
| | gif_path = dirname + '/output.gif' |
| | print(f'Creating gif with {len(imgs)} images') |
| | |
| | |
| | imageio.mimsave(gif_path, imgs, subrectangles=True) |
| |
|
| | if wandb is not None: |
| | wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]}) |
| |
|
| |
|
| | def visualize_chain_uncertainty( |
| | path, dataset_info, wandb=None, spheres_3d=False, mode="chain"): |
| | files = load_xyz_files(path) |
| | files = sorted(files) |
| | save_paths = [] |
| |
|
| | for i in range(len(files)): |
| | if i + 2 == len(files): |
| | break |
| |
|
| | file = files[i] |
| | file2 = files[i + 1] |
| | file3 = files[i + 2] |
| |
|
| | positions, one_hot, _ = load_molecule_xyz(file, |
| | dataset_info=dataset_info) |
| | positions2, one_hot2, _ = load_molecule_xyz( |
| | file2, dataset_info=dataset_info) |
| | positions3, one_hot3, _ = load_molecule_xyz( |
| | file3, dataset_info=dataset_info) |
| |
|
| | all_positions = torch.stack([positions, positions2, positions3], dim=0) |
| | one_hot = torch.stack([one_hot, one_hot2, one_hot3], dim=0) |
| |
|
| | all_atom_type = torch.argmax(one_hot, dim=2).numpy() |
| | fn = file[:-4] + '.png' |
| | plot_data3d_uncertainty( |
| | all_positions, all_atom_type, dataset_info=dataset_info, |
| | save_path=fn, spheres_3d=spheres_3d, alpha=0.5) |
| | save_paths.append(fn) |
| |
|
| | imgs = [imageio.imread(fn) for fn in save_paths] |
| | dirname = os.path.dirname(save_paths[0]) |
| | gif_path = dirname + '/output.gif' |
| | print(f'Creating gif with {len(imgs)} images') |
| | |
| | |
| | imageio.mimsave(gif_path, imgs, subrectangles=True) |
| |
|
| | if wandb is not None: |
| | wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]}) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | |
| | import qm9.dataset as dataset |
| | from configs.datasets_config import qm9_with_h, geom_with_h |
| |
|
| | matplotlib.use('macosx') |
| |
|
| | task = "visualize_molecules" |
| | task_dataset = 'geom' |
| |
|
| | if task_dataset == 'qm9': |
| | dataset_info = qm9_with_h |
| |
|
| |
|
| | class Args: |
| | batch_size = 1 |
| | num_workers = 0 |
| | filter_n_atoms = None |
| | datadir = 'qm9/temp' |
| | dataset = 'qm9' |
| | remove_h = False |
| |
|
| |
|
| | cfg = Args() |
| |
|
| | dataloaders, charge_scale = dataset.retrieve_dataloaders(cfg) |
| |
|
| | for i, data in enumerate(dataloaders['train']): |
| | positions = data['positions'].view(-1, 3) |
| | positions_centered = positions - positions.mean(dim=0, keepdim=True) |
| | one_hot = data['one_hot'].view(-1, 5).type(torch.float32) |
| | atom_type = torch.argmax(one_hot, dim=1).numpy() |
| |
|
| | plot_data3d( |
| | positions_centered, atom_type, dataset_info=dataset_info, |
| | spheres_3d=True) |
| |
|
| | elif task_dataset == 'geom': |
| | files = load_xyz_files('outputs/data') |
| | matplotlib.use('macosx') |
| | for file in files: |
| | x, one_hot, _ = load_molecule_xyz(file, dataset_info=geom_with_h) |
| |
|
| | positions = x.view(-1, 3) |
| | positions_centered = positions - positions.mean(dim=0, keepdim=True) |
| | one_hot = one_hot.view(-1, 16).type(torch.float32) |
| | atom_type = torch.argmax(one_hot, dim=1).numpy() |
| |
|
| | mask = (x == 0).sum(1) != 3 |
| | positions_centered = positions_centered[mask] |
| | atom_type = atom_type[mask] |
| |
|
| | plot_data3d( |
| | positions_centered, atom_type, dataset_info=geom_with_h, |
| | spheres_3d=False) |
| |
|
| | else: |
| | raise ValueError(dataset) |
| |
|