iflp1908sl commited on
Commit
89947eb
·
1 Parent(s): 6255e2c

now supports repeated generation runs with fixed batch size per run:

Browse files
Files changed (1) hide show
  1. app.py +58 -52
app.py CHANGED
@@ -315,6 +315,7 @@ def save_to_format(xyz_str, idx, fmt="pdb"):
315
  def generate(
316
  model_name,
317
  num_molecules,
 
318
  size_mode,
319
  fixed_size,
320
  diffusion_steps,
@@ -333,7 +334,11 @@ def generate(
333
  if TASK is None:
334
  return "", gr.update(choices=[], value=None), [], None, None
335
 
336
- print(f"Generating {num_molecules} molecules with '{model_name}' (Steps: {diffusion_steps}, Seed: {seed})...")
 
 
 
 
337
  seed_everything(int(seed))
338
 
339
  # 1. Override diffusion steps
@@ -347,21 +352,7 @@ def generate(
347
  TASK.T = int(diffusion_steps)
348
 
349
  try:
350
- # 2. Determine molecule sizes
351
- # Handle "Auto" string check robustly
352
- if "Auto" in size_mode:
353
- if TASK.node_dist_model is not None:
354
- # DistributionNodes.sample returns tensor of sizes
355
- nodesxsample = TASK.node_dist_model.sample(num_molecules)
356
- else:
357
- # Fallback if no dist model
358
- nodesxsample = torch.randint(10, 30, (num_molecules,))
359
- else:
360
- nodesxsample = torch.tensor([fixed_size] * num_molecules)
361
-
362
- nodesxsample = nodesxsample.to(DEVICE).long()
363
-
364
- # 3. Sample (auto-switch unconditional/conditional)
365
  condition_names = get_condition_names(TASK)
366
  is_conditional = len(condition_names) > 0
367
 
@@ -372,45 +363,58 @@ def generate(
372
  negative_values = parse_condition_row(
373
  negative_values_df, condition_names, required=False
374
  )
375
- target_fn = lambda z, t: torch.zeros(
376
- z.size(0), device=z.device, dtype=z.dtype
377
- )
378
- one_hot, charges, x, node_mask = TASK.sample_guidance_conitional(
379
- target_function=target_fn,
380
- target_value=target_values,
381
- negative_target_value=negative_values,
382
- nodesxsample=nodesxsample,
383
- gg_scale=0.0,
384
- cfg_scale=float(cfg_scale),
385
- guidance_ver="cfg",
386
- n_frames=0,
387
- fix_noise=False,
388
- )
389
- else:
390
- # This returns: one_hot [B,N,C], charges [B,N,1], x [B,N,3], node_mask [B,N,1]
391
- one_hot, charges, x, node_mask = TASK.sample(
392
- nodesxsample=nodesxsample,
393
- mode="ddpm",
394
- n_frames=0,
395
- fix_noise=False,
396
- )
397
-
398
- # 4. Post-process
399
  xyz_strings = []
400
  summary_rows = []
401
-
402
- for i in range(num_molecules):
403
- # Extract single molecule data
404
- xyz_str = tensors_to_xyz_string(
405
- one_hot[i],
406
- x[i],
407
- node_mask[i],
408
- TASK.atom_vocab
409
- )
410
- xyz_strings.append(xyz_str)
411
- summary_rows.append(parse_composition(xyz_str))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
- # 5. Output generation – save zip for bulk download
414
  zip_path = create_xyz_zip(xyz_strings)
415
 
416
  # Prepare table with "Name" column
@@ -746,6 +750,7 @@ with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft(), head=THREEDMOL
746
  "Mode: **Conditional**" if initial_is_conditional else "Mode: **Unconditional**"
747
  )
748
  num_mol = gr.Slider(1, 12, value=4, step=1, label="Number of Molecules")
 
749
 
750
  size_mode = gr.Radio(
751
  ["Auto (from training data)", "Fixed size"],
@@ -1035,6 +1040,7 @@ with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft(), head=THREEDMOL
1035
  inputs=[
1036
  model_selector,
1037
  num_mol,
 
1038
  size_mode,
1039
  fixed_size,
1040
  diffusion_steps,
 
315
  def generate(
316
  model_name,
317
  num_molecules,
318
+ num_runs,
319
  size_mode,
320
  fixed_size,
321
  diffusion_steps,
 
334
  if TASK is None:
335
  return "", gr.update(choices=[], value=None), [], None, None
336
 
337
+ total_requested = int(num_molecules) * int(num_runs)
338
+ print(
339
+ f"Generating {total_requested} molecules as {num_runs} run(s) x batch {num_molecules} "
340
+ f"with '{model_name}' (Steps: {diffusion_steps}, Seed: {seed})..."
341
+ )
342
  seed_everything(int(seed))
343
 
344
  # 1. Override diffusion steps
 
352
  TASK.T = int(diffusion_steps)
353
 
354
  try:
355
+ # 2. Resolve conditioning mode/inputs once
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  condition_names = get_condition_names(TASK)
357
  is_conditional = len(condition_names) > 0
358
 
 
363
  negative_values = parse_condition_row(
364
  negative_values_df, condition_names, required=False
365
  )
366
+
367
+ # 3. Sample across runs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  xyz_strings = []
369
  summary_rows = []
370
+
371
+ target_fn = lambda z, t: torch.zeros(
372
+ z.size(0), device=z.device, dtype=z.dtype
373
+ )
374
+
375
+ for run_idx in range(int(num_runs)):
376
+ # Determine molecule sizes for this run
377
+ if "Auto" in size_mode:
378
+ if TASK.node_dist_model is not None:
379
+ nodesxsample = TASK.node_dist_model.sample(int(num_molecules))
380
+ else:
381
+ nodesxsample = torch.randint(10, 30, (int(num_molecules),))
382
+ else:
383
+ nodesxsample = torch.tensor([fixed_size] * int(num_molecules))
384
+
385
+ nodesxsample = nodesxsample.to(DEVICE).long()
386
+
387
+ if is_conditional:
388
+ one_hot, charges, x, node_mask = TASK.sample_guidance_conitional(
389
+ target_function=target_fn,
390
+ target_value=target_values,
391
+ negative_target_value=negative_values,
392
+ nodesxsample=nodesxsample,
393
+ gg_scale=0.0,
394
+ cfg_scale=float(cfg_scale),
395
+ guidance_ver="cfg",
396
+ n_frames=0,
397
+ fix_noise=False,
398
+ )
399
+ else:
400
+ one_hot, charges, x, node_mask = TASK.sample(
401
+ nodesxsample=nodesxsample,
402
+ mode="ddpm",
403
+ n_frames=0,
404
+ fix_noise=False,
405
+ )
406
+
407
+ for i in range(int(num_molecules)):
408
+ xyz_str = tensors_to_xyz_string(
409
+ one_hot[i],
410
+ x[i],
411
+ node_mask[i],
412
+ TASK.atom_vocab
413
+ )
414
+ xyz_strings.append(xyz_str)
415
+ summary_rows.append(parse_composition(xyz_str))
416
 
417
+ # 4. Output generation – save zip for bulk download
418
  zip_path = create_xyz_zip(xyz_strings)
419
 
420
  # Prepare table with "Name" column
 
750
  "Mode: **Conditional**" if initial_is_conditional else "Mode: **Unconditional**"
751
  )
752
  num_mol = gr.Slider(1, 12, value=4, step=1, label="Number of Molecules")
753
+ num_runs = gr.Slider(1, 20, value=1, step=1, label="Number of Runs")
754
 
755
  size_mode = gr.Radio(
756
  ["Auto (from training data)", "Fixed size"],
 
1040
  inputs=[
1041
  model_selector,
1042
  num_mol,
1043
+ num_runs,
1044
  size_mode,
1045
  fixed_size,
1046
  diffusion_steps,