harishaseebat92 commited on
Commit
f271967
·
verified ·
1 Parent(s): 4c922b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -250
app.py CHANGED
@@ -1,302 +1,224 @@
1
  import numpy as np
2
- import matplotlib.pyplot as plt
3
- from matplotlib.animation import FuncAnimation
4
- import tempfile
5
  import gradio as gr
6
- import plotly.graph_objects as go
7
- import base64
8
- from fastapi import FastAPI
9
- from pydantic import BaseModel
10
 
11
- # --- Simulation Core (3D) ---
12
- def solve_3d_heat_equation(Lx: float, Ly: float, Lz: float,
13
- t_max: float,
14
- Gamma: float = 0.1,
15
- Nx: int = 30, Ny: int = 30, Nz: int = 30,
16
- initial: str = "gaussian",
17
- bc: str = "dirichlet"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  x = np.linspace(0, Lx, Nx)
19
  y = np.linspace(0, Ly, Ny)
20
- z = np.linspace(0, Lz, Nz)
21
-
22
- # Corrected dx, dy, dz calculation
23
- dx = x[1] - x[0] if Nx > 1 else Lx
24
- dy = y[1] - y[0] if Ny > 1 else Ly
25
- dz = z[1] - z[0] if Nz > 1 else Lz
26
-
27
- if dx == 0 or dy == 0 or dz == 0:
28
- raise ValueError("Grid spacing (dx, dy, dz) cannot be zero. Ensure Nx, Ny, Nz > 1.")
29
 
30
- # Stability condition for 3D FTCS scheme
31
- dt = 0.5 / (Gamma * (1/dx**2 + 1/dy**2 + 1/dz**2))
 
32
  Nt = int(np.ceil(t_max / dt)) + 1
33
-
34
- rx, ry, rz = Gamma * dt / dx**2, Gamma * dt / dy**2, Gamma * dt / dz**2
35
-
36
- X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
37
 
 
 
 
38
  if initial == "gaussian":
39
- u = np.exp(-(((X - Lx/2)**2 + (Y - Ly/2)**2 + (Z - Lz/2)**2) / (2*(max(Lx,Ly,Lz)/10)**2)))
40
  elif initial == "random":
41
- u = np.random.rand(Nx, Ny, Nz)
42
  elif initial == "sinusoidal":
43
- kx = 2 * np.pi / Lx if Lx > 0 else 0
44
- ky = 2 * np.pi / Ly if Ly > 0 else 0
45
- kz = 2 * np.pi / Lz if Lz > 0 else 0
46
- u = np.sin(kx * X) * np.sin(ky * Y) * np.sin(kz * Z)
47
  elif initial == "step":
48
- u = np.where((X < Lx/2) & (Y < Ly/2) & (Z < Lz/2), 1.0, 0.0)
49
  else:
50
  raise ValueError(f"Unknown initial condition: {initial}")
51
 
52
- U = np.zeros((Nt, Nx, Ny, Nz))
53
- U[0] = u.copy() # Store initial condition
 
 
 
 
 
 
 
54
 
 
55
  for n in range(1, Nt):
56
  un = u.copy()
57
- if Nx > 2 and Ny > 2 and Nz > 2:
58
- u[1:-1, 1:-1, 1:-1] = (
59
- un[1:-1, 1:-1, 1:-1]
60
- + rx * (un[2:, 1:-1, 1:-1] - 2 * un[1:-1, 1:-1, 1:-1] + un[:-2, 1:-1, 1:-1])
61
- + ry * (un[1:-1, 2:, 1:-1] - 2 * un[1:-1, 1:-1, 1:-1] + un[1:-1, :-2, 1:-1])
62
- + rz * (un[1:-1, 1:-1, 2:] - 2 * un[1:-1, 1:-1, 1:-1] + un[1:-1, 1:-1, :-2])
63
- )
64
-
65
  if bc == "dirichlet":
66
- if Nx > 0: u[0, :, :] = u[-1, :, :] = 0.0
67
- if Ny > 0: u[:, 0, :] = u[:, -1, :] = 0.0
68
- if Nz > 0: u[:, :, 0] = u[:, :, -1] = 0.0
69
  elif bc == "neumann":
70
- if Nx > 1: u[0, :, :] = u[1, :, :]; u[-1, :, :] = u[-2, :, :]
71
- if Ny > 1: u[:, 0, :] = u[:, 1, :]; u[:, -1, :] = u[:, -2, :]
72
- if Nz > 1: u[:, :, 0] = u[:, :, 1]; u[:, :, -1] = u[:, :, -2]
 
73
  elif bc == "periodic":
74
- if Nx > 1: u[0, :, :] = un[-2, :, :]; u[-1, :, :] = un[1, :, :]
75
- if Ny > 1: u[:, 0, :] = un[:, -2, :]; u[-1, :] = un[:, 1, :]
76
- if Nz > 1: u[:, :, 0] = un[:, :, -2]; u[:, :, -1] = un[:, :, 1]
77
- else:
78
- raise ValueError(f"Unknown bc: {bc}")
79
- U[n] = u.copy()
80
- return U, dt
81
-
82
- # --- Animation Generator (3D Slice) ---
83
- def create_animation_gif_3d_slice(U, Lx, Ly, Lz, initial, bc, Gamma, frame_skip, dt):
84
- Nt, Nx, Ny, Nz = U.shape
85
- fig, ax = plt.subplots()
86
-
87
- slice_z_idx = Nz // 2 if Nz > 0 else 0
88
- z_coord_slice = np.linspace(0, Lz, Nz)[slice_z_idx] if Nz > 0 else 0
89
-
90
- data_slice = U[0, :, :, slice_z_idx].T if Nt > 0 and Nz > 0 else np.zeros((Ny, Nx))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- vmin_val = U.min() if Nt > 0 and U.size > 0 and np.all(np.isfinite(U)) else 0
93
- vmax_val = U.max() if Nt > 0 and U.size > 0 and np.all(np.isfinite(U)) else 1
94
- if vmin_val == vmax_val: vmax_val = vmin_val + 1
95
-
96
- im = ax.imshow(data_slice, cmap='viridis', origin='lower',
97
- extent=[0, Lx, 0, Ly], vmin=vmin_val, vmax=vmax_val)
98
- ax.set_title(f"3D Heat Eq (xy-slice at z={z_coord_slice:.2f})\ninit={initial}, bc={bc}, Gamma={Gamma:.2f}")
99
- ax.set_xlabel("x")
100
- ax.set_ylabel("y")
101
- plt.colorbar(im, ax=ax, label="u")
102
-
103
- def update(frame):
104
- if Nz > 0:
105
- im.set_data(U[frame, :, :, slice_z_idx].T)
106
- return [im]
107
-
108
- # Corrected idx generation
109
- if Nt <= 1:
110
- idx = [0] # Only initial frame if Nt is 0 or 1
111
- else:
112
- current_frame_skip = max(1, frame_skip)
113
- idx = list(range(0, Nt, current_frame_skip))
114
- if (Nt - 1) not in idx:
115
- idx.append(Nt - 1)
116
- idx = sorted(list(set(idx)))
117
-
118
- ani = FuncAnimation(fig, update, frames=idx, blit=True)
119
-
120
- with tempfile.NamedTemporaryFile(suffix='.gif', delete=False) as tmpfile:
121
- ani.save(tmpfile.name, writer='pillow', fps=max(1, 30 // max(1, frame_skip)))
122
- gif_path = tmpfile.name
123
 
124
- plt.close(fig)
125
- return gif_path
126
 
127
- # --- Plotly Figure Generator (3D Volume) ---
128
- def create_plotly_figure_3d(u_3d, Lx, Ly, Lz, time_label):
129
- Nx, Ny, Nz = u_3d.shape
130
- if Nx == 0 or Ny == 0 or Nz == 0:
131
- return go.Figure(layout_title_text=f"3D Heat Distribution (No Data) at t={time_label}")
132
-
133
- x_coords = np.linspace(0, Lx, Nx)
134
- y_coords = np.linspace(0, Ly, Ny)
135
- z_coords = np.linspace(0, Lz, Nz)
136
 
137
- X, Y, Z = np.meshgrid(x_coords, y_coords, z_coords, indexing='ij')
138
-
139
- vmin = np.min(u_3d[np.isfinite(u_3d)]) if np.any(np.isfinite(u_3d)) else 0
140
- vmax = np.max(u_3d[np.isfinite(u_3d)]) if np.any(np.isfinite(u_3d)) else 1
141
- if vmin == vmax: vmax = vmin + 0.1
142
-
143
- fig = go.Figure(data=go.Volume(
144
- x=X.flatten(),
145
- y=Y.flatten(),
146
- z=Z.flatten(),
147
- value=u_3d.flatten(),
148
- isomin=vmin,
149
- isomax=vmax,
150
- opacity=0.1,
151
- surface_count=17,
152
- colorscale='viridis'
153
- ))
154
- fig.update_layout(
155
- title=f"3D Heat Distribution at t={time_label}",
156
- scene=dict(
157
- xaxis_title='x',
158
- yaxis_title='y',
159
- zaxis_title='z',
160
- aspectmode='cube'
161
- )
162
- )
163
- return fig
164
-
165
- # --- Simulation Runner (Extracted Logic for 3D) ---
166
- def run_simulation_3d(lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_skip):
167
- nx = max(3, int(nx))
168
- ny = max(3, int(ny))
169
- nz = max(3, int(nz))
170
-
171
- U, dt = solve_3d_heat_equation(
172
- Lx=lx, Ly=ly, Lz=lz, t_max=t_max, Gamma=gamma,
173
- Nx=nx, Ny=ny, Nz=nz, initial=initial, bc=bc
174
  )
175
- Nt = U.shape[0]
176
-
177
- idx0 = 0
178
- idx1 = round((Nt - 1) / 4) if Nt > 1 else 0
179
- idx2 = round(3 * (Nt - 1) / 4) if Nt > 1 else 0
180
- idx3 = Nt - 1 if Nt > 1 else 0
181
-
182
- u0 = U[idx0]
183
- u1 = U[idx1]
184
- u2 = U[idx2]
185
- u3 = U[idx3]
186
-
187
- fig0 = create_plotly_figure_3d(u0, lx, ly, lz, "0")
188
- fig1 = create_plotly_figure_3d(u1, lx, ly, lz, f"{idx1*dt:.2f}")
189
- fig2 = create_plotly_figure_3d(u2, lx, ly, lz, f"{idx2*dt:.2f}")
190
- fig3 = create_plotly_figure_3d(u3, lx, ly, lz, f"{idx3*dt:.2f}")
191
-
192
- gif_path = create_animation_gif_3d_slice(U, lx, ly, lz, initial, bc, gamma, frame_skip, dt)
193
- return gif_path, fig0, fig1, fig2, fig3
194
 
195
- # --- Gradio Interface Logic (3D) ---
196
- def gradio_interface_3d(lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_skip):
197
- nx_int, ny_int, nz_int = int(nx), int(ny), int(nz)
198
- frame_skip_int = max(1, int(frame_skip))
199
 
200
- nx_int = max(3, nx_int)
201
- ny_int = max(3, ny_int)
202
- nz_int = max(3, nz_int)
 
203
 
204
- gif_path, fig0, fig1, fig2, fig3 = run_simulation_3d(
205
- lx, ly, lz, t_max, gamma, nx_int, ny_int, nz_int, initial, bc, frame_skip_int
206
- )
207
- return gif_path, fig0, fig1, fig2, fig3
208
-
209
- # --- Gradio UI Layout (3D) ---
210
- with gr.Blocks(theme=gr.themes.Soft(), title="3D Heat Simulator") as demo:
211
- gr.Markdown("# 🔥 3D Heat Equation Simulator\nAdjust parameters and run the simulation. Animation shows a central xy-slice. Grid (Nx,Ny,Nz) min 3.")
212
  with gr.Row():
213
  with gr.Column(scale=1):
214
  gr.Markdown("## Domain & Grid")
215
- lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx")
216
- ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly")
217
- lz_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lz")
218
- nx_slider = gr.Slider(3, 60, 20, 1, label="Nx (min 3, e.g., 20-40)")
219
- ny_slider = gr.Slider(3, 60, 20, 1, label="Ny (min 3, e.g., 20-40)")
220
- nz_slider = gr.Slider(3, 60, 20, 1, label="Nz (min 3, e.g., 20-40)")
221
 
222
  gr.Markdown("## Simulation")
223
- t_slider = gr.Slider(0.01, 1.0, 0.1, 0.01, label="t_max (e.g., 0.1-0.5)")
224
- gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma")
225
 
226
  gr.Markdown("## Conditions")
227
  initial_dropdown = gr.Dropdown(
228
- ["gaussian", "random", "sinusoidal", "step"], "gaussian", label="Initial"
229
  )
230
  bc_dropdown = gr.Dropdown(
231
- ["dirichlet", "neumann", "periodic"], "dirichlet", label="Boundary"
232
  )
233
 
234
- gr.Markdown("## Animation")
235
- frame_skip_slider = gr.Slider(1, 50, 5, 1, label="Frame Skip (e.g., 5-10)")
236
- run_btn = gr.Button("Run 3D Simulation", variant="primary")
237
 
238
- with gr.Column(scale=3):
239
- gif_output = gr.Image(label="Animation (Central XY Slice)")
240
- with gr.Row():
241
- plot1 = gr.Plot(label="3D Volume at t=0")
242
- plot2 = gr.Plot(label="3D Volume at t=t/4")
243
- with gr.Row():
244
- plot3 = gr.Plot(label="3D Volume at t=3t/4")
245
- plot4 = gr.Plot(label="3D Volume at t=t_max")
246
-
247
- inputs_list = [lx_slider, ly_slider, lz_slider, t_slider, gamma_slider,
248
- nx_slider, ny_slider, nz_slider, initial_dropdown, bc_dropdown, frame_skip_slider]
249
- outputs_list = [gif_output, plot1, plot2, plot3, plot4]
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- run_btn.click(fn=gradio_interface_3d, inputs=inputs_list, outputs=outputs_list)
252
 
 
253
  gr.Examples(
254
  examples=[
255
- [1.0, 1.0, 1.0, 0.1, 0.1, 20, 20, 20, "gaussian", "dirichlet", 5],
256
- [1.5, 1.0, 0.5, 0.2, 0.05, 25, 20, 15, "sinusoidal", "periodic", 10],
257
- [1.0, 1.0, 1.0, 0.05, 0.2, 15, 15, 15, "step", "neumann", 2]
258
  ],
259
  inputs=inputs_list,
260
- outputs=outputs_list,
261
- fn=gradio_interface_3d,
262
- cache_examples=False
263
  )
264
 
265
- # --- FastAPI Setup for API Endpoint (3D) ---
266
- app = FastAPI()
267
-
268
- app = gr.mount_gradio_app(app, demo, path="/")
269
-
270
- class SimulationParams3D(BaseModel):
271
- lx: float
272
- ly: float
273
- lz: float
274
- t_max: float
275
- gamma: float
276
- nx: int
277
- ny: int
278
- nz: int
279
- initial: str
280
- bc: str
281
- frame_skip: int
282
-
283
- @app.post("/simulate_3d")
284
- def simulate_3d_api(params: SimulationParams3D):
285
- params.frame_skip = max(1, params.frame_skip)
286
- params.nx = max(3, params.nx)
287
- params.ny = max(3, params.ny)
288
- params.nz = max(3, params.nz)
289
-
290
- gif_path, fig0, fig1, fig2, fig3 = run_simulation_3d(**params.dict())
291
- with open(gif_path, "rb") as f:
292
- gif_data = base64.b64encode(f.read()).decode('utf-8')
293
- return {
294
- "gif_base64": gif_data,
295
- "plot0_3d_volume": fig0.to_json(),
296
- "plot1_3d_volume": fig1.to_json(),
297
- "plot2_3d_volume": fig2.to_json(),
298
- "plot3_3d_volume": fig3.to_json()
299
- }
300
-
301
  if __name__ == "__main__":
302
  demo.launch()
 
1
  import numpy as np
2
+ import pyvista as pv
 
 
3
  import gradio as gr
 
 
 
 
4
 
5
+ # --- Core Simulation and Plotting Function ---
6
+ def solve_and_plot_interactive(Lx: float,
7
+ Ly: float,
8
+ t_max: float,
9
+ M: int, # New: Number of time steps for the slider
10
+ Gamma: float = 0.1,
11
+ Nx: int = 50,
12
+ Ny: int = 50,
13
+ initial: str = "gaussian",
14
+ bc: str = "dirichlet"):
15
+ """
16
+ Solves the 2D heat equation and displays the result in an interactive
17
+ PyVista window with a time slider.
18
+
19
+ Args:
20
+ Lx (float): Length of the domain in the x-direction.
21
+ Ly (float): Length of the domain in the y-direction.
22
+ t_max (float): Maximum simulation time.
23
+ M (int): Number of equidistant time steps to be available on the slider.
24
+ Gamma (float): Thermal diffusivity.
25
+ Nx (int): Number of grid points in the x-direction.
26
+ Ny (int): Number of grid points in the y-direction.
27
+ initial (str): Initial condition type.
28
+ bc (str): Boundary condition type.
29
+ """
30
+ # --- 1. Simulation Setup ---
31
+ # Spatial grid
32
  x = np.linspace(0, Lx, Nx)
33
  y = np.linspace(0, Ly, Ny)
34
+ dx, dy = x[1] - x[0], y[1] - y[0]
35
+ if dx == 0 or dy == 0:
36
+ raise ValueError("Nx and Ny must be > 1.")
 
 
 
 
 
 
37
 
38
+ # Time stepping for stability
39
+ # A small factor (0.9) is added for more robust stability
40
+ dt = 0.9 / (2 * Gamma * (1/dx**2 + 1/dy**2))
41
  Nt = int(np.ceil(t_max / dt)) + 1
42
+ rx, ry = Gamma * dt / dx**2, Gamma * dt / dy**2
 
 
 
43
 
44
+ # Initial condition
45
+ X, Y = np.meshgrid(x, y, indexing='ij')
46
+ u = np.zeros((Nx, Ny))
47
  if initial == "gaussian":
48
+ u = np.exp(-(((X - Lx/2)**2 + (Y - Ly/2)**2) / (2*(Lx/10)**2)))
49
  elif initial == "random":
50
+ u = np.random.rand(Nx, Ny)
51
  elif initial == "sinusoidal":
52
+ kx, ky = 2 * np.pi / Lx, 2 * np.pi / Ly
53
+ u = np.sin(kx * X) * np.sin(ky * Y)
 
 
54
  elif initial == "step":
55
+ u = np.where((X < Lx/2) & (Y < Ly/2), 1.0, 0.0)
56
  else:
57
  raise ValueError(f"Unknown initial condition: {initial}")
58
 
59
+ # --- 2. Solve the Heat Equation ---
60
+ # Select M equidistant time indices to store for visualization
61
+ time_indices = np.linspace(0, Nt - 1, M, dtype=int)
62
+ U_slider = np.zeros((M, Nx, Ny))
63
+ store_idx = 0
64
+
65
+ if 0 in time_indices:
66
+ U_slider[store_idx] = u.copy()
67
+ store_idx += 1
68
 
69
+ # Time-stepping loop
70
  for n in range(1, Nt):
71
  un = u.copy()
72
+ # Interior update using finite differences
73
+ u[1:-1, 1:-1] = (
74
+ un[1:-1, 1:-1]
75
+ + rx * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[:-2, 1:-1])
76
+ + ry * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, :-2])
77
+ )
78
+ # Boundary conditions
 
79
  if bc == "dirichlet":
80
+ u[0, :] = u[-1, :] = u[:, 0] = u[:, -1] = 0.0
 
 
81
  elif bc == "neumann":
82
+ u[0, :] = u[1, :]
83
+ u[-1, :] = u[-2, :]
84
+ u[:, 0] = u[:, 1]
85
+ u[:, -1] = u[:, -2]
86
  elif bc == "periodic":
87
+ # Note: A true periodic BC often uses ghost cells. This is a simpler implementation.
88
+ u[0, :] = un[-2, :]
89
+ u[-1, :] = un[1, :]
90
+ u[:, 0] = un[:, -2]
91
+ u[:, -1] = un[:, 1]
92
+
93
+ # Store the frame if it's one of the selected time steps
94
+ if n in time_indices:
95
+ if store_idx < M:
96
+ U_slider[store_idx] = u.copy()
97
+ store_idx += 1
98
+
99
+ # --- 3. Interactive Visualization with PyVista ---
100
+ # Create a 2D mesh (a structured grid) in 3D space (with Z=0)
101
+ grid = pv.StructuredGrid()
102
+ grid.points = np.c_[X.flatten('F'), Y.flatten('F'), np.zeros(Nx * Ny)]
103
+ grid.dimensions = [Nx, Ny, 1]
104
+
105
+ # Add the initial temperature data as scalars to the grid
106
+ # The data needs to be flattened in 'C' (row-major) order for PyVista
107
+ grid['temperature'] = U_slider[0, :, :].flatten('C')
108
+
109
+ # Set up the plotter
110
+ plotter = pv.Plotter()
111
+ plotter.add_mesh(grid, scalars='temperature', cmap='viridis',
112
+ scalar_bar_args={'title': 'Temperature'})
113
+ plotter.view_xy() # Set camera to look down the Z-axis
114
+
115
+ # Define the callback function that updates the plot when the slider moves
116
+ def update_plot(time_step_index):
117
+ # Get the integer index from the slider
118
+ idx = int(time_step_index)
119
+ # Update the scalars on the grid
120
+ grid['temperature'] = U_slider[idx, :, :].flatten('C')
121
+ # Optional: Add a text annotation for the current time
122
+ time_value = (time_indices[idx] / (Nt-1)) * t_max if Nt > 1 else 0
123
+ plotter.add_text(f"Time: {time_value:.2f}s", name='time_label')
124
+
125
+ # Add the slider widget to the plotter
126
+ plotter.add_slider_widget(
127
+ callback=update_plot,
128
+ rng=[0, M - 1], # The slider range corresponds to the indices of U_slider
129
+ value=0,
130
+ title="Time Step",
131
+ style='modern'
132
+ )
133
 
134
+ # Display the plotter window. This is a blocking call.
135
+ # The script will pause here until you close the PyVista window.
136
+ plotter.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
 
 
138
 
139
+ # --- Gradio Interface Function ---
140
+ def gradio_interface(lx, ly, t_max, m_steps, gamma, nx, ny, initial, bc):
141
+ """Wrapper function to connect Gradio inputs to the simulation."""
142
+ # Ensure integer types for grid dimensions
143
+ nx, ny, m_steps = int(nx), int(ny), int(m_steps)
 
 
 
 
144
 
145
+ # Run the simulation and launch the interactive plot
146
+ solve_and_plot_interactive(
147
+ Lx=lx, Ly=ly, t_max=t_max, M=m_steps, Gamma=gamma, Nx=nx, Ny=ny,
148
+ initial=initial, bc=bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  )
150
+ # This function no longer needs to return anything to Gradio
151
+ return "Simulation window launched. Please check your desktop."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
 
 
 
 
153
 
154
+ # --- Gradio UI Definition ---
155
+ with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
156
+ gr.Markdown("# ♨️ 2D Heat Equation Simulator")
157
+ gr.Markdown("Adjust parameters and click 'Run' to launch an interactive window where you can pan, zoom, and scrub through time.")
158
 
 
 
 
 
 
 
 
 
159
  with gr.Row():
160
  with gr.Column(scale=1):
161
  gr.Markdown("## Domain & Grid")
162
+ lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx (Domain Length X)")
163
+ ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly (Domain Length Y)")
164
+ nx_slider = gr.Slider(10, 200, 50, 1, label="Nx (Grid Points X)")
165
+ ny_slider = gr.Slider(10, 200, 50, 1, label="Ny (Grid Points Y)")
 
 
166
 
167
  gr.Markdown("## Simulation")
168
+ t_slider = gr.Slider(0.01, 5.0, 0.5, 0.01, label="t_max (Total Time)")
169
+ gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma (Diffusivity)")
170
 
171
  gr.Markdown("## Conditions")
172
  initial_dropdown = gr.Dropdown(
173
+ ["gaussian", "random", "sinusoidal", "step"], value="gaussian", label="Initial Condition"
174
  )
175
  bc_dropdown = gr.Dropdown(
176
+ ["dirichlet", "neumann", "periodic"], value="dirichlet", label="Boundary Condition"
177
  )
178
 
179
+ gr.Markdown("## Interactive Plot")
180
+ # New slider to control the number of time steps in the interactive window
181
+ m_slider = gr.Slider(2, 200, 40, 1, label="M (Time Steps on Slider)")
182
 
183
+ run_btn = gr.Button("Run Simulation", variant="primary")
184
+
185
+ with gr.Column(scale=2):
186
+ # The output is now a simple text confirmation
187
+ status_output = gr.Textbox(label="Status")
188
+ gr.Markdown(
189
+ """
190
+ ### How to Use:
191
+ 1. Set your desired simulation parameters on the left.
192
+ 2. `M (Time Steps on Slider)` controls how many time points will be available in the interactive view. Higher values give smoother time control but use more memory.
193
+ 3. Click **Run Simulation**.
194
+ 4. A new window will open.
195
+ - **Left-Click + Drag**: Rotate the view.
196
+ - **Right-Click + Drag**: Pan the view.
197
+ - **Scroll Wheel**: Zoom in and out.
198
+ - **Use the Slider**: To move through simulation time.
199
+ 5. Close the interactive window to run another simulation.
200
+ """
201
+ )
202
+
203
+
204
+ # Connect the button to the interface function
205
+ inputs_list = [lx_slider, ly_slider, t_slider, m_slider, gamma_slider,
206
+ nx_slider, ny_slider, initial_dropdown, bc_dropdown]
207
 
208
+ run_btn.click(fn=gradio_interface, inputs=inputs_list, outputs=status_output)
209
 
210
+ # Define some example configurations
211
  gr.Examples(
212
  examples=[
213
+ [1.0, 1.0, 0.5, 50, 0.1, 50, 50, "gaussian", "dirichlet"],
214
+ [2.0, 1.0, 1.0, 80, 0.05, 60, 30, "sinusoidal", "periodic"],
215
+ [1.0, 1.0, 0.2, 40, 0.2, 80, 80, "step", "neumann"],
216
  ],
217
  inputs=inputs_list,
218
+ outputs=[status_output],
219
+ fn=gradio_interface,
220
+ cache_examples=False # It's better to rerun live simulations
221
  )
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  if __name__ == "__main__":
224
  demo.launch()