harishaseebat92 commited on
Commit
98575a1
·
verified ·
1 Parent(s): db9fdf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +487 -180
app.py CHANGED
@@ -1,207 +1,514 @@
1
- import csv
2
- import sys
 
 
 
3
  import numpy as np
 
 
4
  import plotly.graph_objects as go
5
- import gradio as gr
 
6
  from scipy.spatial import Delaunay
7
- import traceback
8
 
9
- # --- FIX for CSV Field Size Limit ---
10
- # Increase the CSV field size limit to handle large Plotly JSON objects that are
11
- # generated when Gradio's Examples component tries to log the output.
12
- max_int = sys.maxsize
13
- while True:
14
- try:
15
- csv.field_size_limit(max_int)
16
- break
17
- except OverflowError:
18
- max_int = int(max_int / 10)
19
- # -----------------------------------------
20
-
21
-
22
- def solve_and_plot_interactive(Lx: float,
23
- Ly: float,
24
- t_max: float,
25
- M: int,
26
- Gamma: float = 0.1,
27
- Nx: int = 50,
28
- Ny: int = 50,
29
- initial: str = "gaussian",
30
- bc: str = "dirichlet"):
31
  """
32
- Solves the 2D heat equation and returns an interactive Plotly figure with a fixed axis range.
 
33
  """
34
- # --- 1. Simulation Setup (No changes here) ---
35
- x = np.linspace(0, Lx, Nx)
36
- y = np.linspace(0, Ly, Ny)
37
- dx, dy = x[1] - x[0], y[1] - y[0]
38
- dt = 0.9 / (2 * Gamma * (1/dx**2 + 1/dy**2))
39
- Nt = int(np.ceil(t_max / dt)) + 1
40
- rx, ry = Gamma * dt / dx**2, Gamma * dt / dy**2
41
- X, Y = np.meshgrid(x, y, indexing='ij')
42
-
43
- u = np.zeros((Nx, Ny))
44
- if initial == "gaussian":
45
- u = np.exp(-(((X - Lx/2)**2 + (Y - Ly/2)**2) / (2*(Lx/10)**2)))
46
- elif initial == "random":
47
- u = np.random.rand(Nx, Ny)
48
- elif initial == "sinusoidal":
49
- kx, ky = 2 * np.pi / Lx, 2 * np.pi / Ly
50
- u = np.sin(kx * X) * np.sin(ky * Y)
51
- elif initial == "step":
52
- u = np.where((X < Lx/2) & (Y < Ly/2), 1.0, 0.0)
53
-
54
- # --- 2. Solve the Heat Equation (No changes here) ---
55
- time_indices = np.linspace(0, Nt - 1, M, dtype=int)
56
- U_slider = np.zeros((M, Nx, Ny))
57
- store_idx = 0
58
- if 0 in time_indices:
59
- U_slider[store_idx] = u.copy()
60
- store_idx += 1
61
-
62
- for n in range(1, Nt):
63
- un = u.copy()
64
- u[1:-1, 1:-1] = (
65
- un[1:-1, 1:-1]
66
- + rx * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[:-2, 1:-1])
67
- + ry * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, :-2])
68
- )
69
- if bc == "dirichlet":
70
- u[0, :] = u[-1, :] = u[:, 0] = u[:, -1] = 0.0
71
- elif bc == "neumann":
72
- u[0, :], u[-1, :], u[:, 0], u[:, -1] = u[1, :], u[-2, :], u[:, 1], u[:, -2]
73
- elif bc == "periodic":
74
- u[0, :], u[-1, :], u[:, 0], u[:, -1] = un[-2, :], un[1, :], un[:, -2], un[:, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- if n in time_indices and store_idx < M:
77
- U_slider[store_idx] = u.copy()
78
- store_idx += 1
79
-
80
- # --- 3. Create a Plotly Figure for Web ---
81
-
82
- # --- FIX: Calculate the global min/max temperature for a fixed Z-axis range ---
83
- z_min = U_slider.min()
84
- z_max = U_slider.max()
85
- # Add a tiny buffer if the data is completely flat to avoid a display error
86
- if z_max == z_min:
87
- z_max += 1e-9
88
- # --------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
89
 
90
- points_2d = np.vstack([X.ravel(), Y.ravel()]).T
91
- tri = Delaunay(points_2d)
92
-
93
- fig = go.Figure()
94
-
95
- for i in range(M):
96
- time_value = (time_indices[i] / (Nt-1)) * t_max if Nt > 1 else 0
97
- z_data = U_slider[i, :, :].flatten()
98
- fig.add_trace(
99
- go.Mesh3d(
100
- x=X.flatten(), y=Y.flatten(), z=z_data,
101
- i=tri.simplices[:, 0], j=tri.simplices[:, 1], k=tri.simplices[:, 2],
102
- intensity=z_data,
103
- colorscale='Viridis',
104
- # --- FIX: Set the intensity range to the global min/max for a consistent color bar ---
105
- cmin=z_min,
106
- cmax=z_max,
107
- # -------------------------------------------------------------------------------------
108
- name=f'Time: {time_value:.2f}s',
109
- showscale=True if i == 0 else False,
110
- visible=(i == 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
- )
113
 
114
- # Slider creation logic is unchanged
115
- steps = []
116
- for i in range(len(fig.data)):
117
- time_value = (time_indices[i] / (Nt-1)) * t_max if Nt > 1 else 0
118
- step = dict(
119
- method="update",
120
- args=[{"visible": [False] * len(fig.data)}],
121
- label=f"{time_value:.2f}s"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
- step["args"][0]["visible"][i] = True
124
- steps.append(step)
125
-
126
- sliders = [dict(active=0, currentvalue={"prefix": "Time: "}, pad={"t": 50}, steps=steps)]
127
-
128
- # --- FIX: Update the layout to use the fixed axis ranges ---
129
- fig.update_layout(
130
- title=f'2D Heat Eq init={initial}, bc={bc}',
131
- scene=dict(
132
- xaxis_title='X',
133
- yaxis_title='Y',
134
- zaxis_title='Temperature',
135
- # Set the axis ranges to be fixed across all animation frames
136
- xaxis=dict(range=[0, Lx]),
137
- yaxis=dict(range=[0, Ly]),
138
- zaxis=dict(range=[z_min, z_max]),
139
- ),
140
- sliders=sliders
 
 
 
 
141
  )
142
- # -------------------------------------------------------------
143
-
144
- return fig
145
 
 
 
 
 
146
 
147
- # --- Gradio Interface Function with Error Handling (No changes here) ---
148
- def gradio_interface(lx, ly, t_max, m_steps, gamma, nx, ny, initial, bc):
149
- try:
150
- nx, ny, m_steps = int(nx), int(ny), int(m_steps)
151
- fig = solve_and_plot_interactive(
152
- Lx=lx, Ly=ly, t_max=t_max, M=m_steps, Gamma=gamma, Nx=nx, Ny=ny,
153
- initial=initial, bc=bc
154
- )
155
- return fig
156
- except Exception as e:
157
- error_text = traceback.format_exc()
158
- print(error_text)
159
- error_fig = go.Figure().update_layout(
160
- title_text="⚠️ Application Error",
161
- annotations=[dict(text=f"An error occurred: {e}", showarrow=False)]
162
- )
163
- return error_fig
164
 
165
- # --- Gradio UI Definition (No changes here) ---
166
- with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
167
- gr.Markdown("# ♨️ 2D Heat Equation Simulator")
168
- gr.Markdown("Adjust parameters and click 'Run' to generate an interactive plot directly in your browser.")
 
 
 
169
 
 
 
 
 
170
  with gr.Row():
171
  with gr.Column(scale=1):
172
  gr.Markdown("## Simulation Parameters")
173
- lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx")
174
- ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly")
175
- nx_slider = gr.Slider(10, 80, 40, 1, label="Nx (Grid Points X)")
176
- ny_slider = gr.Slider(10, 80, 40, 1, label="Ny (Grid Points Y)")
177
- t_slider = gr.Slider(0.01, 2.0, 0.5, 0.01, label="t_max")
178
- gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma")
179
- m_slider = gr.Slider(10, 80, 30, 1, label="M (Time Steps)")
180
-
181
- with gr.Row():
182
- initial_dropdown = gr.Dropdown(["gaussian", "random", "sinusoidal", "step"], value="gaussian", label="Initial")
183
- bc_dropdown = gr.Dropdown(["dirichlet", "neumann", "periodic"], value="dirichlet", label="Boundary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- run_btn = gr.Button("Run Simulation", variant="primary")
186
 
187
- with gr.Column(scale=3):
188
- plot_output = gr.Plot(label="Interactive Heatmap")
 
 
189
 
190
- inputs_list = [lx_slider, ly_slider, t_slider, m_slider, gamma_slider,
191
- nx_slider, ny_slider, initial_dropdown, bc_dropdown]
192
-
193
- run_btn.click(fn=gradio_interface, inputs=inputs_list, outputs=plot_output)
194
-
 
195
  gr.Examples(
196
- examples=[
197
- [1.0, 1.0, 0.5, 30, 0.1, 30, 30, "gaussian", "dirichlet"],
198
- [2.0, 1.0, 1.0, 40, 0.05, 40, 20, "sinusoidal", "periodic"],
199
- ],
200
- inputs=inputs_list,
201
- outputs=plot_output,
202
- fn=gradio_interface,
203
  cache_examples=False
204
  )
205
 
206
  if __name__ == "__main__":
207
- demo.launch()
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import tempfile
4
+ import gradio as gr
5
+ import cudaq
6
  import numpy as np
7
+ import cupy as cp
8
+ from pathlib import Path
9
  import plotly.graph_objects as go
10
+ import plotly.io as pio
11
+ import imageio # Keep for potential future use, but GIF generation removed
12
  from scipy.spatial import Delaunay
 
13
 
14
+ # Set Plotly engine for image export
15
+ try:
16
+ pio.kaleido.scope.mathjax = None
17
+ except AttributeError:
18
+ pass
19
+
20
+ def simulate_qlbm_and_animate(num_reg_qubits: int, T: int, distribution_type: str, ux_input: float, uy_input: float, velocity_field_type: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  """
22
+ Simulates a 2D advection-diffusion problem using a Quantum Lattice Boltzmann Method (QLBM)
23
+ and generates an interactive Plotly figure with a slider for selected time steps.
24
  """
25
+ # GIF related variables removed as GIF generation is no longer needed
26
+ # video_length = T
27
+ # simulation_fps = 0.1
28
+ # frames = int(simulation_fps * video_length)
29
+ # if frames == 0:
30
+ # frames = 1
31
+
32
+ num_anc = 3
33
+ num_qubits_total = 2 * num_reg_qubits + num_anc
34
+ current_N = 2**num_reg_qubits
35
+ N_tot_state_vector = 2**num_qubits_total
36
+ num_ranks = 1
37
+ rank = 0
38
+ N_sub_per_rank = int(N_tot_state_vector // num_ranks)
39
+
40
+ # timesteps_per_frame logic removed as it was for GIF
41
+ # timesteps_per_frame = 1
42
+ # if frames < T and frames > 0:
43
+ # timesteps_per_frame = int(T / frames)
44
+ # if timesteps_per_frame == 0:
45
+ # timesteps_per_frame = 1
46
+
47
+ # Initial state setup
48
+ if distribution_type == "Sine Wave (Original)":
49
+ selected_initial_state_function_raw = lambda x, y, N_val_func: \
50
+ np.sin(x * 2 * np.pi / N_val_func) * (1 - 0.5 * x / N_val_func) * \
51
+ np.sin(y * 4 * np.pi / N_val_func) * (1 - 0.5 * y / N_val_func) + 1
52
+ elif distribution_type == "Gaussian":
53
+ selected_initial_state_function_raw = lambda x, y, N_val_func: \
54
+ np.exp(-((x - N_val_func / 2)**2 / (2 * (N_val_func / 5)**2) +
55
+ (y - N_val_func / 2)**2 / (2 * (N_val_func / 5)**2))) * 1.8 + 0.2
56
+ elif distribution_type == "Random":
57
+ selected_initial_state_function_raw = lambda x, y, N_val_func: \
58
+ np.random.rand(N_val_func, N_val_func) * 1.5 + 0.2 if isinstance(x, int) else \
59
+ np.random.rand(x.shape, x.shape[3]) * 1.5 + 0.2
60
+ else:
61
+ print(f"Warning: Unknown distribution type '{distribution_type}'. Defaulting to Sine Wave.")
62
+ selected_initial_state_function_raw = lambda x, y, N_val_func: \
63
+ np.sin(x * 2 * np.pi / N_val_func) * (1 - 0.5 * x / N_val_func) * \
64
+ np.sin(y * 4 * np.pi / N_val_func) * (1 - 0.5 * y / N_val_func) + 1
65
+
66
+ initial_state_func_eval = lambda x_coords, y_coords: \
67
+ selected_initial_state_function_raw(x_coords, y_coords, current_N) * \
68
+ (y_coords < current_N).astype(int)
69
+
70
+ with tempfile.TemporaryDirectory() as tmp_npy_dir:
71
+ intermediate_folder_path = Path(tmp_npy_dir)
72
+
73
+ cudaq.set_target('nvidia', option='fp64')
74
+
75
+ @cudaq.kernel
76
+ def alloc_kernel(num_qubits_alloc: int):
77
+ qubits = cudaq.qvector(num_qubits_alloc)
78
+
79
+ from cupy.cuda.memory import MemoryPointer, UnownedMemory
80
+
81
+ def to_cupy_array(state):
82
+ tensor = state.getTensor()
83
+ pDevice = tensor.data()
84
+ sizeByte = tensor.get_num_elements() * tensor.get_element_size()
85
+ mem = UnownedMemory(pDevice, sizeByte, owner=state)
86
+ memptr_obj = MemoryPointer(mem, 0)
87
+ cupy_array_val = cp.ndarray(tensor.get_num_elements(),
88
+ dtype=cp.complex128,
89
+ memptr=memptr_obj)
90
+ return cupy_array_val
91
+
92
+ class QLBMAdvecDiffD2Q5_new:
93
+ def __init__(self, ux=0.2, uy=0.15) -> None:
94
+ self.dim = 2
95
+ self.ndir = 5
96
+ self.nq_dir = math.ceil(np.log2(self.ndir))
97
+ self.dirs =
98
+ for dir_int in range(self.ndir):
99
+ dir_bin = f"{dir_int:b}".zfill(self.nq_dir)
100
+ self.dirs.append(dir_bin)
101
+ self.e_unitvec = np.array([0, 1, -1, 1, -1])
102
+ self.wts = np.array([2/6, 1/6, 1/6, 1/6, 1/6])
103
+ self.cs = 1 / np.sqrt(3)
104
+ self.ux = ux
105
+ self.uy = uy
106
+ self.u = np.array([0, self.ux, self.ux, self.uy, self.uy])
107
+ self.wtcoeffs = np.multiply(self.wts, 1 + self.e_unitvec * self.u / self.cs**2)
108
+ self.create_circuit()
109
+
110
+ def create_circuit(self):
111
+ v = np.pad(self.wtcoeffs, (0, 2**num_anc - self.ndir))
112
+ v = v**0.5
113
+ v += 1 # Original line was v += 1, not v += 1
114
+ v = v / np.linalg.norm(v)
115
+ U_prep = 2 * np.outer(v, v) - np.eye(len(v))
116
+ cudaq.register_operation("prep_op", U_prep)
117
+
118
+ def collisionOp(dirs_list):
119
+ dirs_i_list_val =
120
+ for dir_str in dirs_list:
121
+ dirs_i = [(int(c)) for c in dir_str]
122
+ dirs_i_list_val += dirs_i[::-1]
123
+ return dirs_i_list_val
124
+
125
+ self.dirs_i_list = collisionOp(self.dirs)
126
+
127
+ @cudaq.kernel
128
+ def rshift(q: cudaq.qview, n: int):
129
+ for i in range(n):
130
+ if i == n - 1:
131
+ x(q[n - 1 - i])
132
+ elif i == n - 2:
133
+ x.ctrl(q[n - 1 - (i + 1)], q[n - 1 - i])
134
+ else:
135
+ x.ctrl(q[0:n - 1 - i], q[n - 1 - i])
136
+
137
+ @cudaq.kernel
138
+ def lshift(q: cudaq.qview, n: int):
139
+ for i in range(n):
140
+ if i == 0:
141
+ x(q) # Corrected from x(q)
142
+ elif i == 1:
143
+ x.ctrl(q, q[3])
144
+ else:
145
+ x.ctrl(q[0:i], q[i])
146
+
147
+ @cudaq.kernel
148
+ def d2q5_tstep(q: cudaq.qview, nqx: int, nqy: int, nq_dir_val: int, dirs_i_val: list[int]):
149
+ qx = q[0:nqx]
150
+ qy = q[nqx:nqx + nqy]
151
+ qdir = q[nqx + nqy:nqx + nqy + nq_dir_val]
152
+
153
+ idx_lqx = 2
154
+ b_list = dirs_i_val[idx_lqx * nq_dir_val:(idx_lqx + 1) * nq_dir_val]
155
+ for j in range(nq_dir_val):
156
+ if b_list[j] == 0: x(qdir[j])
157
+ cudaq.control(lshift, qdir, qx, nqx)
158
+ for j in range(nq_dir_val):
159
+ if b_list[j] == 0: x(qdir[j])
160
+
161
+ idx_rqx = 1
162
+ b_list = dirs_i_val[idx_rqx * nq_dir_val:(idx_rqx + 1) * nq_dir_val]
163
+ for j in range(nq_dir_val):
164
+ if b_list[j] == 0: x(qdir[j])
165
+ cudaq.control(rshift, qdir, qx, nqx)
166
+ for j in range(nq_dir_val):
167
+ if b_list[j] == 0: x(qdir[j])
168
+
169
+ idx_lqy = 4
170
+ b_list = dirs_i_val[idx_lqy * nq_dir_val:(idx_lqy + 1) * nq_dir_val]
171
+ for j in range(nq_dir_val):
172
+ if b_list[j] == 0: x(qdir[j])
173
+ cudaq.control(lshift, qdir, qy, nqy)
174
+ for j in range(nq_dir_val):
175
+ if b_list[j] == 0: x(qdir[j])
176
+
177
+ idx_rqy = 3
178
+ b_list = dirs_i_val[idx_rqy * nq_dir_val:(idx_rqy + 1) * nq_dir_val]
179
+ for j in range(nq_dir_val):
180
+ if b_list[j] == 0: x(qdir[j])
181
+ cudaq.control(rshift, qdir, qy, nqy)
182
+ for j in range(nq_dir_val):
183
+ if b_list[j] == 0: x(qdir[j])
184
+
185
+ @cudaq.kernel
186
+ def d2q5_tstep_wrapper(state_arg: cudaq.State, nqx: int, nqy: int, nq_dir_val: int, dirs_i_val: list[int]):
187
+ q = cudaq.qvector(state_arg)
188
+ qdir = q[nqx + nqy:nqx + nqy + nq_dir_val]
189
+ prep_op(qdir[4], qdir[3], qdir) # Corrected from qdir
190
+ d2q5_tstep(q, nqx, nqy, nq_dir_val, dirs_i_val)
191
+ prep_op(qdir[4], qdir[3], qdir) # Corrected from qdir
192
+
193
+ @cudaq.kernel
194
+ def d2q5_tstep_wrapper_hadamard(vec_arg: list[complex], nqx: int, nqy: int, nq_dir_val: int, dirs_i_val: list[int]):
195
+ q = cudaq.qvector(vec_arg)
196
+ qdir = q[nqx + nqy:nqx + nqy + nq_dir_val]
197
+ qy = q[nqx:nqx + nqy]
198
+ prep_op(qdir[4], qdir[3], qdir) # Corrected from qdir
199
+ d2q5_tstep(q, nqx, nqy, nq_dir_val, dirs_i_val)
200
+ prep_op(qdir[4], qdir[3], qdir) # Corrected from qdir
201
+ for i in range(nqy):
202
+ h(qy[i])
203
+
204
+ def run_timestep_func(vec_arg, hadamard=False):
205
+ if hadamard:
206
+ result = cudaq.get_state(d2q5_tstep_wrapper_hadamard, vec_arg, num_reg_qubits, num_reg_qubits, self.nq_dir, self.dirs_i_list)
207
+ else:
208
+ result = cudaq.get_state(d2q5_tstep_wrapper, vec_arg, num_reg_qubits, num_reg_qubits, self.nq_dir, self.dirs_i_list)
209
+ num_nonzero_ranks = num_ranks / (2**num_anc)
210
+ rank_slice_cupy = to_cupy_array(result)
211
+ if rank >= num_nonzero_ranks and num_nonzero_ranks > 0:
212
+ sub_sv_zeros = np.zeros(N_sub_per_rank, dtype=np.complex128)
213
+ cp.cuda.runtime.memcpy(rank_slice_cupy.data.ptr, sub_sv_zeros.ctypes.data, sub_sv_zeros.nbytes, cp.cuda.runtime.memcpyHostToDevice)
214
+ if rank == 0 and num_nonzero_ranks < 1 and N_sub_per_rank > 0:
215
+ limit_idx = int(N_tot_state_vector / (2**num_anc))
216
+ if limit_idx < rank_slice_cupy.size:
217
+ rank_slice_cupy[limit_idx:] = 0
218
+ return result
219
+ self.run_timestep = run_timestep_func
220
+
221
+ def write_state(self, state_to_write, t_step_str_val):
222
+ rank_slice_cupy = to_cupy_array(state_to_write)
223
+ num_nonzero_ranks = num_ranks / (2**num_anc)
224
+ if rank < num_nonzero_ranks or (rank == 0 and num_nonzero_ranks <= 0):
225
+ save_path = intermediate_folder_path / f"{t_step_str_val}_{rank}.npy"
226
+ with open(save_path, 'wb') as f:
227
+ arr_to_save = None
228
+ data_limit = N_sub_per_rank
229
+ if num_nonzero_ranks < 1 and rank == 0:
230
+ data_limit = int(N_tot_state_vector / (2**num_anc))
231
+ if data_limit > 0:
232
+ relevant_part_cupy = cp.real(rank_slice_cupy[:data_limit])
233
+ else:
234
+ relevant_part_cupy = cp.array(, dtype=cp.float64)
235
+ if relevant_part_cupy.size >= current_N * current_N:
236
+ arr_flat = relevant_part_cupy[:current_N * current_N]
237
+ if downsampling_factor > 1 and current_N > 0:
238
+ arr_reshaped = arr_flat.reshape((current_N, current_N))
239
+ arr_downsampled = arr_reshaped[::downsampling_factor, ::downsampling_factor]
240
+ arr_to_save = arr_downsampled.flatten()
241
+ else:
242
+ arr_to_save = arr_flat
243
+ elif relevant_part_cupy.size > 0:
244
+ if downsampling_factor > 1:
245
+ arr_to_save = relevant_part_cupy[::downsampling_factor]
246
+ else:
247
+ arr_to_save = relevant_part_cupy
248
+ if arr_to_save is not None and arr_to_save.size > 0:
249
+ np.save(f, arr_to_save.get() if isinstance(arr_to_save, cp.ndarray) else arr_to_save)
250
+
251
+ def run_evolution(self, initial_state_arg, total_timesteps, observable=False, timesteps_to_save=None):
252
+ current_state_val = initial_state_arg
253
+ for t_iter in range(total_timesteps):
254
+ next_state_val = None
255
+ if t_iter == total_timesteps - 1 and observable:
256
+ next_state_val = self.run_timestep(current_state_val, True)
257
+ self.write_state(next_state_val, str(t_iter)) # Save final state
258
+ else:
259
+ next_state_val = self.run_timestep(current_state_val)
260
+ # Save data only for specific intervals for the slider
261
+ if timesteps_to_save and t_iter in timesteps_to_save:
262
+ self.write_state(next_state_val, str(t_iter))
263
+ if rank == 0 and t_iter % 10 == 0: # Print progress less frequently
264
+ print(f"Timestep: {t_iter}/{total_timesteps}")
265
+ cp.get_default_memory_pool().free_all_blocks()
266
+ current_state_val = next_state_val
267
+ if rank == 0:
268
+ print(f"Timestep: {total_timesteps}/{total_timesteps} (Evolution complete)")
269
+ cp.get_default_memory_pool().free_all_blocks()
270
+ self.final_state = current_state_val
271
+
272
+ downsampling_factor = 2**5
273
+ if current_N == 0:
274
+ print("Error: current_N is zero. num_reg_qubits likely too small.")
275
+ return None, None
276
+ if current_N < downsampling_factor:
277
+ downsampling_factor = current_N if current_N > 0 else 1
278
+
279
+ qlbm_obj = QLBMAdvecDiffD2Q5_new(ux=ux_input, uy=uy_input)
280
+ initial_state_val = cudaq.get_state(alloc_kernel, num_qubits_total)
281
 
282
+ xv_init = np.arange(current_N)
283
+ yv_init = np.arange(current_N)
284
+ initial_grid_2d_X, initial_grid_2d_Y = np.meshgrid(xv_init, yv_init)
285
+
286
+ if distribution_type == "Random":
287
+ initial_grid_2d = selected_initial_state_function_raw(current_N, current_N, current_N)
288
+ else:
289
+ initial_grid_2d = initial_state_func_eval(initial_grid_2d_X, initial_grid_2d_Y)
290
+
291
+ sub_sv_init_flat = initial_grid_2d.flatten().astype(np.complex128)
292
+ full_initial_sv_host = np.zeros(N_sub_per_rank, dtype=np.complex128)
293
+ num_computational_states = current_N * current_N
294
+
295
+ if len(sub_sv_init_flat) == num_computational_states:
296
+ if num_computational_states <= N_sub_per_rank:
297
+ full_initial_sv_host[:num_computational_states] = sub_sv_init_flat
298
+ else:
299
+ print(f"Error: Grid data {num_computational_states} > N_sub_per_rank {N_sub_per_rank}")
300
+ return None, None
301
+ else:
302
+ print(f"Warning: Initial state size {len(sub_sv_init_flat)}!= expected {num_computational_states}")
303
+ fill_len = min(len(sub_sv_init_flat), num_computational_states, N_sub_per_rank)
304
+ full_initial_sv_host[:fill_len] = sub_sv_init_flat[:fill_len]
305
 
306
+ rank_slice_init = to_cupy_array(initial_state_val)
307
+ print(f'Rank {rank}: Initializing state with {distribution_type} (ux={ux_input}, uy={uy_input})...')
308
+ cp.cuda.runtime.memcpy(rank_slice_init.data.ptr, full_initial_sv_host.ctypes.data, full_initial_sv_host.nbytes, cp.cuda.runtime.memcpyHostToDevice)
309
+ print(f'Rank {rank}: Initial state copied. Size: {len(sub_sv_init_flat)}. N_sub_per_rank: {N_sub_per_rank}')
310
+
311
+ # Explicitly save initial state (t=0)
312
+ qlbm_obj.write_state(initial_state_val, "0")
313
+
314
+ print("Starting QLBM evolution...")
315
+ # Define specific timesteps to save for the slider
316
+ timesteps_for_slider = # T-1 is the last t_iter
317
+ qlbm_obj.run_evolution(initial_state_val, T, timesteps_to_save=timesteps_for_slider)
318
+ print("QLBM evolution complete.")
319
+
320
+ print("Generating plots with Plotly...")
321
+ downsampled_N = current_N // downsampling_factor
322
+ if downsampled_N == 0 and current_N > 0:
323
+ downsampled_N = 1
324
+ elif current_N == 0:
325
+ print("Error: current_N is zero before Plotly stage.")
326
+ return None, None
327
+
328
+ # Load data for specific time steps for interactive plot
329
+ # These correspond to the filenames saved: 0, T//4, 3*T//4, T-1
330
+ time_steps_to_load =
331
+ data_frames =
332
+ actual_timesteps_loaded =
333
+ for t in time_steps_to_load:
334
+ file_path = intermediate_folder_path / f"{t}_{rank}.npy"
335
+ if file_path.exists():
336
+ sol_loaded = np.load(file_path)
337
+ if sol_loaded.size == downsampled_N * downsampled_N:
338
+ Z_data = np.reshape(sol_loaded, (downsampled_N, downsampled_N))
339
+ data_frames.append(Z_data)
340
+ actual_timesteps_loaded.append(t)
341
+ else:
342
+ print(f"Warning: File {file_path} size {sol_loaded.size}!= expected {downsampled_N*downsampled_N}. Skipping.")
343
+ else:
344
+ print(f"Warning: File {file_path} not found. Skipping.")
345
+
346
+ if not data_frames:
347
+ print("Error: No data frames loaded for interactive plot.")
348
+ return None, None
349
+
350
+ x_coords_plot = np.linspace(-10, 10, downsampled_N)
351
+ y_coords_plot = np.linspace(-10, 10, downsampled_N)
352
+
353
+ # Calculate global min/max for consistent scaling
354
+ z_min = min([np.min(Z) for Z in data_frames])
355
+ z_max = max([np.max(Z) for Z in data_frames])
356
+ if z_max == z_min:
357
+ z_max += 1e-9
358
+
359
+ # Create interactive Plotly figure with slider
360
+ fig = go.Figure()
361
+
362
+ for i, Z in enumerate(data_frames):
363
+ fig.add_trace(
364
+ go.Surface(
365
+ z=Z, x=x_coords_plot, y=y_coords_plot,
366
+ colorscale='Viridis',
367
+ cmin=z_min, cmax=z_max,
368
+ name=f'Time: {actual_timesteps_loaded[i]}',
369
+ showscale=(i == 0) # Show color scale only for the first trace
370
+ )
371
  )
 
372
 
373
+ steps =
374
+ for i in range(len(data_frames)):
375
+ step = dict(
376
+ method="update",
377
+ args=[{"visible": [False] * len(data_frames)}],
378
+ label=f"Time: {actual_timesteps_loaded[i]}"
379
+ )
380
+ step["args"]["visible"][i] = True
381
+ steps.append(step)
382
+
383
+ sliders =
384
+
385
+ fig.update_layout(
386
+ title='QLBM Simulation - Density Evolution',
387
+ scene=dict(
388
+ xaxis_title='X',
389
+ yaxis_title='Y',
390
+ zaxis_title='Density',
391
+ xaxis=dict(range=[x_coords_plot, x_coords_plot[-1]]),
392
+ yaxis=dict(range=[y_coords_plot, y_coords_plot[-1]]),
393
+ zaxis=dict(range=[z_min, z_max]),
394
+ ),
395
+ sliders=sliders,
396
+ width=1000, # Increased width
397
+ height=900 # Increased height
398
  )
399
+
400
+ # GIF generation logic removed as per request
401
+ #... (removed all GIF related code)...
402
+
403
+ return fig # Return only the interactive Plotly figure
404
+
405
+ # Gradio Interface Definition
406
+ def qlbm_gradio_interface(num_reg_qubits_input: int, timescale_input: int, distribution_type_param: str, ux_param: float, uy_param: float, velocity_field_type_param: str):
407
+ num_reg_qubits_val = int(num_reg_qubits_input)
408
+ timescale_val = int(timescale_input)
409
+ ux_val = float(ux_param)
410
+ uy_val = float(uy_param)
411
+
412
+ print(f"Gradio Interface: num_reg_qubits={num_reg_qubits_val}, T={timescale_val}, Distribution={distribution_type_param}, ux={ux_val}, uy={uy_val}, VelocityFieldType={velocity_field_type_param}")
413
+
414
+ plot_fig = simulate_qlbm_and_animate( # Only expecting plot_fig now
415
+ num_reg_qubits=num_reg_qubits_val,
416
+ T=timescale_val,
417
+ distribution_type=distribution_type_param,
418
+ ux_input=ux_val,
419
+ uy_input=uy_val,
420
+ velocity_field_type=velocity_field_type_param # Pass the new dummy parameter
421
  )
 
 
 
422
 
423
+ if plot_fig is None:
424
+ gr.Warning("Simulation or plotting failed. Please check console for errors.")
425
+ return None
426
+ return plot_fig # Return only the interactive Plotly figure
427
 
428
+ with gr.Blocks(theme=gr.themes.Soft(), title="QLBM Simulation with Plotly") as qlbm_demo:
429
+ gr.Markdown(
430
+ """
431
+ # ⚛️ Quantum Lattice Boltzmann Method (QLBM) Simulator (Plotly Animation)
432
+ Welcome to the Quantum Lattice Boltzmann Method (QLBM) simulator! This version uses Plotly for 3D animation and interactive plots.
 
 
 
 
 
 
 
 
 
 
 
 
433
 
434
+ **How this Simulator Works:**
435
+ This simulator implements a D2Q5 model on a quantum computer simulator (CUDA-Q).
436
+ - Control grid size (via Number of Register Qubits: $N=2^{\text{num_reg_qubits}}$).
437
+ - Set total simulation time (Timescale T).
438
+ - Choose initial distribution.
439
+ - Set advection velocities `ux` and `uy`.
440
+ The simulation generates an interactive Plotly figure with a slider for selected time steps.
441
 
442
+ **Note:** Higher qubit counts and longer timescales are computationally intensive. Advection velocities should be small (e.g., < 0.3).
443
+ The Plotly figure allows interactive exploration of specific time steps.
444
+ """
445
+ )
446
  with gr.Row():
447
  with gr.Column(scale=1):
448
  gr.Markdown("## Simulation Parameters")
449
+ num_reg_qubits_slider = gr.Slider(
450
+ minimum=2, maximum=10, value=8, step=1,
451
+ label="Number of Register Qubits (num_reg_qubits)",
452
+ info="Grid N = 2^num_reg_qubits. Max 10 (Note: >8 slow; >9 may hit simulator/memory limits on free tiers)."
453
+ )
454
+ timescale_slider = gr.Slider(
455
+ minimum=0, maximum=2000, value=100, step=10,
456
+ label="Timescale (T)", info="Total number of timesteps. Max 2000."
457
+ )
458
+
459
+ # Group 1: Initial Conditions
460
+ with gr.Accordion("Initial Conditions", open=True):
461
+ distribution_options =
462
+ distribution_type_input = gr.Radio(
463
+ choices=distribution_options, value="Sine Wave (Original)",
464
+ label="Initial Distribution Type", info="Select the initial pattern of the substance."
465
+ )
466
+
467
+ # Group 2: Velocity Fields
468
+ with gr.Accordion("Velocity Fields", open=True):
469
+ velocity_field_options = # Dummy options
470
+ velocity_field_type_input = gr.Radio(
471
+ choices=velocity_field_options, value="Uniform",
472
+ label="Velocity Field Type", info="Select the type of background velocity field."
473
+ )
474
+ ux_slider = gr.Slider(
475
+ minimum=-0.4, maximum=0.4, value=0.2, step=0.01,
476
+ label="Advection Velocity ux", info="x-component of background advection."
477
+ )
478
+ uy_slider = gr.Slider(
479
+ minimum=-0.4, maximum=0.4, value=0.15, step=0.01,
480
+ label="Advection Velocity uy", info="y-component of background advection."
481
+ )
482
 
483
+ run_qlbm_btn = gr.Button("Run QLBM Simulation", variant="primary")
484
 
485
+ with gr.Column(scale=2):
486
+ # Removed gr.Image for GIF
487
+ # qlbm_plot_output = gr.Image(label="QLBM Simulation Animation (GIF)", type="filepath", height=900)
488
+ qlbm_interactive_plot = gr.Plot(label="Interactive Density Plot with Slider")
489
 
490
+ qlbm_inputs_list = [num_reg_qubits_slider, timescale_slider, distribution_type_input, ux_slider, uy_slider, velocity_field_type_input]
491
+ run_qlbm_btn.click(
492
+ fn=qlbm_gradio_interface,
493
+ inputs=qlbm_inputs_list,
494
+ outputs=[qlbm_interactive_plot] # Only interactive plot
495
+ )
496
  gr.Examples(
497
+ examples=,
498
+ [6, 50, "Gaussian", 0.1, 0.05, "Uniform"],
499
+ ,
500
+ inputs=qlbm_inputs_list,
501
+ outputs=[qlbm_interactive_plot], # Only interactive plot
502
+ fn=qlbm_gradio_interface,
 
503
  cache_examples=False
504
  )
505
 
506
  if __name__ == "__main__":
507
+ try:
508
+ cudaq.set_target('nvidia', option='fp64')
509
+ print(f"CUDA-Q Target successfully set to: {cudaq.get_target().name}")
510
+ except Exception as e_target:
511
+ print(f"Warning: Could not set CUDA-Q target to 'nvidia'. Error: {e_target}")
512
+ print(f"Current CUDA-Q Target: {cudaq.get_target().name}. Performance may be affected.")
513
+
514
+ qlbm_demo.launch()