Spaces:
Runtime error
Runtime error
Change max batch_size
Browse files- app.py +17 -7
- src/generation.py +3 -2
app.py
CHANGED
|
@@ -17,6 +17,11 @@ from src.generation import generate_linkers, try_to_convert_to_sdf, get_pocket
|
|
| 17 |
from zipfile import ZipFile
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
MODELS_METADATA = {
|
| 21 |
'geom_difflinker': {
|
| 22 |
'link': 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1',
|
|
@@ -329,9 +334,7 @@ def generate_without_pocket(input_file, n_steps, n_atoms, num_samples, selected_
|
|
| 329 |
|
| 330 |
for data in dataloader:
|
| 331 |
try:
|
| 332 |
-
generate_linkers(
|
| 333 |
-
ddpm=ddpm, data=data, num_samples=num_samples, sample_fn=sample_fn, name=name, with_pocket=False
|
| 334 |
-
)
|
| 335 |
except Exception as e:
|
| 336 |
e = str(e).replace('\'', '')
|
| 337 |
error = f'Caught exception while generating linkers: {e}'
|
|
@@ -450,7 +453,8 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples
|
|
| 450 |
dataset = MOADDataset(data=dataset)
|
| 451 |
ddpm.val_dataset = dataset
|
| 452 |
|
| 453 |
-
|
|
|
|
| 454 |
print('Created dataloader')
|
| 455 |
|
| 456 |
ddpm.edm.T = n_steps
|
|
@@ -470,10 +474,13 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples
|
|
| 470 |
def sample_fn(_data):
|
| 471 |
return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
|
| 472 |
|
| 473 |
-
for data in dataloader:
|
| 474 |
try:
|
|
|
|
| 475 |
generate_linkers(
|
| 476 |
-
ddpm=ddpm, data=data,
|
|
|
|
|
|
|
| 477 |
)
|
| 478 |
except Exception as e:
|
| 479 |
e = str(e).replace('\'', '')
|
|
@@ -520,7 +527,10 @@ with demo:
|
|
| 520 |
gr.Markdown('Upload the file of the target protein in .pdb format (optionally):')
|
| 521 |
input_protein_file = gr.File(file_count='single', label='Target Protein (Optional)')
|
| 522 |
|
| 523 |
-
n_steps = gr.Slider(
|
|
|
|
|
|
|
|
|
|
| 524 |
n_atoms = gr.Slider(
|
| 525 |
minimum=0, maximum=20,
|
| 526 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
|
|
|
| 17 |
from zipfile import ZipFile
|
| 18 |
|
| 19 |
|
| 20 |
+
MIN_N_STEPS = 100
|
| 21 |
+
MAX_N_STEPS = 500
|
| 22 |
+
MAX_BATCH_SIZE = 5
|
| 23 |
+
|
| 24 |
+
|
| 25 |
MODELS_METADATA = {
|
| 26 |
'geom_difflinker': {
|
| 27 |
'link': 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1',
|
|
|
|
| 334 |
|
| 335 |
for data in dataloader:
|
| 336 |
try:
|
| 337 |
+
generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=False)
|
|
|
|
|
|
|
| 338 |
except Exception as e:
|
| 339 |
e = str(e).replace('\'', '')
|
| 340 |
error = f'Caught exception while generating linkers: {e}'
|
|
|
|
| 453 |
dataset = MOADDataset(data=dataset)
|
| 454 |
ddpm.val_dataset = dataset
|
| 455 |
|
| 456 |
+
batch_size = min(num_samples, MAX_BATCH_SIZE)
|
| 457 |
+
dataloader = get_dataloader(dataset, batch_size=batch_size, collate_fn=collate_with_fragment_edges)
|
| 458 |
print('Created dataloader')
|
| 459 |
|
| 460 |
ddpm.edm.T = n_steps
|
|
|
|
| 474 |
def sample_fn(_data):
|
| 475 |
return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
|
| 476 |
|
| 477 |
+
for batch_i, data in enumerate(dataloader):
|
| 478 |
try:
|
| 479 |
+
offset_idx = batch_i * batch_size
|
| 480 |
generate_linkers(
|
| 481 |
+
ddpm=ddpm, data=data,
|
| 482 |
+
sample_fn=sample_fn, name=name, with_pocket=True,
|
| 483 |
+
offset_idx=offset_idx,
|
| 484 |
)
|
| 485 |
except Exception as e:
|
| 486 |
e = str(e).replace('\'', '')
|
|
|
|
| 527 |
gr.Markdown('Upload the file of the target protein in .pdb format (optionally):')
|
| 528 |
input_protein_file = gr.File(file_count='single', label='Target Protein (Optional)')
|
| 529 |
|
| 530 |
+
n_steps = gr.Slider(
|
| 531 |
+
minimum=MIN_N_STEPS, maximum=MAX_N_STEPS,
|
| 532 |
+
label="Number of Denoising Steps", step=10
|
| 533 |
+
)
|
| 534 |
n_atoms = gr.Slider(
|
| 535 |
minimum=0, maximum=20,
|
| 536 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
src/generation.py
CHANGED
|
@@ -10,7 +10,7 @@ from src.utils import FoundNaNException
|
|
| 10 |
from src.datasets import get_one_hot
|
| 11 |
|
| 12 |
|
| 13 |
-
def generate_linkers(ddpm, data,
|
| 14 |
chain = node_mask = None
|
| 15 |
for i in range(5):
|
| 16 |
try:
|
|
@@ -37,7 +37,8 @@ def generate_linkers(ddpm, data, num_samples, sample_fn, name, with_pocket=False
|
|
| 37 |
if with_pocket:
|
| 38 |
node_mask[torch.where(data['pocket_mask'])] = 0
|
| 39 |
|
| 40 |
-
|
|
|
|
| 41 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
| 42 |
print('Saved XYZ files')
|
| 43 |
|
|
|
|
| 10 |
from src.datasets import get_one_hot
|
| 11 |
|
| 12 |
|
| 13 |
+
def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False, offset_idx=0):
|
| 14 |
chain = node_mask = None
|
| 15 |
for i in range(5):
|
| 16 |
try:
|
|
|
|
| 37 |
if with_pocket:
|
| 38 |
node_mask[torch.where(data['pocket_mask'])] = 0
|
| 39 |
|
| 40 |
+
batch_size = len(data)
|
| 41 |
+
names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)]
|
| 42 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
| 43 |
print('Saved XYZ files')
|
| 44 |
|