Spaces:
Runtime error
Runtime error
igashov
commited on
Commit
·
3c26059
1
Parent(s):
aebc0d2
multiple samples
Browse files
app.py
CHANGED
|
@@ -14,6 +14,8 @@ from src.datasets import get_dataloader, collate_with_fragment_edges, parse_mole
|
|
| 14 |
from src.lightning import DDPM
|
| 15 |
from src.linker_size_lightning import SizeClassifier
|
| 16 |
|
|
|
|
|
|
|
| 17 |
parser = argparse.ArgumentParser()
|
| 18 |
parser.add_argument('--ip', type=str, default=None)
|
| 19 |
args = parser.parse_args()
|
|
@@ -103,10 +105,8 @@ def generate(input_file):
|
|
| 103 |
molecule = read_molecule(path)
|
| 104 |
molecule = Chem.RemoveAllHs(molecule)
|
| 105 |
name = '.'.join(path.split('/')[-1].split('.')[:-1])
|
| 106 |
-
inp_sdf = f'results/{name}
|
| 107 |
-
inp_xyz = f'results/{name}
|
| 108 |
-
out_sdf = f'results/{name}_output.sdf'
|
| 109 |
-
out_xyz = f'results/{name}_output.xyz'
|
| 110 |
except Exception as e:
|
| 111 |
return f'Could not read the molecule: {e}'
|
| 112 |
|
|
@@ -133,8 +133,8 @@ def generate(input_file):
|
|
| 133 |
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 134 |
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 135 |
'num_atoms': len(positions),
|
| 136 |
-
}]
|
| 137 |
-
dataloader = get_dataloader(dataset, batch_size=
|
| 138 |
print('Created dataloader')
|
| 139 |
|
| 140 |
for data in dataloader:
|
|
@@ -142,12 +142,21 @@ def generate(input_file):
|
|
| 142 |
print('Generated linker')
|
| 143 |
x = chain[0][:, :, :ddpm.n_dims]
|
| 144 |
h = chain[0][:, :, ddpm.n_dims:]
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
print('Converted to SDF')
|
| 149 |
break
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
input_fragments_content = read_molecule_content(inp_sdf)
|
| 152 |
generated_molecule_content = read_molecule_content(out_sdf)
|
| 153 |
html = output.SAMPLES_RENDERING_TEMPLATE.format(
|
|
@@ -158,7 +167,7 @@ def generate(input_file):
|
|
| 158 |
)
|
| 159 |
return [
|
| 160 |
output.IFRAME_TEMPLATE.format(html=html),
|
| 161 |
-
[inp_sdf, inp_xyz
|
| 162 |
]
|
| 163 |
|
| 164 |
|
|
|
|
| 14 |
from src.lightning import DDPM
|
| 15 |
from src.linker_size_lightning import SizeClassifier
|
| 16 |
|
| 17 |
+
N_SAMPLES = 5
|
| 18 |
+
|
| 19 |
parser = argparse.ArgumentParser()
|
| 20 |
parser.add_argument('--ip', type=str, default=None)
|
| 21 |
args = parser.parse_args()
|
|
|
|
| 105 |
molecule = read_molecule(path)
|
| 106 |
molecule = Chem.RemoveAllHs(molecule)
|
| 107 |
name = '.'.join(path.split('/')[-1].split('.')[:-1])
|
| 108 |
+
inp_sdf = f'results/input_{name}.sdf'
|
| 109 |
+
inp_xyz = f'results/input_{name}.xyz'
|
|
|
|
|
|
|
| 110 |
except Exception as e:
|
| 111 |
return f'Could not read the molecule: {e}'
|
| 112 |
|
|
|
|
| 133 |
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 134 |
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 135 |
'num_atoms': len(positions),
|
| 136 |
+
}] * N_SAMPLES
|
| 137 |
+
dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
|
| 138 |
print('Created dataloader')
|
| 139 |
|
| 140 |
for data in dataloader:
|
|
|
|
| 142 |
print('Generated linker')
|
| 143 |
x = chain[0][:, :, :ddpm.n_dims]
|
| 144 |
h = chain[0][:, :, ddpm.n_dims:]
|
| 145 |
+
names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
|
| 146 |
+
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
| 147 |
+
print('Saved XYZ files')
|
|
|
|
| 148 |
break
|
| 149 |
|
| 150 |
+
out_files = []
|
| 151 |
+
for i in range(N_SAMPLES):
|
| 152 |
+
out_xyz = f'results/output_{i+1}_{name}_.xyz'
|
| 153 |
+
out_sdf = f'results/output_{i+1}_{name}_.sdf'
|
| 154 |
+
subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
|
| 155 |
+
out_files.append(out_xyz)
|
| 156 |
+
out_files.append(out_sdf)
|
| 157 |
+
print('Converted to SDF')
|
| 158 |
+
|
| 159 |
+
out_sdf = f'results/output_1_{name}_.sdf'
|
| 160 |
input_fragments_content = read_molecule_content(inp_sdf)
|
| 161 |
generated_molecule_content = read_molecule_content(out_sdf)
|
| 162 |
html = output.SAMPLES_RENDERING_TEMPLATE.format(
|
|
|
|
| 167 |
)
|
| 168 |
return [
|
| 169 |
output.IFRAME_TEMPLATE.format(html=html),
|
| 170 |
+
[inp_sdf, inp_xyz] + out_files,
|
| 171 |
]
|
| 172 |
|
| 173 |
|