File size: 10,024 Bytes
56c4b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import argparse
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.interpolate import interp1d
import time

from solver import *


### For nRMSE evaluation
def compute_nrmse(u_computed, u_reference):
    """Computes the Normalized Root Mean Squared Error (nRMSE) between the computed solution and reference.

    Args:
        u_computed (np.ndarray): Computed solution [batch_size, len(t_coordinate), N].
        u_reference (np.ndarray): Reference solution [batch_size, len(t_coordinate), N].
    
    Returns:
        nrmse (np.float32): The normalized RMSE value.
    """
    rmse_values = np.sqrt(np.mean((u_computed - u_reference)**2, axis=(1,2)))
    u_true_norm = np.sqrt(np.mean(u_reference**2, axis=(1,2)))
    nrmse = np.mean(rmse_values / u_true_norm)
    return nrmse


### For convergence test
def init(xc, 
         modes: list =["sin", "sinsin", "Gaussian", "react", "possin"], 
         u0=1.0, 
         du=0.1):
    """Initializes one or more 1D scalar functions based on specified modes.

    Args:
        xc (np.ndarray): Cell center coordinates.
        modes (list): List of initial condition types to generate. Options include
                     "sin", "sinsin", "Gaussian", "react", and "possin".
        u0 (float): Base amplitude scaling factor.
        du (float): Secondary amplitude scaling factor for "sinsin" mode.
    
    Returns:
        np.ndarray: Stacked initial conditions with shape [len(modes), len(xc)].
    """
    initial_conditions = []
    for mode in modes:
        assert mode in ["sin", "sinsin", "Gaussian", "react", "possin"], f"mode {mode} not supported!"

        if mode == "sin":  # sinusoidal wave
            u = u0 * np.sin((xc + 1.0) * np.pi)
        elif mode == "sinsin":  # sinusoidal wave
            u = np.sin((xc + 1.0) * np.pi) + du * np.sin((xc + 1.0) * np.pi * 8.0)
        elif mode == "Gaussian":  # for diffusion check
            t0 = 1.0
            u = np.exp(-(xc**2) * np.pi / (4.0 * t0)) / np.sqrt(2.0 * t0)
        elif mode == "react":  # for reaction-diffusion eq.
            logu = -0.5 * (xc - np.pi) ** 2 / (0.25 * np.pi) ** 2
            u = np.exp(logu)
        elif mode == "possin":  # sinusoidal wave
            u = u0 * np.abs(np.sin((xc + 1.0) * np.pi))
        
        initial_conditions.append(u)
    return np.stack(initial_conditions)

def interpolate_solution(u_fine, x_fine, t_fine, x_coarse, t_coarse):
    """
    Interpolates the fine solution onto the coarse grid in both space and time.
    """
    # Interpolate in space
    space_interp_func = interp1d(x_fine, u_fine, axis=2, kind='linear', fill_value="extrapolate")
    # finding the values of the u_fine function over the grid points of x
    u_fine_interp_space = space_interp_func(x_coarse) 

    # Interpolate in time
    time_interp_func = interp1d(t_fine, u_fine_interp_space, axis=1, kind='linear', fill_value="extrapolate")
    # finding the values of the u_fine_interp_sapce function over the grid points of time.
    u_fine_interp = time_interp_func(t_coarse)

    return u_fine_interp

def compute_error(coarse_tuple, fine_tuple):
    """
    Computes the error between coarse and fine grid solutions by interpolating in both space and time.
    """
    u_coarse, x_coarse, t_coarse = coarse_tuple
    u_fine, x_fine, t_fine = fine_tuple
    u_fine_interp = interpolate_solution(u_fine, x_fine, t_fine, x_coarse, t_coarse)

    # Compute L2 norm error
    error = np.mean(np.linalg.norm(u_coarse - u_fine_interp, axis=(1,2))) / np.sqrt(u_coarse.size)
    return error

def get_x_coordinate(x_min, x_max, nx):
    dx = (x_max - x_min) / nx
    xe = np.linspace(x_min, x_max, nx+1)
    
    xc = xe[:-1] + 0.5 * dx
    return xc

def get_t_coordinate(t_min, t_max, nt):
    # t-coordinate
    it_tot = np.ceil((t_max - t_min) / nt) + 1
    tc = np.arange(it_tot + 1) * nt
    return tc

def convergence_test(nu, rho,
                     nxs=[256, 512, 1024, 2048], 
                     dts=[0.01, 0.01, 0.01, 0.01], 
                     t_min=0, t_max=2,
                     x_min=-1, x_max=1):
    print(f"##### Running convergence test for the solver #####")
    us = []
    xcs = []
    tcs = []
    for nx, dt in zip(nxs, dts):
        print(f"**** Spatio resolution {nx} ****")
        tc = get_t_coordinate(t_min, t_max, dt)
        xc = get_x_coordinate(x_min, x_max, nx)
        u0 = init(xc)
        u = solver(u0, tc, nu, rho)
        us.append(np.squeeze(np.array(u)))
        xcs.append(np.array(xc))
        tcs.append(np.array(tc))
        print(f"**** Finished ****")
    
    # now we try to compute error. 
    errors = []
    for i in range(len(nxs) - 1):
        coarse_tuple = (us[i], xcs[i], tcs[i])
        fine_tuple = (us[-1], xcs[-1], tcs[-1])
        error = compute_error(
            coarse_tuple, fine_tuple
        )
        errors.append(error)

    for i in range(len(nxs) - 2):
        rate = np.log(errors[i] / errors[i+1]) / np.log(nxs[i+1] / nxs[i])
        print(f"Error measured at spatio resolution {nxs[i]} is {errors[i]:.3e}")
        print(f"Rate of convergence measured at spatio resolution {nxs[i]} is {rate:.3f}")

    avg_rate = np.mean(
        [np.log(errors[i] / errors[i+1]) / np.log(nxs[i+1] / nxs[i]) for i in range(len(nxs) - 2)]
    )
    return avg_rate

def save_visualization(u_batch_np: np.array, u_ref_np: np.array, save_file_idx=0):
    """
    Save the visualization of u_batch and u_ref in 2D (space vs time).
    """
    difference_np = u_batch_np - u_ref_np
    fig, axs = plt.subplots(3, 1, figsize=(7, 12))

    im1 = axs[0].imshow(u_batch_np, aspect='auto', extent=[0, 1, 1, 0], cmap='viridis')
    cbar1 = fig.colorbar(im1, ax=axs[0])
    cbar1.set_label("Predicted values", fontsize=14)
    axs[0].set_xlabel("Spatial Dimension (x)", fontsize=14)
    axs[0].set_ylabel("Temporal Dimension (t)", fontsize=14)
    axs[0].set_title("Computed Solution over Space and Time", fontsize=16)

    im2 = axs[1].imshow(u_ref_np, aspect='auto', extent=[0, 1, 1, 0], cmap='viridis')
    cbar2 = fig.colorbar(im2, ax=axs[1])
    cbar2.set_label("Reference values", fontsize=14)
    axs[1].set_xlabel("Spatial Dimension (x)", fontsize=14)
    axs[1].set_ylabel("Temporal Dimension (t)", fontsize=14)
    axs[1].set_title("Reference Solution over Space and Time", fontsize=16)

    im3 = axs[2].imshow(difference_np, aspect='auto', extent=[0, 1, 1, 0], cmap='coolwarm')
    cbar3 = fig.colorbar(im3, ax=axs[2])
    cbar3.set_label("Prediction error", fontsize=14)
    axs[2].set_xlabel("Spatial Dimension (x)", fontsize=14)
    axs[2].set_ylabel("Temporal Dimension (t)", fontsize=14)
    axs[2].set_title("Prediction error over Space and Time", fontsize=16)

    plt.subplots_adjust(hspace=0.4)
    plt.savefig(os.path.join(args.save_pth, f'visualization_{save_file_idx}.png'))

def time_min_max(t_coordinate):
    return t_coordinate[0], t_coordinate[-1]

def x_coord_min_max(x_coordinate):
    return x_coordinate[0], x_coordinate[-1]

def load_data(path, is_h5py=True):
    if is_h5py:
        #TODO: make sure this works out
        with h5py.File(path, 'r') as f:
            # Do NOT modify the data loading code
            t_coordinate = np.array(f['t-coordinate'])
            u = np.array(f['tensor'])
            x_coordinate = np.array(f['x-coordinate'])
    else:
        raise NotImplementedError("Only h5py format is supported for now.")

    t_min, t_max = time_min_max(t_coordinate)
    x_min, x_max = time_min_max(x_coordinate)
    return dict(
        tensor=u,
        t_coordinate=t_coordinate,
        x_coordinate=x_coordinate,
        t_min=t_min,
        t_max=t_max,
        x_min=x_min,
        x_max=x_max
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Script for solving 1D Reaction-Diffusion Equation.")

    parser.add_argument("--save-pth", type=str,
                        default='.',
                        help="The folder to save experimental results.")

    parser.add_argument("--run-id", type=str,
                        default=0,
                        help="The id of the current run.")

    parser.add_argument("--nu", type=float,
                        default=0.5,
                        choices=[0.5, 1.0, 2.0, 5.0],
                        help="The diffusion coefficient.")

    parser.add_argument("--rho", type=float,
                        default=1.0,
                        choices=[1.0, 2.0, 5.0, 10.0],
                        help="The reaction coefficient.")
    parser.add_argument("--dataset-path-for-eval", type=str,
                        default='/usr1/username/data/CodePDE/ReactionDiffusion/ReacDiff_Nu0.5_Rho1.0.hdf5',
                        help="The path to load the dataset.")

    args = parser.parse_args()

    data_dict = load_data(args.dataset_path_for_eval, is_h5py=True)
    u = data_dict['tensor']
    t_coordinate = data_dict['t_coordinate']
    x_coordinate = data_dict['x_coordinate']

    print(f"Loaded data with shape: {u.shape}")
    # t_coordinate contains T+1 time points, i.e., 0, t_1, ..., t_T.

    # Extract test set
    u0 = u[:, 0]
    u_ref = u[:, :]

    # Hyperparameters
    batch_size, N = u0.shape
    nu, rho = args.nu, args.rho

    # Run solver
    print(f"##### Running the solver on the given dataset #####")
    start_time = time.time()
    u_batch = solver(u0, t_coordinate, nu, rho)
    end_time = time.time()
    print(f"##### Finished #####")

    # Evaluation
    nrmse = compute_nrmse(u_batch, u_ref)
    avg_rate = convergence_test(
        nu,
        rho,
        t_min=data_dict['t_min'],
        t_max=data_dict['t_max']/10,  # to save time
        x_min=data_dict['x_min'],
        x_max=data_dict['x_max']
    )
    print(f"Result summary")
    print(
        f"nRMSE: {nrmse:.3e}\t| "
        f"Time: {end_time - start_time:.2f}s\t| "
        f"Average convergence rate: {avg_rate:.3f}\t|"
    )

    # Visualization for the first sample
    save_visualization(u_batch[2], u_ref[2], args.run_id)