File size: 6,346 Bytes
2368e93
 
 
 
 
 
 
 
 
 
 
 
4c7009f
 
 
 
 
 
 
 
 
 
 
2368e93
 
 
 
 
 
 
4c7009f
2368e93
 
 
 
 
 
 
4c7009f
2368e93
 
 
 
 
 
 
4c7009f
2368e93
 
 
 
 
 
 
 
 
 
 
 
 
4c7009f
 
 
 
 
 
2368e93
 
 
 
 
 
4c7009f
2368e93
 
 
 
4c7009f
2368e93
 
4c7009f
2368e93
 
 
4c7009f
2368e93
 
 
 
4c7009f
2368e93
4c7009f
2368e93
 
 
4c7009f
2368e93
4c7009f
 
 
 
 
2368e93
 
 
 
 
 
 
 
 
4c7009f
2368e93
 
 
4c7009f
2368e93
 
4c7009f
2368e93
4c7009f
2368e93
4c7009f
2368e93
4c7009f
 
 
2368e93
 
 
 
4c7009f
 
 
 
2368e93
 
 
4c7009f
2368e93
 
 
 
 
 
4c7009f
2368e93
 
4c7009f
 
 
 
 
2368e93
4c7009f
 
2368e93
4c7009f
2368e93
 
 
 
4c7009f
2368e93
 
 
 
4c7009f
2368e93
4c7009f
2368e93
 
4c7009f
2368e93
4c7009f
2368e93
 
 
 
 
4c7009f
2368e93
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import argparse
import os
import yaml
import torch
import torch.nn as nn
import numpy as np
import torchdiffeq
import utils
from diff2flow import VPDiffusionFlow, dict2namespace
import datasets
from tqdm import tqdm


def ode_inverse_solve(
    flow_model,
    x_data,
    x_cond,
    steps=100,
    method="dopri5",
    patch_size=64,
    atol=1e-5,
    rtol=1e-5,
):
    """
    Solves the ODE from t=0 (data) to t=1 (noise).
    Returns x_1 (noise latent).
    """
    # Define the drift function wrapper for torchdiffeq
    # For inversion, we integrate from 0 to 1.
    # The drift v(x, t) is defined for t in [0, 1].

    def drift_func(t, x):
        # flow_model.get_velocity expects t in [0, 1]
        # When using torchdiffeq, t will be traversing 0->1.
        return flow_model.get_velocity(x, t, x_cond, patch_size=patch_size)

    # Time points from 0 to 1
    t_eval = torch.linspace(0.0, 1.0, steps + 1, device=x_data.device)

    # Solve
    out = torchdiffeq.odeint(
        drift_func, x_data, t_eval, method=method, atol=atol, rtol=rtol
    )
    # Return final state at t=1
    return out[-1]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--resume", type=str, required=True)
    parser.add_argument("--data_dir", type=str, default=None)
    parser.add_argument("--dataset", type=str, default=None)
    parser.add_argument("--steps", type=int, default=100)
    parser.add_argument("--output_dir", type=str, default="reflow_data")
    parser.add_argument("--seed", type=int, default=61)
    parser.add_argument("--patch_size", type=int, default=64)
    parser.add_argument("--method", type=str, default="dopri5")
    parser.add_argument("--atol", type=float, default=1e-5)
    parser.add_argument("--rtol", type=float, default=1e-5)
    parser.add_argument(
        "--max_images",
        type=int,
        default=None,
        help="Max images to generate (for testing)",
    )
    args = parser.parse_args()

    # Load Config
    with open(os.path.join("configs", args.config), "r") as f:
        config_dict = yaml.safe_load(f)
    config = dict2namespace(config_dict)

    if args.data_dir:
        config.data.data_dir = args.data_dir
    if args.dataset:
        config.data.dataset = args.dataset

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    config.device = device

    # Reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # Load Model
    print("Initializing VPDiffusionFlow...")
    flow = VPDiffusionFlow(args, config)
    flow.load_ckpt(args.resume)

    os.makedirs(args.output_dir, exist_ok=True)

    # Load Dataset
    print(f"Loading dataset {config.data.dataset}...")
    DATASET = datasets.__dict__[config.data.dataset](config)

    # We use the TRAINING set to generate pairs for training the reflow model
    train_loader, _ = DATASET.get_loaders(
        parse_patches=False,
        validation=config.data.dataset if args.dataset else "raindrop",
    )

    # We want to iterate over training data. Note: get_loaders usually returns (train_loader, val_loader).
    # RainDrop.get_loaders returns (train_loader, val_loader).
    # train_loader usually parses patches = True for original training.
    # But for generating full image pairs or consistent pairs, we might want full images or patching?
    # The user asked for "very fast inference". If we train on patches, we can infer on patches (and then stitch).
    # If we train on full images, that's better but memory intensive.
    # The original training was likely on patches (RainDropDataset uses patch_size).
    # For Reflow, we should probably train on PATCHES to match the original training distribution and efficiency.
    # So let's use parse_patches=True for the loader to match training setup.

    # However, to use `ode_inverse_solve`, we need to follow the flow.
    # If we use patches, we can solve ODE for each patch independently.
    # This is consistent.

    # Re-get loaders with parse_patches=True to get training patches
    train_loader, _ = DATASET.get_loaders(parse_patches=True)

    print(f"Starting generation of reflow pairs...")

    count = 0

    # Iterate through training patches
    for i, (x_batch, img_id) in enumerate(
        tqdm(train_loader, desc="Generating Reflow Pairs")
    ):
        # x_batch: [B, N, 6, H, W] if parse_patches=True
        # Flatten B and N to process all patches
        if x_batch.ndim == 5:
            x_batch = x_batch.flatten(start_dim=0, end_dim=1)

        input_img = x_batch[:, :3, :, :].to(device)  # Input (Rainy)
        gt_img = x_batch[:, 3:, :, :].to(device)  # GT (Clean)

        # Transform data to [-1, 1]
        x_cond = utils.sampling.data_transform(input_img)
        x_data = utils.sampling.data_transform(gt_img)

        # Run ODE Inversion: x_data (t=0) -> x_noise (t=1)
        # Note: patch_size argument in ode_inverse_solve usually used for stitching.
        # Here x_data IS a patch (e.g. 64x64 or config size).
        # So we can pass patch_size=None or just let it handle it.
        # Our VPDiffusionFlow.get_velocity handles patching if x > patch_size.
        # Here x is likely small.

        with torch.no_grad():
            x_noise = ode_inverse_solve(
                flow,
                x_data,
                x_cond,
                steps=args.steps,
                method=args.method,
                patch_size=args.patch_size,
                atol=args.atol,
                rtol=args.rtol,
            )

        # Save pair (x_noise, x_cond, x_data)
        # x_noise is the 'target' input for the reflow model (at t=1)
        # x_data is the 'target' output (at t=0)
        # x_cond is the condition

        # We save this batch
        batch_data = {
            "x_noise": x_noise.cpu(),
            "x_data": x_data.cpu(),
            "x_cond": x_cond.cpu(),
        }

        save_path = os.path.join(args.output_dir, f"batch_{i}.pth")
        torch.save(batch_data, save_path)

        print(f"Saved batch {i} to {save_path}")

        count += input_img.shape[0]
        if args.max_images and count >= args.max_images:
            print(f"Reached max images {args.max_images}")
            break


if __name__ == "__main__":
    main()