DiffSBDD / analysis /visualization.py
mority's picture
Upload 48 files
4742cab verified
raw
history blame
16.2 kB
import torch
import numpy as np
import os
import glob
import random
import matplotlib
import imageio
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from analysis.molecule_builder import get_bond_order
##############
### Files ####
###########-->
def save_xyz_file(path, one_hot, positions, atom_decoder, id_from=0,
name='molecule', batch_mask=None):
try:
os.makedirs(path)
except OSError:
pass
if batch_mask is None:
batch_mask = torch.zeros(len(one_hot))
for batch_i in torch.unique(batch_mask):
cur_batch_mask = (batch_mask == batch_i)
n_atoms = int(torch.sum(cur_batch_mask).item())
f = open(path + name + '_' + "%03d.xyz" % (batch_i + id_from), "w")
f.write("%d\n\n" % n_atoms)
atoms = torch.argmax(one_hot[cur_batch_mask], dim=1)
batch_pos = positions[cur_batch_mask]
for atom_i in range(n_atoms):
atom = atoms[atom_i]
atom = atom_decoder[atom]
f.write("%s %.9f %.9f %.9f\n" % (atom, batch_pos[atom_i, 0], batch_pos[atom_i, 1], batch_pos[atom_i, 2]))
f.close()
def load_molecule_xyz(file, dataset_info):
with open(file, encoding='utf8') as f:
n_atoms = int(f.readline())
one_hot = torch.zeros(n_atoms, len(dataset_info['atom_decoder']))
positions = torch.zeros(n_atoms, 3)
f.readline()
atoms = f.readlines()
for i in range(n_atoms):
atom = atoms[i].split(' ')
atom_type = atom[0]
one_hot[i, dataset_info['atom_encoder'][atom_type]] = 1
position = torch.Tensor([float(e) for e in atom[1:]])
positions[i, :] = position
return positions, one_hot
def load_xyz_files(path, shuffle=True):
files = glob.glob(path + "/*.xyz")
if shuffle:
random.shuffle(files)
return files
# <----########
### Files ####
##############
def draw_sphere(ax, x, y, z, size, color, alpha):
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)
xs = size * np.outer(np.cos(u), np.sin(v))
ys = size * np.outer(np.sin(u), np.sin(v)) * 0.8 # Correct for matplotlib.
zs = size * np.outer(np.ones(np.size(u)), np.cos(v))
# for i in range(2):
# ax.plot_surface(x+random.randint(-5,5), y+random.randint(-5,5), z+random.randint(-5,5), rstride=4, cstride=4, color='b', linewidth=0, alpha=0.5)
ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color,
linewidth=0,
alpha=alpha)
# # calculate vectors for "vertical" circle
# a = np.array([-np.sin(elev / 180 * np.pi), 0, np.cos(elev / 180 * np.pi)])
# b = np.array([0, 1, 0])
# b = b * np.cos(rot) + np.cross(a, b) * np.sin(rot) + a * np.dot(a, b) * (
# 1 - np.cos(rot))
# ax.plot(np.sin(u), np.cos(u), 0, color='k', linestyle='dashed')
# horiz_front = np.linspace(0, np.pi, 100)
# ax.plot(np.sin(horiz_front), np.cos(horiz_front), 0, color='k')
# vert_front = np.linspace(np.pi / 2, 3 * np.pi / 2, 100)
# ax.plot(a[0] * np.sin(u) + b[0] * np.cos(u), b[1] * np.cos(u),
# a[2] * np.sin(u) + b[2] * np.cos(u), color='k', linestyle='dashed')
# ax.plot(a[0] * np.sin(vert_front) + b[0] * np.cos(vert_front),
# b[1] * np.cos(vert_front),
# a[2] * np.sin(vert_front) + b[2] * np.cos(vert_front), color='k')
#
# ax.view_init(elev=elev, azim=0)
def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color,
dataset_info):
# draw_sphere(ax, 0, 0, 0, 1)
# draw_sphere(ax, 1, 1, 1, 1)
x = positions[:, 0]
y = positions[:, 1]
z = positions[:, 2]
# Hydrogen, Carbon, Nitrogen, Oxygen, Flourine
# ax.set_facecolor((1.0, 0.47, 0.42))
colors_dic = np.array(dataset_info['colors_dic'])
radius_dic = np.array(dataset_info['radius_dic'])
area_dic = 1500 * radius_dic ** 2
# areas_dic = sizes_dic * sizes_dic * 3.1416
areas = area_dic[atom_type]
radii = radius_dic[atom_type]
colors = colors_dic[atom_type]
if spheres_3d:
for i, j, k, s, c in zip(x, y, z, radii, colors):
draw_sphere(ax, i.item(), j.item(), k.item(), 0.7 * s, c, alpha)
else:
ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha,
c=colors) # , linewidths=2, edgecolors='#FFFFFF')
for i in range(len(x)):
for j in range(i + 1, len(x)):
p1 = np.array([x[i], y[i], z[i]])
p2 = np.array([x[j], y[j], z[j]])
dist = np.sqrt(np.sum((p1 - p2) ** 2))
atom1, atom2 = dataset_info['atom_decoder'][atom_type[i]], \
dataset_info['atom_decoder'][atom_type[j]]
s = (atom_type[i], atom_type[j])
draw_edge_int = get_bond_order(dataset_info['atom_decoder'][s[0]],
dataset_info['atom_decoder'][s[1]],
dist)
line_width = 2
draw_edge = draw_edge_int > 0
if draw_edge:
if draw_edge_int == 4:
linewidth_factor = 1.5
else:
# linewidth_factor = draw_edge_int # Prop to number of
# edges.
linewidth_factor = 1
ax.plot([x[i], x[j]], [y[i], y[j]], [z[i], z[j]],
linewidth=line_width * linewidth_factor,
c=hex_bg_color, alpha=alpha)
def plot_data3d(positions, atom_type, dataset_info, camera_elev=0,
camera_azim=0, save_path=None, spheres_3d=False,
bg='black', alpha=1.):
black = (0, 0, 0)
white = (1, 1, 1)
hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666'
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.set_aspect('auto')
ax.view_init(elev=camera_elev, azim=camera_azim)
if bg == 'black':
ax.set_facecolor(black)
else:
ax.set_facecolor(white)
# ax.xaxis.pane.set_edgecolor('#D0D0D0')
ax.xaxis.pane.set_alpha(0)
ax.yaxis.pane.set_alpha(0)
ax.zaxis.pane.set_alpha(0)
ax._axis3don = False
if bg == 'black':
ax.w_xaxis.line.set_color("black")
else:
ax.w_xaxis.line.set_color("white")
plot_molecule(ax, positions, atom_type, alpha, spheres_3d,
hex_bg_color, dataset_info)
# if 'qm9' in dataset_info['name']:
max_value = positions.abs().max().item()
# axis_lim = 3.2
axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
ax.set_xlim(-axis_lim, axis_lim)
ax.set_ylim(-axis_lim, axis_lim)
ax.set_zlim(-axis_lim, axis_lim)
# elif dataset_info['name'] == 'geom':
# max_value = positions.abs().max().item()
#
# # axis_lim = 3.2
# axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
# ax.set_xlim(-axis_lim, axis_lim)
# ax.set_ylim(-axis_lim, axis_lim)
# ax.set_zlim(-axis_lim, axis_lim)
# elif dataset_info['name'] == 'pdbbind':
# max_value = positions.abs().max().item()
#
# # axis_lim = 3.2
# axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
# ax.set_xlim(-axis_lim, axis_lim)
# ax.set_ylim(-axis_lim, axis_lim)
# ax.set_zlim(-axis_lim, axis_lim)
# else:
# raise ValueError(dataset_info['name'])
dpi = 120 if spheres_3d else 50
if save_path is not None:
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
if spheres_3d:
img = imageio.imread(save_path)
img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
imageio.imsave(save_path, img_brighter)
else:
plt.show()
plt.close()
def plot_data3d_uncertainty(
all_positions, all_atom_types, dataset_info, camera_elev=0,
camera_azim=0,
save_path=None, spheres_3d=False, bg='black', alpha=1.):
black = (0, 0, 0)
white = (1, 1, 1)
hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666'
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.set_aspect('auto')
ax.view_init(elev=camera_elev, azim=camera_azim)
if bg == 'black':
ax.set_facecolor(black)
else:
ax.set_facecolor(white)
# ax.xaxis.pane.set_edgecolor('#D0D0D0')
ax.xaxis.pane.set_alpha(0)
ax.yaxis.pane.set_alpha(0)
ax.zaxis.pane.set_alpha(0)
ax._axis3don = False
if bg == 'black':
ax.w_xaxis.line.set_color("black")
else:
ax.w_xaxis.line.set_color("white")
for i in range(len(all_positions)):
positions = all_positions[i]
atom_type = all_atom_types[i]
plot_molecule(ax, positions, atom_type, alpha, spheres_3d,
hex_bg_color, dataset_info)
if 'qm9' in dataset_info['name']:
max_value = all_positions[0].abs().max().item()
# axis_lim = 3.2
axis_lim = min(40, max(max_value + 0.3, 3.2))
ax.set_xlim(-axis_lim, axis_lim)
ax.set_ylim(-axis_lim, axis_lim)
ax.set_zlim(-axis_lim, axis_lim)
elif dataset_info['name'] == 'geom':
max_value = all_positions[0].abs().max().item()
# axis_lim = 3.2
axis_lim = min(40, max(max_value / 2 + 0.3, 3.2))
ax.set_xlim(-axis_lim, axis_lim)
ax.set_ylim(-axis_lim, axis_lim)
ax.set_zlim(-axis_lim, axis_lim)
elif dataset_info['name'] == 'pdbbind':
max_value = all_positions[0].abs().max().item()
# axis_lim = 3.2
axis_lim = min(40, max(max_value / 2 + 0.3, 3.2))
ax.set_xlim(-axis_lim, axis_lim)
ax.set_ylim(-axis_lim, axis_lim)
ax.set_zlim(-axis_lim, axis_lim)
else:
raise ValueError(dataset_info['name'])
dpi = 120 if spheres_3d else 50
if save_path is not None:
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
if spheres_3d:
img = imageio.imread(save_path)
img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
imageio.imsave(save_path, img_brighter)
else:
plt.show()
plt.close()
def plot_grid():
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
im1 = np.arange(100).reshape((10, 10))
im2 = im1.T
im3 = np.flipud(im1)
im4 = np.fliplr(im2)
fig = plt.figure(figsize=(10., 10.))
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(6, 6), # creates 2x2 grid of axes
axes_pad=0.1, # pad between axes in inch.
)
for ax, im in zip(grid, [im1, im2, im3, im4]):
# Iterating over the grid returns the Axes.
ax.imshow(im)
plt.show()
def visualize(path, dataset_info, max_num=25, wandb=None, spheres_3d=False):
files = load_xyz_files(path)[0:max_num]
for file in files:
positions, one_hot = load_molecule_xyz(file, dataset_info)
atom_type = torch.argmax(one_hot, dim=1).numpy()
dists = torch.cdist(positions.unsqueeze(0),
positions.unsqueeze(0)).squeeze(0)
dists = dists[dists > 0]
# print("Average distance between atoms", dists.mean().item())
plot_data3d(positions, atom_type, dataset_info=dataset_info,
save_path=file[:-4] + '.png',
spheres_3d=spheres_3d)
if wandb is not None:
path = file[:-4] + '.png'
# Log image(s)
im = plt.imread(path)
wandb.log({'molecule': [wandb.Image(im, caption=path)]})
def visualize_chain(path, dataset_info, wandb=None, spheres_3d=False,
mode="chain"):
files = load_xyz_files(path)
files = sorted(files)
save_paths = []
for i in range(len(files)):
file = files[i]
positions, one_hot = load_molecule_xyz(file, dataset_info=dataset_info)
atom_type = torch.argmax(one_hot, dim=1).numpy()
fn = file[:-4] + '.png'
plot_data3d(positions, atom_type, dataset_info=dataset_info,
save_path=fn, spheres_3d=spheres_3d, alpha=1.0)
save_paths.append(fn)
imgs = [imageio.imread(fn) for fn in save_paths]
dirname = os.path.dirname(save_paths[0])
gif_path = dirname + '/output.gif'
print(f'Creating gif with {len(imgs)} images')
# Add the last frame 10 times so that the final result remains temporally.
# imgs.extend([imgs[-1]] * 10)
imageio.mimsave(gif_path, imgs, subrectangles=True)
if wandb is not None:
wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})
def visualize_chain_uncertainty(
path, dataset_info, wandb=None, spheres_3d=False, mode="chain"):
files = load_xyz_files(path)
files = sorted(files)
save_paths = []
for i in range(len(files)):
if i + 2 == len(files):
break
file = files[i]
file2 = files[i + 1]
file3 = files[i + 2]
positions, one_hot, _ = load_molecule_xyz(file,
dataset_info=dataset_info)
positions2, one_hot2, _ = load_molecule_xyz(
file2, dataset_info=dataset_info)
positions3, one_hot3, _ = load_molecule_xyz(
file3, dataset_info=dataset_info)
all_positions = torch.stack([positions, positions2, positions3], dim=0)
one_hot = torch.stack([one_hot, one_hot2, one_hot3], dim=0)
all_atom_type = torch.argmax(one_hot, dim=2).numpy()
fn = file[:-4] + '.png'
plot_data3d_uncertainty(
all_positions, all_atom_type, dataset_info=dataset_info,
save_path=fn, spheres_3d=spheres_3d, alpha=0.5)
save_paths.append(fn)
imgs = [imageio.imread(fn) for fn in save_paths]
dirname = os.path.dirname(save_paths[0])
gif_path = dirname + '/output.gif'
print(f'Creating gif with {len(imgs)} images')
# Add the last frame 10 times so that the final result remains temporally.
# imgs.extend([imgs[-1]] * 10)
imageio.mimsave(gif_path, imgs, subrectangles=True)
if wandb is not None:
wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})
if __name__ == '__main__':
# plot_grid()
import qm9.dataset as dataset
from configs.datasets_config import qm9_with_h, geom_with_h
matplotlib.use('macosx')
task = "visualize_molecules"
task_dataset = 'geom'
if task_dataset == 'qm9':
dataset_info = qm9_with_h
class Args:
batch_size = 1
num_workers = 0
filter_n_atoms = None
datadir = 'qm9/temp'
dataset = 'qm9'
remove_h = False
cfg = Args()
dataloaders, charge_scale = dataset.retrieve_dataloaders(cfg)
for i, data in enumerate(dataloaders['train']):
positions = data['positions'].view(-1, 3)
positions_centered = positions - positions.mean(dim=0, keepdim=True)
one_hot = data['one_hot'].view(-1, 5).type(torch.float32)
atom_type = torch.argmax(one_hot, dim=1).numpy()
plot_data3d(
positions_centered, atom_type, dataset_info=dataset_info,
spheres_3d=True)
elif task_dataset == 'geom':
files = load_xyz_files('outputs/data')
matplotlib.use('macosx')
for file in files:
x, one_hot, _ = load_molecule_xyz(file, dataset_info=geom_with_h)
positions = x.view(-1, 3)
positions_centered = positions - positions.mean(dim=0, keepdim=True)
one_hot = one_hot.view(-1, 16).type(torch.float32)
atom_type = torch.argmax(one_hot, dim=1).numpy()
mask = (x == 0).sum(1) != 3
positions_centered = positions_centered[mask]
atom_type = atom_type[mask]
plot_data3d(
positions_centered, atom_type, dataset_info=geom_with_h,
spheres_3d=False)
else:
raise ValueError(dataset)