File size: 6,346 Bytes
2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 4c7009f 2368e93 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | import argparse
import os
import yaml
import torch
import torch.nn as nn
import numpy as np
import torchdiffeq
import utils
from diff2flow import VPDiffusionFlow, dict2namespace
import datasets
from tqdm import tqdm
def ode_inverse_solve(
flow_model,
x_data,
x_cond,
steps=100,
method="dopri5",
patch_size=64,
atol=1e-5,
rtol=1e-5,
):
"""
Solves the ODE from t=0 (data) to t=1 (noise).
Returns x_1 (noise latent).
"""
# Define the drift function wrapper for torchdiffeq
# For inversion, we integrate from 0 to 1.
# The drift v(x, t) is defined for t in [0, 1].
def drift_func(t, x):
# flow_model.get_velocity expects t in [0, 1]
# When using torchdiffeq, t will be traversing 0->1.
return flow_model.get_velocity(x, t, x_cond, patch_size=patch_size)
# Time points from 0 to 1
t_eval = torch.linspace(0.0, 1.0, steps + 1, device=x_data.device)
# Solve
out = torchdiffeq.odeint(
drift_func, x_data, t_eval, method=method, atol=atol, rtol=rtol
)
# Return final state at t=1
return out[-1]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--resume", type=str, required=True)
parser.add_argument("--data_dir", type=str, default=None)
parser.add_argument("--dataset", type=str, default=None)
parser.add_argument("--steps", type=int, default=100)
parser.add_argument("--output_dir", type=str, default="reflow_data")
parser.add_argument("--seed", type=int, default=61)
parser.add_argument("--patch_size", type=int, default=64)
parser.add_argument("--method", type=str, default="dopri5")
parser.add_argument("--atol", type=float, default=1e-5)
parser.add_argument("--rtol", type=float, default=1e-5)
parser.add_argument(
"--max_images",
type=int,
default=None,
help="Max images to generate (for testing)",
)
args = parser.parse_args()
# Load Config
with open(os.path.join("configs", args.config), "r") as f:
config_dict = yaml.safe_load(f)
config = dict2namespace(config_dict)
if args.data_dir:
config.data.data_dir = args.data_dir
if args.dataset:
config.data.dataset = args.dataset
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
config.device = device
# Reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# Load Model
print("Initializing VPDiffusionFlow...")
flow = VPDiffusionFlow(args, config)
flow.load_ckpt(args.resume)
os.makedirs(args.output_dir, exist_ok=True)
# Load Dataset
print(f"Loading dataset {config.data.dataset}...")
DATASET = datasets.__dict__[config.data.dataset](config)
# We use the TRAINING set to generate pairs for training the reflow model
train_loader, _ = DATASET.get_loaders(
parse_patches=False,
validation=config.data.dataset if args.dataset else "raindrop",
)
# We want to iterate over training data. Note: get_loaders usually returns (train_loader, val_loader).
# RainDrop.get_loaders returns (train_loader, val_loader).
# train_loader usually parses patches = True for original training.
# But for generating full image pairs or consistent pairs, we might want full images or patching?
# The user asked for "very fast inference". If we train on patches, we can infer on patches (and then stitch).
# If we train on full images, that's better but memory intensive.
# The original training was likely on patches (RainDropDataset uses patch_size).
# For Reflow, we should probably train on PATCHES to match the original training distribution and efficiency.
# So let's use parse_patches=True for the loader to match training setup.
# However, to use `ode_inverse_solve`, we need to follow the flow.
# If we use patches, we can solve ODE for each patch independently.
# This is consistent.
# Re-get loaders with parse_patches=True to get training patches
train_loader, _ = DATASET.get_loaders(parse_patches=True)
print(f"Starting generation of reflow pairs...")
count = 0
# Iterate through training patches
for i, (x_batch, img_id) in enumerate(
tqdm(train_loader, desc="Generating Reflow Pairs")
):
# x_batch: [B, N, 6, H, W] if parse_patches=True
# Flatten B and N to process all patches
if x_batch.ndim == 5:
x_batch = x_batch.flatten(start_dim=0, end_dim=1)
input_img = x_batch[:, :3, :, :].to(device) # Input (Rainy)
gt_img = x_batch[:, 3:, :, :].to(device) # GT (Clean)
# Transform data to [-1, 1]
x_cond = utils.sampling.data_transform(input_img)
x_data = utils.sampling.data_transform(gt_img)
# Run ODE Inversion: x_data (t=0) -> x_noise (t=1)
# Note: patch_size argument in ode_inverse_solve usually used for stitching.
# Here x_data IS a patch (e.g. 64x64 or config size).
# So we can pass patch_size=None or just let it handle it.
# Our VPDiffusionFlow.get_velocity handles patching if x > patch_size.
# Here x is likely small.
with torch.no_grad():
x_noise = ode_inverse_solve(
flow,
x_data,
x_cond,
steps=args.steps,
method=args.method,
patch_size=args.patch_size,
atol=args.atol,
rtol=args.rtol,
)
# Save pair (x_noise, x_cond, x_data)
# x_noise is the 'target' input for the reflow model (at t=1)
# x_data is the 'target' output (at t=0)
# x_cond is the condition
# We save this batch
batch_data = {
"x_noise": x_noise.cpu(),
"x_data": x_data.cpu(),
"x_cond": x_cond.cpu(),
}
save_path = os.path.join(args.output_dir, f"batch_{i}.pth")
torch.save(batch_data, save_path)
print(f"Saved batch {i} to {save_path}")
count += input_img.shape[0]
if args.max_images and count >= args.max_images:
print(f"Reached max images {args.max_images}")
break
if __name__ == "__main__":
main()
|