| """
|
| compare_3modes_speed.py — Head-to-head comparison of original vs optimized
|
| 3-mode training for N iterations.
|
|
|
| Supports 3 modes:
|
| --pipeline orig Run original pipeline only, save results to JSON
|
| --pipeline opt Run optimized pipeline only, save results to JSON
|
| --pipeline both Run both sequentially (may OOM on large batchsize)
|
|
|
| After running orig and opt separately, use --pipeline compare to print
|
| the comparison table from saved JSONs.
|
|
|
| Usage:
|
| # Separate jobs (recommended for batchsize>=3):
|
| python tests/compare_3modes_speed.py --pipeline orig --device xpu --batchsize 3
|
| python tests/compare_3modes_speed.py --pipeline opt --device xpu --batchsize 3
|
| python tests/compare_3modes_speed.py --pipeline compare
|
|
|
| # Single job (only for small batchsize):
|
| python tests/compare_3modes_speed.py --pipeline both --device xpu --batchsize 1
|
| """
|
|
|
| import os, sys, time, json, random, argparse
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| sys.path.insert(0, ROOT_DIR)
|
|
|
|
|
|
|
| IMG_SIZE = 128
|
| BATCHSIZE = 3
|
| TIMESTEPS = 80
|
| NDIMS = 3
|
| V_SCALE = 5e-5
|
| NOISE_SCALE = 0.05
|
| NET_NAME = "recmulmodmutattnnet"
|
| LR = 1e-5
|
| NUM_STEPS = 10
|
|
|
|
|
| MSK_EPS = 0.01
|
| TEXT_EMBED_PROB = 0.7
|
| AUG_RESAMPLE_PROB = 0.5
|
| LOSS_WEIGHTS_DIFF = [2.0, 2.0, 4.0]
|
| LOSS_WEIGHTS_REGIST = [1.0, 0.05, 128]
|
| LOSS_WEIGHT_CONTRASTIVE = 1.0
|
| CONTRASTIVE_STEP_RATIO = 2
|
| DIFF_REG_BATCH_RATIO = 2
|
|
|
|
|
| RESULTS_DIR = os.path.join(ROOT_DIR, "Logs")
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--pipeline", type=str, default="both",
|
| choices=["orig", "opt", "both", "compare"],
|
| help="Which pipeline to run (orig/opt/both/compare)")
|
| parser.add_argument("--device", type=str, default="xpu", choices=["cpu", "cuda", "xpu"])
|
| parser.add_argument("--steps", type=int, default=NUM_STEPS)
|
| parser.add_argument("--img-size", type=int, default=IMG_SIZE)
|
| parser.add_argument("--batchsize", type=int, default=BATCHSIZE)
|
| parser.add_argument("--results-dir", type=str, default=RESULTS_DIR,
|
| help="Directory to save/load result JSONs")
|
| args = parser.parse_args()
|
|
|
| DEVICE = args.device
|
| NUM_STEPS = args.steps
|
| IMG_SIZE = args.img_size
|
| BATCHSIZE = args.batchsize
|
| RESULTS_DIR = args.results_dir
|
|
|
|
|
| def detect_device(device):
|
| """Auto-detect device availability."""
|
| if device == "xpu":
|
| if not hasattr(torch, 'xpu') or not torch.xpu.is_available():
|
| print("XPU not available, falling back to CPU")
|
| return "cpu"
|
| print(f"XPU available: {torch.xpu.get_device_name(0)}")
|
| elif device == "cuda" and not torch.cuda.is_available():
|
| print("CUDA not available, falling back to CPU")
|
| return "cpu"
|
| return device
|
|
|
|
|
| def seed_all(seed=42):
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
|
|
|
| def generate_dummy_data(num_steps, batchsize, img_size, embd_dim=1024):
|
| """Pre-generate all random data for reproducibility."""
|
| seed_all(999)
|
| S = img_size
|
| indiv_batches = []
|
| pair_batches = []
|
| for _ in range(num_steps):
|
| x0 = torch.rand(batchsize, 1, S, S, S, dtype=torch.float32)
|
| embd = torch.randn(batchsize, embd_dim, dtype=torch.float32)
|
| indiv_batches.append((x0, embd))
|
|
|
| bp = max(1, batchsize // DIFF_REG_BATCH_RATIO)
|
| x1 = torch.rand(bp, 1, S, S, S, dtype=torch.float32)
|
| y1 = torch.rand(bp, 1, S, S, S, dtype=torch.float32)
|
| embd_x = torch.randn(bp, embd_dim, dtype=torch.float32)
|
| embd_y = torch.randn(bp, embd_dim, dtype=torch.float32)
|
| pair_batches.append((x1, y1, embd_x, embd_y))
|
| return indiv_batches, pair_batches
|
|
|
|
|
|
|
|
|
| def run_pipeline(DDPMClass, LossModule, indiv_batches, pair_batches, device, label,
|
| use_opt_net=False):
|
| """Run NUM_STEPS iterations of 3-mode training. Returns per-step losses and times."""
|
| import utils
|
| from Dataloader.dataloader_utils import thresh_img
|
|
|
| seed_all(42)
|
| if use_opt_net:
|
| from Diffusion.networks_opt import get_net_opt, OptSTN
|
| Net = get_net_opt(NET_NAME)
|
| stn_cls = OptSTN
|
| else:
|
| from Diffusion.networks import get_net, STN
|
| Net = get_net(NET_NAME)
|
| stn_cls = STN
|
|
|
| network = Net(n_steps=TIMESTEPS, ndims=NDIMS, num_input_chn=1, res=IMG_SIZE)
|
| ddpm = DDPMClass(
|
| network=network,
|
| n_steps=TIMESTEPS,
|
| image_chw=[1] + [IMG_SIZE] * NDIMS,
|
| device=device,
|
| batch_size=BATCHSIZE,
|
| img_pad_mode="zeros",
|
| v_scale=V_SCALE,
|
| )
|
| ddf_stn = stn_cls(img_sz=IMG_SIZE, ndims=NDIMS, padding_mode="border", device=device)
|
| ddpm.to(device)
|
| ddf_stn.to(device)
|
|
|
| loss_reg = LossModule.Grad(penalty=['l1', 'negdetj', 'range'], ndims=NDIMS,
|
| outrange_thresh=0.2, outrange_weight=1e3)
|
| loss_reg1 = LossModule.Grad(penalty=['l1', 'negdetj', 'range'], ndims=NDIMS,
|
| outrange_thresh=0.6, outrange_weight=1e3)
|
| loss_dist = LossModule.MRSE(img_sz=IMG_SIZE)
|
| loss_ang = LossModule.NCC(img_sz=IMG_SIZE)
|
| loss_imgsim = LossModule.MSLNCC()
|
| loss_imgmse = LossModule.LMSE()
|
|
|
|
|
| loss_imgsim.to(device)
|
| loss_imgmse.to(device)
|
|
|
| optimizer = torch.optim.Adam(ddpm.parameters(), lr=LR)
|
|
|
| ddpm.train()
|
| step_losses = []
|
| step_times = []
|
|
|
|
|
| print(f" [{label}] Warmup step...")
|
| _run_one_step(0, ddpm, ddf_stn, optimizer,
|
| loss_reg, loss_reg1, loss_dist, loss_ang, loss_imgsim, loss_imgmse,
|
| indiv_batches[0], pair_batches[0], device, seed_offset=9999)
|
|
|
| print(f" [{label}] Running {len(indiv_batches)} timed steps...")
|
| total_start = time.time()
|
|
|
| for step in range(len(indiv_batches)):
|
| step_start = time.time()
|
| losses = _run_one_step(step, ddpm, ddf_stn, optimizer,
|
| loss_reg, loss_reg1, loss_dist, loss_ang, loss_imgsim, loss_imgmse,
|
| indiv_batches[step], pair_batches[step], device, seed_offset=step)
|
|
|
| if device == "xpu":
|
| torch.xpu.synchronize()
|
| elif device == "cuda":
|
| torch.cuda.synchronize()
|
| step_time = time.time() - step_start
|
| step_losses.append(losses)
|
| step_times.append(step_time)
|
| print(f" [{label}] step {step}: diff={losses['diff']:.6f} contra={losses['contra']:.6f} regist={losses['regist']:.6f} | {step_time:.2f}s")
|
|
|
|
|
| if device == "xpu":
|
| torch.xpu.empty_cache()
|
|
|
| total_time = time.time() - total_start
|
|
|
|
|
| del ddpm, ddf_stn, optimizer
|
| if device == "xpu":
|
| torch.xpu.empty_cache()
|
| elif device == "cuda":
|
| torch.cuda.empty_cache()
|
| import gc; gc.collect()
|
|
|
| return step_losses, step_times, total_time
|
|
|
|
|
| def _run_one_step(step, ddpm, ddf_stn, optimizer,
|
| loss_reg, loss_reg1, loss_dist, loss_ang, loss_imgsim, loss_imgmse,
|
| indiv_batch, pair_batch, device, seed_offset=0):
|
| """Execute one full 3-mode training step. Returns loss dict."""
|
| import utils
|
| from Dataloader.dataloader_utils import thresh_img
|
|
|
|
|
| seed_all(1000 + seed_offset)
|
|
|
| x0, embd = indiv_batch
|
| x0 = x0.to(device).type(torch.float32)
|
| embd_dev = embd.to(device).type(torch.float32)
|
| if np.random.uniform(0, 1) < TEXT_EMBED_PROB:
|
| embd_in = embd_dev
|
| else:
|
| embd_in = None
|
|
|
| n = x0.size()[0]
|
| blind_mask = utils.get_random_deformed_mask(x0.shape[2:], apply_possibility=0.6).to(device)
|
|
|
| if NDIMS > 2:
|
| if np.random.uniform(0, 1) < AUG_RESAMPLE_PROB:
|
| x0 = utils.random_resample(x0, deform_scale=0)
|
| else:
|
| [x0] = utils.random_permute([x0], select_dims=[-1, -2, -3])
|
| if NOISE_SCALE > 0:
|
| if np.random.uniform(0, 1) < AUG_RESAMPLE_PROB:
|
| x0 = thresh_img(x0, [0, 2 * NOISE_SCALE])
|
| x0 = x0 * (np.random.normal(1, NOISE_SCALE)) + np.random.normal(0, NOISE_SCALE)
|
|
|
| t = torch.randint(0, TIMESTEPS, (n,)).to(device)
|
| proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
|
| cond_img, _, cond_ratio = ddpm.proc_cond_img(x0, proc_type=proc_type)
|
|
|
| pre_dvf_I, dvf_I = ddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask, proc_type=[], text=embd_in)
|
|
|
| loss_ddf = loss_reg(pre_dvf_I, img=x0)
|
| trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| loss_gen_d = loss_dist(pred=trm_pred, inv_lab=dvf_I, ddf_stn=None, mask=blind_mask)
|
| loss_gen_a = loss_ang(pred=trm_pred, inv_lab=dvf_I, ddf_stn=None, mask=blind_mask)
|
|
|
| loss_tot = LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d + LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| loss_tot = torch.sqrt(1. + MSK_EPS - cond_ratio) * loss_tot
|
|
|
| optimizer.zero_grad()
|
| loss_tot.backward()
|
| optimizer.step()
|
|
|
| diff_val = loss_tot.item()
|
|
|
|
|
| contra_val = 0.0
|
| if step % CONTRASTIVE_STEP_RATIO == 0:
|
| raw_network = ddpm.network
|
| t_contra = torch.randint(0, TIMESTEPS, (n,)).to(device)
|
| _ = raw_network(x=(x0 * blind_mask).detach(), y=cond_img.detach(), t=t_contra, text=None)
|
| if hasattr(raw_network, 'img_embd') and raw_network.img_embd is not None:
|
| img_embd = raw_network.img_embd
|
| loss_contra = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean())
|
| optimizer.zero_grad()
|
| loss_contra.backward()
|
| torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.05)
|
| optimizer.step()
|
| contra_val = loss_contra.item()
|
|
|
|
|
| x1, y1, _, embd_y = pair_batch
|
| if np.random.uniform(0, 1) < TEXT_EMBED_PROB:
|
| embd_y = embd_y.to(device).type(torch.float32)
|
| else:
|
| embd_y = None
|
| x1 = x1.to(device).type(torch.float32)
|
| y1 = y1.to(device).type(torch.float32)
|
| [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1, -2, -3])
|
| if NOISE_SCALE > 0:
|
| [x1, y1] = thresh_img([x1, y1], [0, 2 * NOISE_SCALE])
|
| rs = np.random.normal(1, NOISE_SCALE)
|
| rsh = np.random.normal(0, NOISE_SCALE)
|
| x1 = x1 * rs + rsh
|
| y1 = y1 * rs + rsh
|
|
|
| scale_regist = np.random.uniform(0.0, 0.7)
|
| select_timestep = 16
|
| t_pool = list(range(int(TIMESTEPS * scale_regist), TIMESTEPS))
|
| select_timestep = min(select_timestep, len(t_pool))
|
| T_regist = sorted(random.sample(t_pool, select_timestep), reverse=True)
|
| T_regist = [[t_val for _ in range(max(1, BATCHSIZE // 2))] for t_val in T_regist]
|
|
|
| proc_type_r = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
|
| y1_proc, msk_tgt, cond_ratio_r = ddpm.proc_cond_img(y1, proc_type=proc_type_r)
|
| msk_tgt = msk_tgt + MSK_EPS
|
|
|
| [ddf_comp, _], [img_rec, _, _], _ = ddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[], text=embd_y)
|
| loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt * (y1 > 0.01))
|
| loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt * (y1 >= 0.0))
|
| loss_ddf1 = loss_reg1(ddf_comp, img=y1)
|
|
|
| loss_regist = LOSS_WEIGHTS_REGIST[0] * loss_sim + LOSS_WEIGHTS_REGIST[1] * loss_mse + LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
| loss_regist = torch.sqrt(cond_ratio_r + MSK_EPS) * loss_regist
|
| optimizer.zero_grad()
|
| loss_regist.backward()
|
| torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.2)
|
| optimizer.step()
|
|
|
| regist_val = loss_regist.item()
|
|
|
| return {'diff': diff_val, 'contra': contra_val, 'regist': regist_val}
|
|
|
|
|
|
|
|
|
| def save_results(label, step_losses, step_times, total_time, results_dir):
|
| """Save pipeline results to JSON."""
|
| data = {
|
| 'label': label,
|
| 'device': DEVICE,
|
| 'img_size': IMG_SIZE,
|
| 'batchsize': BATCHSIZE,
|
| 'num_steps': NUM_STEPS,
|
| 'step_losses': step_losses,
|
| 'step_times': step_times,
|
| 'total_time': total_time,
|
| }
|
| os.makedirs(results_dir, exist_ok=True)
|
| path = os.path.join(results_dir, f"compare_{label}.json")
|
| with open(path, 'w') as f:
|
| json.dump(data, f, indent=2)
|
| print(f"\nResults saved to {path}")
|
| return path
|
|
|
|
|
| def load_results(label, results_dir):
|
| """Load pipeline results from JSON."""
|
| path = os.path.join(results_dir, f"compare_{label}.json")
|
| if not os.path.exists(path):
|
| print(f"ERROR: Results file not found: {path}")
|
| print(f"Run with --pipeline {label} first.")
|
| sys.exit(1)
|
| with open(path) as f:
|
| return json.load(f)
|
|
|
|
|
|
|
|
|
| def print_comparison(orig_data, opt_data):
|
| """Print loss and timing comparison table."""
|
| print("\n" + "=" * 70)
|
| print("LOSS COMPARISON (Original vs Optimized)")
|
| print(f"Device={orig_data['device']}, IMG_SIZE={orig_data['img_size']}, "
|
| f"BATCHSIZE={orig_data['batchsize']}, STEPS={orig_data['num_steps']}")
|
| print("=" * 70)
|
| print(f"{'Step':>4} {'Diff_Orig':>12} {'Diff_Opt':>12} {'Match':>6} "
|
| f"{'Contra_Orig':>12} {'Contra_Opt':>12} {'Match':>6} "
|
| f"{'Regist_Orig':>12} {'Regist_Opt':>12} {'Match':>6}")
|
|
|
| n = min(len(orig_data['step_losses']), len(opt_data['step_losses']))
|
| all_match = True
|
| for i in range(n):
|
| o = orig_data['step_losses'][i]
|
| p = opt_data['step_losses'][i]
|
| dm = "YES" if abs(o['diff'] - p['diff']) < 1e-4 else "NO"
|
| cm = "YES" if abs(o['contra'] - p['contra']) < 1e-4 else "NO"
|
| rm = "YES" if abs(o['regist'] - p['regist']) < 1e-4 else "NO"
|
| if dm == "NO" or cm == "NO" or rm == "NO":
|
| all_match = False
|
| print(f"{i:>4} {o['diff']:>12.6f} {p['diff']:>12.6f} {dm:>6} "
|
| f"{o['contra']:>12.6f} {p['contra']:>12.6f} {cm:>6} "
|
| f"{o['regist']:>12.6f} {p['regist']:>12.6f} {rm:>6}")
|
|
|
| print("\n" + "=" * 70)
|
| print("TIMING COMPARISON")
|
| print("=" * 70)
|
| print(f"{'Step':>4} {'Orig (s)':>10} {'Opt (s)':>10} {'Speedup':>10}")
|
| for i in range(n):
|
| ot = orig_data['step_times'][i]
|
| pt = opt_data['step_times'][i]
|
| sp = ot / pt if pt > 0 else float('inf')
|
| print(f"{i:>4} {ot:>10.2f} {pt:>10.2f} {sp:>9.2f}x")
|
|
|
| avg_orig = np.mean(orig_data['step_times'][:n])
|
| avg_opt = np.mean(opt_data['step_times'][:n])
|
| avg_speedup = avg_orig / avg_opt if avg_opt > 0 else float('inf')
|
|
|
| print(f"\n{'Avg':>4} {avg_orig:>10.2f} {avg_opt:>10.2f} {avg_speedup:>9.2f}x")
|
| print(f"Total: ORIG={orig_data['total_time']:.2f}s OPT={opt_data['total_time']:.2f}s "
|
| f"Speedup={orig_data['total_time']/opt_data['total_time']:.2f}x")
|
|
|
| print("\n" + "=" * 70)
|
| print(f"Losses identical: {'YES' if all_match else 'NO'}")
|
| print(f"Average speedup: {avg_speedup:.2f}x")
|
| print("=" * 70)
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| if args.pipeline == "compare":
|
|
|
| orig_data = load_results("orig", RESULTS_DIR)
|
| opt_data = load_results("opt", RESULTS_DIR)
|
| print_comparison(orig_data, opt_data)
|
| sys.exit(0)
|
|
|
|
|
| DEVICE = detect_device(DEVICE)
|
|
|
| print("=" * 70)
|
| print(f"3-Mode Training: Speed Comparison (pipeline={args.pipeline})")
|
| print(f"Device={DEVICE}, IMG_SIZE={IMG_SIZE}, BATCHSIZE={BATCHSIZE}, STEPS={NUM_STEPS}")
|
| print("=" * 70)
|
|
|
| print("\nPre-generating dummy data...")
|
| indiv_batches, pair_batches = generate_dummy_data(NUM_STEPS, BATCHSIZE, IMG_SIZE)
|
|
|
| if args.pipeline in ("orig", "both"):
|
| from Diffusion.diffuser import DeformDDPM as OrigDeformDDPM
|
| import Diffusion.losses as orig_losses_mod
|
|
|
| print("\n" + "-" * 70)
|
| print("Running ORIGINAL pipeline (OM_train_3modes.py logic)")
|
| print("-" * 70)
|
| orig_losses_list, orig_times, orig_total = run_pipeline(
|
| OrigDeformDDPM, orig_losses_mod, indiv_batches, pair_batches, DEVICE, "ORIG",
|
| use_opt_net=False)
|
| save_results("orig", orig_losses_list, orig_times, orig_total, RESULTS_DIR)
|
|
|
| if args.pipeline in ("opt", "both"):
|
| from Diffusion.diffuser_opt import DeformDDPM as OptDeformDDPM
|
| import Diffusion.losses_opt as opt_losses_mod
|
|
|
|
|
| if args.pipeline == "both":
|
| indiv_batches, pair_batches = generate_dummy_data(NUM_STEPS, BATCHSIZE, IMG_SIZE)
|
|
|
| print("\n" + "-" * 70)
|
| print("Running OPTIMIZED pipeline (OM_train_3modes_opt.py logic)")
|
| print("-" * 70)
|
| opt_losses_list, opt_times, opt_total = run_pipeline(
|
| OptDeformDDPM, opt_losses_mod, indiv_batches, pair_batches, DEVICE, "OPT",
|
| use_opt_net=True)
|
| save_results("opt", opt_losses_list, opt_times, opt_total, RESULTS_DIR)
|
|
|
| if args.pipeline == "both":
|
| orig_data = load_results("orig", RESULTS_DIR)
|
| opt_data = load_results("opt", RESULTS_DIR)
|
| print_comparison(orig_data, opt_data)
|
|
|