Spaces:
Build error
Build error
Commit ·
89947eb
1
Parent(s): 6255e2c
now supports repeated generation runs with fixed batch size per run:
Browse files
app.py
CHANGED
|
@@ -315,6 +315,7 @@ def save_to_format(xyz_str, idx, fmt="pdb"):
|
|
| 315 |
def generate(
|
| 316 |
model_name,
|
| 317 |
num_molecules,
|
|
|
|
| 318 |
size_mode,
|
| 319 |
fixed_size,
|
| 320 |
diffusion_steps,
|
|
@@ -333,7 +334,11 @@ def generate(
|
|
| 333 |
if TASK is None:
|
| 334 |
return "", gr.update(choices=[], value=None), [], None, None
|
| 335 |
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
seed_everything(int(seed))
|
| 338 |
|
| 339 |
# 1. Override diffusion steps
|
|
@@ -347,21 +352,7 @@ def generate(
|
|
| 347 |
TASK.T = int(diffusion_steps)
|
| 348 |
|
| 349 |
try:
|
| 350 |
-
# 2.
|
| 351 |
-
# Handle "Auto" string check robustly
|
| 352 |
-
if "Auto" in size_mode:
|
| 353 |
-
if TASK.node_dist_model is not None:
|
| 354 |
-
# DistributionNodes.sample returns tensor of sizes
|
| 355 |
-
nodesxsample = TASK.node_dist_model.sample(num_molecules)
|
| 356 |
-
else:
|
| 357 |
-
# Fallback if no dist model
|
| 358 |
-
nodesxsample = torch.randint(10, 30, (num_molecules,))
|
| 359 |
-
else:
|
| 360 |
-
nodesxsample = torch.tensor([fixed_size] * num_molecules)
|
| 361 |
-
|
| 362 |
-
nodesxsample = nodesxsample.to(DEVICE).long()
|
| 363 |
-
|
| 364 |
-
# 3. Sample (auto-switch unconditional/conditional)
|
| 365 |
condition_names = get_condition_names(TASK)
|
| 366 |
is_conditional = len(condition_names) > 0
|
| 367 |
|
|
@@ -372,45 +363,58 @@ def generate(
|
|
| 372 |
negative_values = parse_condition_row(
|
| 373 |
negative_values_df, condition_names, required=False
|
| 374 |
)
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
)
|
| 378 |
-
one_hot, charges, x, node_mask = TASK.sample_guidance_conitional(
|
| 379 |
-
target_function=target_fn,
|
| 380 |
-
target_value=target_values,
|
| 381 |
-
negative_target_value=negative_values,
|
| 382 |
-
nodesxsample=nodesxsample,
|
| 383 |
-
gg_scale=0.0,
|
| 384 |
-
cfg_scale=float(cfg_scale),
|
| 385 |
-
guidance_ver="cfg",
|
| 386 |
-
n_frames=0,
|
| 387 |
-
fix_noise=False,
|
| 388 |
-
)
|
| 389 |
-
else:
|
| 390 |
-
# This returns: one_hot [B,N,C], charges [B,N,1], x [B,N,3], node_mask [B,N,1]
|
| 391 |
-
one_hot, charges, x, node_mask = TASK.sample(
|
| 392 |
-
nodesxsample=nodesxsample,
|
| 393 |
-
mode="ddpm",
|
| 394 |
-
n_frames=0,
|
| 395 |
-
fix_noise=False,
|
| 396 |
-
)
|
| 397 |
-
|
| 398 |
-
# 4. Post-process
|
| 399 |
xyz_strings = []
|
| 400 |
summary_rows = []
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
-
#
|
| 414 |
zip_path = create_xyz_zip(xyz_strings)
|
| 415 |
|
| 416 |
# Prepare table with "Name" column
|
|
@@ -746,6 +750,7 @@ with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft(), head=THREEDMOL
|
|
| 746 |
"Mode: **Conditional**" if initial_is_conditional else "Mode: **Unconditional**"
|
| 747 |
)
|
| 748 |
num_mol = gr.Slider(1, 12, value=4, step=1, label="Number of Molecules")
|
|
|
|
| 749 |
|
| 750 |
size_mode = gr.Radio(
|
| 751 |
["Auto (from training data)", "Fixed size"],
|
|
@@ -1035,6 +1040,7 @@ with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft(), head=THREEDMOL
|
|
| 1035 |
inputs=[
|
| 1036 |
model_selector,
|
| 1037 |
num_mol,
|
|
|
|
| 1038 |
size_mode,
|
| 1039 |
fixed_size,
|
| 1040 |
diffusion_steps,
|
|
|
|
| 315 |
def generate(
|
| 316 |
model_name,
|
| 317 |
num_molecules,
|
| 318 |
+
num_runs,
|
| 319 |
size_mode,
|
| 320 |
fixed_size,
|
| 321 |
diffusion_steps,
|
|
|
|
| 334 |
if TASK is None:
|
| 335 |
return "", gr.update(choices=[], value=None), [], None, None
|
| 336 |
|
| 337 |
+
total_requested = int(num_molecules) * int(num_runs)
|
| 338 |
+
print(
|
| 339 |
+
f"Generating {total_requested} molecules as {num_runs} run(s) x batch {num_molecules} "
|
| 340 |
+
f"with '{model_name}' (Steps: {diffusion_steps}, Seed: {seed})..."
|
| 341 |
+
)
|
| 342 |
seed_everything(int(seed))
|
| 343 |
|
| 344 |
# 1. Override diffusion steps
|
|
|
|
| 352 |
TASK.T = int(diffusion_steps)
|
| 353 |
|
| 354 |
try:
|
| 355 |
+
# 2. Resolve conditioning mode/inputs once
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
condition_names = get_condition_names(TASK)
|
| 357 |
is_conditional = len(condition_names) > 0
|
| 358 |
|
|
|
|
| 363 |
negative_values = parse_condition_row(
|
| 364 |
negative_values_df, condition_names, required=False
|
| 365 |
)
|
| 366 |
+
|
| 367 |
+
# 3. Sample across runs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
xyz_strings = []
|
| 369 |
summary_rows = []
|
| 370 |
+
|
| 371 |
+
target_fn = lambda z, t: torch.zeros(
|
| 372 |
+
z.size(0), device=z.device, dtype=z.dtype
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
for run_idx in range(int(num_runs)):
|
| 376 |
+
# Determine molecule sizes for this run
|
| 377 |
+
if "Auto" in size_mode:
|
| 378 |
+
if TASK.node_dist_model is not None:
|
| 379 |
+
nodesxsample = TASK.node_dist_model.sample(int(num_molecules))
|
| 380 |
+
else:
|
| 381 |
+
nodesxsample = torch.randint(10, 30, (int(num_molecules),))
|
| 382 |
+
else:
|
| 383 |
+
nodesxsample = torch.tensor([fixed_size] * int(num_molecules))
|
| 384 |
+
|
| 385 |
+
nodesxsample = nodesxsample.to(DEVICE).long()
|
| 386 |
+
|
| 387 |
+
if is_conditional:
|
| 388 |
+
one_hot, charges, x, node_mask = TASK.sample_guidance_conitional(
|
| 389 |
+
target_function=target_fn,
|
| 390 |
+
target_value=target_values,
|
| 391 |
+
negative_target_value=negative_values,
|
| 392 |
+
nodesxsample=nodesxsample,
|
| 393 |
+
gg_scale=0.0,
|
| 394 |
+
cfg_scale=float(cfg_scale),
|
| 395 |
+
guidance_ver="cfg",
|
| 396 |
+
n_frames=0,
|
| 397 |
+
fix_noise=False,
|
| 398 |
+
)
|
| 399 |
+
else:
|
| 400 |
+
one_hot, charges, x, node_mask = TASK.sample(
|
| 401 |
+
nodesxsample=nodesxsample,
|
| 402 |
+
mode="ddpm",
|
| 403 |
+
n_frames=0,
|
| 404 |
+
fix_noise=False,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
for i in range(int(num_molecules)):
|
| 408 |
+
xyz_str = tensors_to_xyz_string(
|
| 409 |
+
one_hot[i],
|
| 410 |
+
x[i],
|
| 411 |
+
node_mask[i],
|
| 412 |
+
TASK.atom_vocab
|
| 413 |
+
)
|
| 414 |
+
xyz_strings.append(xyz_str)
|
| 415 |
+
summary_rows.append(parse_composition(xyz_str))
|
| 416 |
|
| 417 |
+
# 4. Output generation – save zip for bulk download
|
| 418 |
zip_path = create_xyz_zip(xyz_strings)
|
| 419 |
|
| 420 |
# Prepare table with "Name" column
|
|
|
|
| 750 |
"Mode: **Conditional**" if initial_is_conditional else "Mode: **Unconditional**"
|
| 751 |
)
|
| 752 |
num_mol = gr.Slider(1, 12, value=4, step=1, label="Number of Molecules")
|
| 753 |
+
num_runs = gr.Slider(1, 20, value=1, step=1, label="Number of Runs")
|
| 754 |
|
| 755 |
size_mode = gr.Radio(
|
| 756 |
["Auto (from training data)", "Fixed size"],
|
|
|
|
| 1040 |
inputs=[
|
| 1041 |
model_selector,
|
| 1042 |
num_mol,
|
| 1043 |
+
num_runs,
|
| 1044 |
size_mode,
|
| 1045 |
fixed_size,
|
| 1046 |
diffusion_steps,
|