| | import argparse |
| | import os |
| | import yaml |
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import torchdiffeq |
| | import utils |
| | from diff2flow import VPDiffusionFlow, dict2namespace |
| | import datasets |
| | from tqdm import tqdm |
| |
|
| |
|
| | def ode_inverse_solve( |
| | flow_model, |
| | x_data, |
| | x_cond, |
| | steps=100, |
| | method="dopri5", |
| | patch_size=64, |
| | atol=1e-5, |
| | rtol=1e-5, |
| | ): |
| | """ |
| | Solves the ODE from t=0 (data) to t=1 (noise). |
| | Returns x_1 (noise latent). |
| | """ |
| | |
| | |
| | |
| |
|
| | def drift_func(t, x): |
| | |
| | |
| | return flow_model.get_velocity(x, t, x_cond, patch_size=patch_size) |
| |
|
| | |
| | t_eval = torch.linspace(0.0, 1.0, steps + 1, device=x_data.device) |
| |
|
| | |
| | out = torchdiffeq.odeint( |
| | drift_func, x_data, t_eval, method=method, atol=atol, rtol=rtol |
| | ) |
| | |
| | return out[-1] |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", type=str, required=True) |
| | parser.add_argument("--resume", type=str, required=True) |
| | parser.add_argument("--data_dir", type=str, default=None) |
| | parser.add_argument("--dataset", type=str, default=None) |
| | parser.add_argument("--steps", type=int, default=100) |
| | parser.add_argument("--output_dir", type=str, default="reflow_data") |
| | parser.add_argument("--seed", type=int, default=61) |
| | parser.add_argument("--patch_size", type=int, default=64) |
| | parser.add_argument("--method", type=str, default="dopri5") |
| | parser.add_argument("--atol", type=float, default=1e-5) |
| | parser.add_argument("--rtol", type=float, default=1e-5) |
| | parser.add_argument( |
| | "--max_images", |
| | type=int, |
| | default=None, |
| | help="Max images to generate (for testing)", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | |
| | with open(os.path.join("configs", args.config), "r") as f: |
| | config_dict = yaml.safe_load(f) |
| | config = dict2namespace(config_dict) |
| |
|
| | if args.data_dir: |
| | config.data.data_dir = args.data_dir |
| | if args.dataset: |
| | config.data.dataset = args.dataset |
| |
|
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| | config.device = device |
| |
|
| | |
| | torch.manual_seed(args.seed) |
| | np.random.seed(args.seed) |
| |
|
| | |
| | print("Initializing VPDiffusionFlow...") |
| | flow = VPDiffusionFlow(args, config) |
| | flow.load_ckpt(args.resume) |
| |
|
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | |
| | print(f"Loading dataset {config.data.dataset}...") |
| | DATASET = datasets.__dict__[config.data.dataset](config) |
| |
|
| | |
| | train_loader, _ = DATASET.get_loaders( |
| | parse_patches=False, |
| | validation=config.data.dataset if args.dataset else "raindrop", |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | train_loader, _ = DATASET.get_loaders(parse_patches=True) |
| |
|
| | print(f"Starting generation of reflow pairs...") |
| |
|
| | count = 0 |
| |
|
| | |
| | for i, (x_batch, img_id) in enumerate( |
| | tqdm(train_loader, desc="Generating Reflow Pairs") |
| | ): |
| | |
| | |
| | if x_batch.ndim == 5: |
| | x_batch = x_batch.flatten(start_dim=0, end_dim=1) |
| |
|
| | input_img = x_batch[:, :3, :, :].to(device) |
| | gt_img = x_batch[:, 3:, :, :].to(device) |
| |
|
| | |
| | x_cond = utils.sampling.data_transform(input_img) |
| | x_data = utils.sampling.data_transform(gt_img) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | with torch.no_grad(): |
| | x_noise = ode_inverse_solve( |
| | flow, |
| | x_data, |
| | x_cond, |
| | steps=args.steps, |
| | method=args.method, |
| | patch_size=args.patch_size, |
| | atol=args.atol, |
| | rtol=args.rtol, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | batch_data = { |
| | "x_noise": x_noise.cpu(), |
| | "x_data": x_data.cpu(), |
| | "x_cond": x_cond.cpu(), |
| | } |
| |
|
| | save_path = os.path.join(args.output_dir, f"batch_{i}.pth") |
| | torch.save(batch_data, save_path) |
| |
|
| | print(f"Saved batch {i} to {save_path}") |
| |
|
| | count += input_img.shape[0] |
| | if args.max_images and count >= args.max_images: |
| | print(f"Reached max images {args.max_images}") |
| | break |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|