Spaces:
Runtime error
Runtime error
igashov
commited on
Commit
·
c438a2a
1
Parent(s):
bec2844
update COM
Browse files
app.py
CHANGED
|
@@ -160,6 +160,12 @@ def generate(input_file, n_steps):
|
|
| 160 |
print('Generated linker')
|
| 161 |
x = chain[0][:, :, :ddpm.n_dims]
|
| 162 |
h = chain[0][:, :, ddpm.n_dims:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
|
| 164 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
| 165 |
print('Saved XYZ files')
|
|
|
|
| 160 |
print('Generated linker')
|
| 161 |
x = chain[0][:, :, :ddpm.n_dims]
|
| 162 |
h = chain[0][:, :, ddpm.n_dims:]
|
| 163 |
+
|
| 164 |
+
pos_masked = data['positions'] * data['fragment_mask']
|
| 165 |
+
N = data['fragment_mask'].sum(1, keepdims=True)
|
| 166 |
+
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
|
| 167 |
+
x = x + mean * node_mask
|
| 168 |
+
|
| 169 |
names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
|
| 170 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
| 171 |
print('Saved XYZ files')
|