File size: 18,185 Bytes
2af0e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
"""

compare_3modes_speed.py — Head-to-head comparison of original vs optimized

3-mode training for N iterations.



Supports 3 modes:

  --pipeline orig    Run original pipeline only, save results to JSON

  --pipeline opt     Run optimized pipeline only, save results to JSON

  --pipeline both    Run both sequentially (may OOM on large batchsize)



After running orig and opt separately, use --pipeline compare to print

the comparison table from saved JSONs.



Usage:

    # Separate jobs (recommended for batchsize>=3):

    python tests/compare_3modes_speed.py --pipeline orig --device xpu --batchsize 3

    python tests/compare_3modes_speed.py --pipeline opt  --device xpu --batchsize 3

    python tests/compare_3modes_speed.py --pipeline compare



    # Single job (only for small batchsize):

    python tests/compare_3modes_speed.py --pipeline both --device xpu --batchsize 1

"""

import os, sys, time, json, random, argparse
import numpy as np
import torch
import torch.nn.functional as F

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT_DIR)

# ========================== Config ==========================

IMG_SIZE = 128
BATCHSIZE = 3
TIMESTEPS = 80
NDIMS = 3
V_SCALE = 5e-5
NOISE_SCALE = 0.05
NET_NAME = "recmulmodmutattnnet"
LR = 1e-5
NUM_STEPS = 10

# Training constants (must match both scripts)
MSK_EPS = 0.01
TEXT_EMBED_PROB = 0.7
AUG_RESAMPLE_PROB = 0.5
LOSS_WEIGHTS_DIFF = [2.0, 2.0, 4.0]
LOSS_WEIGHTS_REGIST = [1.0, 0.05, 128]
LOSS_WEIGHT_CONTRASTIVE = 1.0
CONTRASTIVE_STEP_RATIO = 2
DIFF_REG_BATCH_RATIO = 2

# Output directory for results
RESULTS_DIR = os.path.join(ROOT_DIR, "Logs")

parser = argparse.ArgumentParser()
parser.add_argument("--pipeline", type=str, default="both",
                    choices=["orig", "opt", "both", "compare"],
                    help="Which pipeline to run (orig/opt/both/compare)")
parser.add_argument("--device", type=str, default="xpu", choices=["cpu", "cuda", "xpu"])
parser.add_argument("--steps", type=int, default=NUM_STEPS)
parser.add_argument("--img-size", type=int, default=IMG_SIZE)
parser.add_argument("--batchsize", type=int, default=BATCHSIZE)
parser.add_argument("--results-dir", type=str, default=RESULTS_DIR,
                    help="Directory to save/load result JSONs")
args = parser.parse_args()

DEVICE = args.device
NUM_STEPS = args.steps
IMG_SIZE = args.img_size
BATCHSIZE = args.batchsize
RESULTS_DIR = args.results_dir


def detect_device(device):
    """Auto-detect device availability."""
    if device == "xpu":
        if not hasattr(torch, 'xpu') or not torch.xpu.is_available():
            print("XPU not available, falling back to CPU")
            return "cpu"
        print(f"XPU available: {torch.xpu.get_device_name(0)}")
    elif device == "cuda" and not torch.cuda.is_available():
        print("CUDA not available, falling back to CPU")
        return "cpu"
    return device


def seed_all(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# ========================== Data Generation ==========================

def generate_dummy_data(num_steps, batchsize, img_size, embd_dim=1024):
    """Pre-generate all random data for reproducibility."""
    seed_all(999)
    S = img_size
    indiv_batches = []
    pair_batches = []
    for _ in range(num_steps):
        x0 = torch.rand(batchsize, 1, S, S, S, dtype=torch.float32)
        embd = torch.randn(batchsize, embd_dim, dtype=torch.float32)
        indiv_batches.append((x0, embd))

        bp = max(1, batchsize // DIFF_REG_BATCH_RATIO)
        x1 = torch.rand(bp, 1, S, S, S, dtype=torch.float32)
        y1 = torch.rand(bp, 1, S, S, S, dtype=torch.float32)
        embd_x = torch.randn(bp, embd_dim, dtype=torch.float32)
        embd_y = torch.randn(bp, embd_dim, dtype=torch.float32)
        pair_batches.append((x1, y1, embd_x, embd_y))
    return indiv_batches, pair_batches


# ========================== Single Pipeline Run ==========================

def run_pipeline(DDPMClass, LossModule, indiv_batches, pair_batches, device, label,

                 use_opt_net=False):
    """Run NUM_STEPS iterations of 3-mode training. Returns per-step losses and times."""
    import utils
    from Dataloader.dataloader_utils import thresh_img

    seed_all(42)
    if use_opt_net:
        from Diffusion.networks_opt import get_net_opt, OptSTN
        Net = get_net_opt(NET_NAME)
        stn_cls = OptSTN
    else:
        from Diffusion.networks import get_net, STN
        Net = get_net(NET_NAME)
        stn_cls = STN

    network = Net(n_steps=TIMESTEPS, ndims=NDIMS, num_input_chn=1, res=IMG_SIZE)
    ddpm = DDPMClass(
        network=network,
        n_steps=TIMESTEPS,
        image_chw=[1] + [IMG_SIZE] * NDIMS,
        device=device,
        batch_size=BATCHSIZE,
        img_pad_mode="zeros",
        v_scale=V_SCALE,
    )
    ddf_stn = stn_cls(img_sz=IMG_SIZE, ndims=NDIMS, padding_mode="border", device=device)
    ddpm.to(device)
    ddf_stn.to(device)

    loss_reg = LossModule.Grad(penalty=['l1', 'negdetj', 'range'], ndims=NDIMS,
                                outrange_thresh=0.2, outrange_weight=1e3)
    loss_reg1 = LossModule.Grad(penalty=['l1', 'negdetj', 'range'], ndims=NDIMS,
                                 outrange_thresh=0.6, outrange_weight=1e3)
    loss_dist = LossModule.MRSE(img_sz=IMG_SIZE)
    loss_ang = LossModule.NCC(img_sz=IMG_SIZE)
    loss_imgsim = LossModule.MSLNCC()
    loss_imgmse = LossModule.LMSE()

    # Move buffer-based losses to device
    loss_imgsim.to(device)
    loss_imgmse.to(device)

    optimizer = torch.optim.Adam(ddpm.parameters(), lr=LR)

    ddpm.train()
    step_losses = []
    step_times = []

    # Warmup: 1 step (not timed)
    print(f"  [{label}] Warmup step...")
    _run_one_step(0, ddpm, ddf_stn, optimizer,
                  loss_reg, loss_reg1, loss_dist, loss_ang, loss_imgsim, loss_imgmse,
                  indiv_batches[0], pair_batches[0], device, seed_offset=9999)

    print(f"  [{label}] Running {len(indiv_batches)} timed steps...")
    total_start = time.time()

    for step in range(len(indiv_batches)):
        step_start = time.time()
        losses = _run_one_step(step, ddpm, ddf_stn, optimizer,
                               loss_reg, loss_reg1, loss_dist, loss_ang, loss_imgsim, loss_imgmse,
                               indiv_batches[step], pair_batches[step], device, seed_offset=step)
        # Synchronize device
        if device == "xpu":
            torch.xpu.synchronize()
        elif device == "cuda":
            torch.cuda.synchronize()
        step_time = time.time() - step_start
        step_losses.append(losses)
        step_times.append(step_time)
        print(f"  [{label}] step {step}: diff={losses['diff']:.6f} contra={losses['contra']:.6f} regist={losses['regist']:.6f} | {step_time:.2f}s")

        # Free XPU memory between steps to avoid fragmentation-induced OOM
        if device == "xpu":
            torch.xpu.empty_cache()

    total_time = time.time() - total_start

    # Cleanup
    del ddpm, ddf_stn, optimizer
    if device == "xpu":
        torch.xpu.empty_cache()
    elif device == "cuda":
        torch.cuda.empty_cache()
    import gc; gc.collect()

    return step_losses, step_times, total_time


def _run_one_step(step, ddpm, ddf_stn, optimizer,

                  loss_reg, loss_reg1, loss_dist, loss_ang, loss_imgsim, loss_imgmse,

                  indiv_batch, pair_batch, device, seed_offset=0):
    """Execute one full 3-mode training step. Returns loss dict."""
    import utils
    from Dataloader.dataloader_utils import thresh_img

    # Seed for reproducibility of augmentation/proc_type choices
    seed_all(1000 + seed_offset)

    x0, embd = indiv_batch
    x0 = x0.to(device).type(torch.float32)
    embd_dev = embd.to(device).type(torch.float32)
    if np.random.uniform(0, 1) < TEXT_EMBED_PROB:
        embd_in = embd_dev
    else:
        embd_in = None

    n = x0.size()[0]
    blind_mask = utils.get_random_deformed_mask(x0.shape[2:], apply_possibility=0.6).to(device)

    if NDIMS > 2:
        if np.random.uniform(0, 1) < AUG_RESAMPLE_PROB:
            x0 = utils.random_resample(x0, deform_scale=0)
        else:
            [x0] = utils.random_permute([x0], select_dims=[-1, -2, -3])
    if NOISE_SCALE > 0:
        if np.random.uniform(0, 1) < AUG_RESAMPLE_PROB:
            x0 = thresh_img(x0, [0, 2 * NOISE_SCALE])
        x0 = x0 * (np.random.normal(1, NOISE_SCALE)) + np.random.normal(0, NOISE_SCALE)

    t = torch.randint(0, TIMESTEPS, (n,)).to(device)
    proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
    cond_img, _, cond_ratio = ddpm.proc_cond_img(x0, proc_type=proc_type)

    pre_dvf_I, dvf_I = ddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask, proc_type=[], text=embd_in)

    loss_ddf = loss_reg(pre_dvf_I, img=x0)
    trm_pred = ddf_stn(pre_dvf_I, dvf_I)
    loss_gen_d = loss_dist(pred=trm_pred, inv_lab=dvf_I, ddf_stn=None, mask=blind_mask)
    loss_gen_a = loss_ang(pred=trm_pred, inv_lab=dvf_I, ddf_stn=None, mask=blind_mask)

    loss_tot = LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d + LOSS_WEIGHTS_DIFF[2] * loss_ddf
    loss_tot = torch.sqrt(1. + MSK_EPS - cond_ratio) * loss_tot

    optimizer.zero_grad()
    loss_tot.backward()
    optimizer.step()

    diff_val = loss_tot.item()

    # --- Contrastive ---
    contra_val = 0.0
    if step % CONTRASTIVE_STEP_RATIO == 0:
        raw_network = ddpm.network
        t_contra = torch.randint(0, TIMESTEPS, (n,)).to(device)
        _ = raw_network(x=(x0 * blind_mask).detach(), y=cond_img.detach(), t=t_contra, text=None)
        if hasattr(raw_network, 'img_embd') and raw_network.img_embd is not None:
            img_embd = raw_network.img_embd
            loss_contra = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean())
            optimizer.zero_grad()
            loss_contra.backward()
            torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.05)
            optimizer.step()
            contra_val = loss_contra.item()

    # --- Registration ---
    x1, y1, _, embd_y = pair_batch
    if np.random.uniform(0, 1) < TEXT_EMBED_PROB:
        embd_y = embd_y.to(device).type(torch.float32)
    else:
        embd_y = None
    x1 = x1.to(device).type(torch.float32)
    y1 = y1.to(device).type(torch.float32)
    [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1, -2, -3])
    if NOISE_SCALE > 0:
        [x1, y1] = thresh_img([x1, y1], [0, 2 * NOISE_SCALE])
        rs = np.random.normal(1, NOISE_SCALE)
        rsh = np.random.normal(0, NOISE_SCALE)
        x1 = x1 * rs + rsh
        y1 = y1 * rs + rsh

    scale_regist = np.random.uniform(0.0, 0.7)
    select_timestep = 16  # fixed for both
    t_pool = list(range(int(TIMESTEPS * scale_regist), TIMESTEPS))
    select_timestep = min(select_timestep, len(t_pool))
    T_regist = sorted(random.sample(t_pool, select_timestep), reverse=True)
    T_regist = [[t_val for _ in range(max(1, BATCHSIZE // 2))] for t_val in T_regist]

    proc_type_r = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
    y1_proc, msk_tgt, cond_ratio_r = ddpm.proc_cond_img(y1, proc_type=proc_type_r)
    msk_tgt = msk_tgt + MSK_EPS

    [ddf_comp, _], [img_rec, _, _], _ = ddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[], text=embd_y)
    loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt * (y1 > 0.01))
    loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt * (y1 >= 0.0))
    loss_ddf1 = loss_reg1(ddf_comp, img=y1)

    loss_regist = LOSS_WEIGHTS_REGIST[0] * loss_sim + LOSS_WEIGHTS_REGIST[1] * loss_mse + LOSS_WEIGHTS_REGIST[2] * loss_ddf1
    loss_regist = torch.sqrt(cond_ratio_r + MSK_EPS) * loss_regist
    optimizer.zero_grad()
    loss_regist.backward()
    torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.2)
    optimizer.step()

    regist_val = loss_regist.item()

    return {'diff': diff_val, 'contra': contra_val, 'regist': regist_val}


# ========================== Save / Load Results ==========================

def save_results(label, step_losses, step_times, total_time, results_dir):
    """Save pipeline results to JSON."""
    data = {
        'label': label,
        'device': DEVICE,
        'img_size': IMG_SIZE,
        'batchsize': BATCHSIZE,
        'num_steps': NUM_STEPS,
        'step_losses': step_losses,
        'step_times': step_times,
        'total_time': total_time,
    }
    os.makedirs(results_dir, exist_ok=True)
    path = os.path.join(results_dir, f"compare_{label}.json")
    with open(path, 'w') as f:
        json.dump(data, f, indent=2)
    print(f"\nResults saved to {path}")
    return path


def load_results(label, results_dir):
    """Load pipeline results from JSON."""
    path = os.path.join(results_dir, f"compare_{label}.json")
    if not os.path.exists(path):
        print(f"ERROR: Results file not found: {path}")
        print(f"Run with --pipeline {label} first.")
        sys.exit(1)
    with open(path) as f:
        return json.load(f)


# ========================== Compare ==========================

def print_comparison(orig_data, opt_data):
    """Print loss and timing comparison table."""
    print("\n" + "=" * 70)
    print("LOSS COMPARISON (Original vs Optimized)")
    print(f"Device={orig_data['device']}, IMG_SIZE={orig_data['img_size']}, "
          f"BATCHSIZE={orig_data['batchsize']}, STEPS={orig_data['num_steps']}")
    print("=" * 70)
    print(f"{'Step':>4}  {'Diff_Orig':>12} {'Diff_Opt':>12} {'Match':>6}  "
          f"{'Contra_Orig':>12} {'Contra_Opt':>12} {'Match':>6}  "
          f"{'Regist_Orig':>12} {'Regist_Opt':>12} {'Match':>6}")

    n = min(len(orig_data['step_losses']), len(opt_data['step_losses']))
    all_match = True
    for i in range(n):
        o = orig_data['step_losses'][i]
        p = opt_data['step_losses'][i]
        dm = "YES" if abs(o['diff'] - p['diff']) < 1e-4 else "NO"
        cm = "YES" if abs(o['contra'] - p['contra']) < 1e-4 else "NO"
        rm = "YES" if abs(o['regist'] - p['regist']) < 1e-4 else "NO"
        if dm == "NO" or cm == "NO" or rm == "NO":
            all_match = False
        print(f"{i:>4}  {o['diff']:>12.6f} {p['diff']:>12.6f} {dm:>6}  "
              f"{o['contra']:>12.6f} {p['contra']:>12.6f} {cm:>6}  "
              f"{o['regist']:>12.6f} {p['regist']:>12.6f} {rm:>6}")

    print("\n" + "=" * 70)
    print("TIMING COMPARISON")
    print("=" * 70)
    print(f"{'Step':>4}  {'Orig (s)':>10} {'Opt (s)':>10} {'Speedup':>10}")
    for i in range(n):
        ot = orig_data['step_times'][i]
        pt = opt_data['step_times'][i]
        sp = ot / pt if pt > 0 else float('inf')
        print(f"{i:>4}  {ot:>10.2f} {pt:>10.2f} {sp:>9.2f}x")

    avg_orig = np.mean(orig_data['step_times'][:n])
    avg_opt = np.mean(opt_data['step_times'][:n])
    avg_speedup = avg_orig / avg_opt if avg_opt > 0 else float('inf')

    print(f"\n{'Avg':>4}  {avg_orig:>10.2f} {avg_opt:>10.2f} {avg_speedup:>9.2f}x")
    print(f"Total: ORIG={orig_data['total_time']:.2f}s  OPT={opt_data['total_time']:.2f}s  "
          f"Speedup={orig_data['total_time']/opt_data['total_time']:.2f}x")

    print("\n" + "=" * 70)
    print(f"Losses identical: {'YES' if all_match else 'NO'}")
    print(f"Average speedup:  {avg_speedup:.2f}x")
    print("=" * 70)


# ========================== Main ==========================

if __name__ == "__main__":
    if args.pipeline == "compare":
        # Just load and compare saved results
        orig_data = load_results("orig", RESULTS_DIR)
        opt_data = load_results("opt", RESULTS_DIR)
        print_comparison(orig_data, opt_data)
        sys.exit(0)

    # Detect device
    DEVICE = detect_device(DEVICE)

    print("=" * 70)
    print(f"3-Mode Training: Speed Comparison (pipeline={args.pipeline})")
    print(f"Device={DEVICE}, IMG_SIZE={IMG_SIZE}, BATCHSIZE={BATCHSIZE}, STEPS={NUM_STEPS}")
    print("=" * 70)

    print("\nPre-generating dummy data...")
    indiv_batches, pair_batches = generate_dummy_data(NUM_STEPS, BATCHSIZE, IMG_SIZE)

    if args.pipeline in ("orig", "both"):
        from Diffusion.diffuser import DeformDDPM as OrigDeformDDPM
        import Diffusion.losses as orig_losses_mod

        print("\n" + "-" * 70)
        print("Running ORIGINAL pipeline (OM_train_3modes.py logic)")
        print("-" * 70)
        orig_losses_list, orig_times, orig_total = run_pipeline(
            OrigDeformDDPM, orig_losses_mod, indiv_batches, pair_batches, DEVICE, "ORIG",
            use_opt_net=False)
        save_results("orig", orig_losses_list, orig_times, orig_total, RESULTS_DIR)

    if args.pipeline in ("opt", "both"):
        from Diffusion.diffuser_opt import DeformDDPM as OptDeformDDPM
        import Diffusion.losses_opt as opt_losses_mod

        # Re-generate data if running 'both' (original consumed the tensors on device)
        if args.pipeline == "both":
            indiv_batches, pair_batches = generate_dummy_data(NUM_STEPS, BATCHSIZE, IMG_SIZE)

        print("\n" + "-" * 70)
        print("Running OPTIMIZED pipeline (OM_train_3modes_opt.py logic)")
        print("-" * 70)
        opt_losses_list, opt_times, opt_total = run_pipeline(
            OptDeformDDPM, opt_losses_mod, indiv_batches, pair_batches, DEVICE, "OPT",
            use_opt_net=True)
        save_results("opt", opt_losses_list, opt_times, opt_total, RESULTS_DIR)

    if args.pipeline == "both":
        orig_data = load_results("orig", RESULTS_DIR)
        opt_data = load_results("opt", RESULTS_DIR)
        print_comparison(orig_data, opt_data)