harishaseebat92 commited on
Commit
381e9af
·
verified ·
1 Parent(s): 3016dc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -135
app.py CHANGED
@@ -1,48 +1,31 @@
1
  import numpy as np
2
  import pyvista as pv
 
3
  import gradio as gr
 
4
 
5
- # --- Core Simulation and Plotting Function ---
6
  def solve_and_plot_interactive(Lx: float,
7
  Ly: float,
8
  t_max: float,
9
- M: int, # New: Number of time steps for the slider
10
  Gamma: float = 0.1,
11
  Nx: int = 50,
12
  Ny: int = 50,
13
  initial: str = "gaussian",
14
  bc: str = "dirichlet"):
15
  """
16
- Solves the 2D heat equation and displays the result in an interactive
17
- PyVista window with a time slider.
18
-
19
- Args:
20
- Lx (float): Length of the domain in the x-direction.
21
- Ly (float): Length of the domain in the y-direction.
22
- t_max (float): Maximum simulation time.
23
- M (int): Number of equidistant time steps to be available on the slider.
24
- Gamma (float): Thermal diffusivity.
25
- Nx (int): Number of grid points in the x-direction.
26
- Ny (int): Number of grid points in the y-direction.
27
- initial (str): Initial condition type.
28
- bc (str): Boundary condition type.
29
  """
30
- # --- 1. Simulation Setup ---
31
- # Spatial grid
32
  x = np.linspace(0, Lx, Nx)
33
  y = np.linspace(0, Ly, Ny)
34
  dx, dy = x[1] - x[0], y[1] - y[0]
35
- if dx == 0 or dy == 0:
36
- raise ValueError("Nx and Ny must be > 1.")
37
-
38
- # Time stepping for stability
39
- # A small factor (0.9) is added for more robust stability
40
  dt = 0.9 / (2 * Gamma * (1/dx**2 + 1/dy**2))
41
  Nt = int(np.ceil(t_max / dt)) + 1
42
  rx, ry = Gamma * dt / dx**2, Gamma * dt / dy**2
43
-
44
- # Initial condition
45
  X, Y = np.meshgrid(x, y, indexing='ij')
 
46
  u = np.zeros((Nx, Ny))
47
  if initial == "gaussian":
48
  u = np.exp(-(((X - Lx/2)**2 + (Y - Ly/2)**2) / (2*(Lx/10)**2)))
@@ -53,171 +36,140 @@ def solve_and_plot_interactive(Lx: float,
53
  u = np.sin(kx * X) * np.sin(ky * Y)
54
  elif initial == "step":
55
  u = np.where((X < Lx/2) & (Y < Ly/2), 1.0, 0.0)
56
- else:
57
- raise ValueError(f"Unknown initial condition: {initial}")
58
 
59
- # --- 2. Solve the Heat Equation ---
60
- # Select M equidistant time indices to store for visualization
61
  time_indices = np.linspace(0, Nt - 1, M, dtype=int)
62
  U_slider = np.zeros((M, Nx, Ny))
63
  store_idx = 0
64
-
65
  if 0 in time_indices:
66
  U_slider[store_idx] = u.copy()
67
  store_idx += 1
68
 
69
- # Time-stepping loop
70
  for n in range(1, Nt):
71
  un = u.copy()
72
- # Interior update using finite differences
73
  u[1:-1, 1:-1] = (
74
  un[1:-1, 1:-1]
75
  + rx * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[:-2, 1:-1])
76
  + ry * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, :-2])
77
  )
78
- # Boundary conditions
79
  if bc == "dirichlet":
80
  u[0, :] = u[-1, :] = u[:, 0] = u[:, -1] = 0.0
81
  elif bc == "neumann":
82
- u[0, :] = u[1, :]
83
- u[-1, :] = u[-2, :]
84
- u[:, 0] = u[:, 1]
85
- u[:, -1] = u[:, -2]
86
  elif bc == "periodic":
87
- # Note: A true periodic BC often uses ghost cells. This is a simpler implementation.
88
- u[0, :] = un[-2, :]
89
- u[-1, :] = un[1, :]
90
- u[:, 0] = un[:, -2]
91
- u[:, -1] = un[:, 1]
92
 
93
- # Store the frame if it's one of the selected time steps
94
- if n in time_indices:
95
- if store_idx < M:
96
- U_slider[store_idx] = u.copy()
97
- store_idx += 1
98
-
99
- # --- 3. Interactive Visualization with PyVista ---
100
- # Create a 2D mesh (a structured grid) in 3D space (with Z=0)
101
- grid = pv.StructuredGrid()
102
- grid.points = np.c_[X.flatten('F'), Y.flatten('F'), np.zeros(Nx * Ny)]
103
- grid.dimensions = [Nx, Ny, 1]
104
-
105
- # Add the initial temperature data as scalars to the grid
106
- # The data needs to be flattened in 'C' (row-major) order for PyVista
107
- grid['temperature'] = U_slider[0, :, :].flatten('C')
108
-
109
- # Set up the plotter
110
- plotter = pv.Plotter()
111
- plotter.add_mesh(grid, scalars='temperature', cmap='viridis',
112
- scalar_bar_args={'title': 'Temperature'})
113
- plotter.view_xy() # Set camera to look down the Z-axis
114
-
115
- # Define the callback function that updates the plot when the slider moves
116
- def update_plot(time_step_index):
117
- # Get the integer index from the slider
118
- idx = int(time_step_index)
119
- # Update the scalars on the grid
120
- grid['temperature'] = U_slider[idx, :, :].flatten('C')
121
- # Optional: Add a text annotation for the current time
122
- time_value = (time_indices[idx] / (Nt-1)) * t_max if Nt > 1 else 0
123
- plotter.add_text(f"Time: {time_value:.2f}s", name='time_label')
124
-
125
- # Add the slider widget to the plotter
126
- plotter.add_slider_widget(
127
- callback=update_plot,
128
- rng=[0, M - 1], # The slider range corresponds to the indices of U_slider
129
- value=0,
130
- title="Time Step",
131
- style='modern'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  )
133
 
134
- # Display the plotter window. This is a blocking call.
135
- # The script will pause here until you close the PyVista window.
136
- plotter.show()
137
 
138
 
139
  # --- Gradio Interface Function ---
140
  def gradio_interface(lx, ly, t_max, m_steps, gamma, nx, ny, initial, bc):
141
- """Wrapper function to connect Gradio inputs to the simulation."""
142
- # Ensure integer types for grid dimensions
143
  nx, ny, m_steps = int(nx), int(ny), int(m_steps)
144
-
145
- # Run the simulation and launch the interactive plot
146
- solve_and_plot_interactive(
147
  Lx=lx, Ly=ly, t_max=t_max, M=m_steps, Gamma=gamma, Nx=nx, Ny=ny,
148
  initial=initial, bc=bc
149
  )
150
- # This function no longer needs to return anything to Gradio
151
- return "Simulation window launched. Please check your desktop."
152
 
153
 
154
  # --- Gradio UI Definition ---
155
  with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
156
  gr.Markdown("# ♨️ 2D Heat Equation Simulator")
157
- gr.Markdown("Adjust parameters and click 'Run' to launch an interactive window where you can pan, zoom, and scrub through time.")
158
 
159
  with gr.Row():
160
  with gr.Column(scale=1):
161
- gr.Markdown("## Domain & Grid")
162
  lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx (Domain Length X)")
163
  ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly (Domain Length Y)")
164
- nx_slider = gr.Slider(10, 200, 50, 1, label="Nx (Grid Points X)")
165
- ny_slider = gr.Slider(10, 200, 50, 1, label="Ny (Grid Points Y)")
166
-
167
- gr.Markdown("## Simulation")
168
- t_slider = gr.Slider(0.01, 5.0, 0.5, 0.01, label="t_max (Total Time)")
169
  gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma (Diffusivity)")
170
-
171
- gr.Markdown("## Conditions")
172
- initial_dropdown = gr.Dropdown(
173
- ["gaussian", "random", "sinusoidal", "step"], value="gaussian", label="Initial Condition"
174
- )
175
- bc_dropdown = gr.Dropdown(
176
- ["dirichlet", "neumann", "periodic"], value="dirichlet", label="Boundary Condition"
177
- )
178
-
179
- gr.Markdown("## Interactive Plot")
180
- # New slider to control the number of time steps in the interactive window
181
- m_slider = gr.Slider(2, 200, 40, 1, label="M (Time Steps on Slider)")
182
 
183
  run_btn = gr.Button("Run Simulation", variant="primary")
184
 
185
- with gr.Column(scale=2):
186
- # The output is now a simple text confirmation
187
- status_output = gr.Textbox(label="Status")
188
- gr.Markdown(
189
- """
190
- ### How to Use:
191
- 1. Set your desired simulation parameters on the left.
192
- 2. `M (Time Steps on Slider)` controls how many time points will be available in the interactive view. Higher values give smoother time control but use more memory.
193
- 3. Click **Run Simulation**.
194
- 4. A new window will open.
195
- - **Left-Click + Drag**: Rotate the view.
196
- - **Right-Click + Drag**: Pan the view.
197
- - **Scroll Wheel**: Zoom in and out.
198
- - **Use the Slider**: To move through simulation time.
199
- 5. Close the interactive window to run another simulation.
200
- """
201
- )
202
-
203
 
204
- # Connect the button to the interface function
205
  inputs_list = [lx_slider, ly_slider, t_slider, m_slider, gamma_slider,
206
  nx_slider, ny_slider, initial_dropdown, bc_dropdown]
207
 
208
- run_btn.click(fn=gradio_interface, inputs=inputs_list, outputs=status_output)
209
 
210
- # Define some example configurations
211
  gr.Examples(
212
  examples=[
213
- [1.0, 1.0, 0.5, 50, 0.1, 50, 50, "gaussian", "dirichlet"],
214
- [2.0, 1.0, 1.0, 80, 0.05, 60, 30, "sinusoidal", "periodic"],
215
- [1.0, 1.0, 0.2, 40, 0.2, 80, 80, "step", "neumann"],
216
  ],
217
  inputs=inputs_list,
218
- outputs=[status_output],
219
  fn=gradio_interface,
220
- cache_examples=False # It's better to rerun live simulations
221
  )
222
 
223
  if __name__ == "__main__":
 
1
  import numpy as np
2
  import pyvista as pv
3
+ import plotly.graph_objects as go
4
  import gradio as gr
5
+ from scipy.spatial import Delaunay
6
 
 
7
  def solve_and_plot_interactive(Lx: float,
8
  Ly: float,
9
  t_max: float,
10
+ M: int,
11
  Gamma: float = 0.1,
12
  Nx: int = 50,
13
  Ny: int = 50,
14
  initial: str = "gaussian",
15
  bc: str = "dirichlet"):
16
  """
17
+ Solves the 2D heat equation and returns an interactive Plotly figure
18
+ that can be rendered in a web browser.
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
+ # --- 1. Simulation Setup (Same as before) ---
 
21
  x = np.linspace(0, Lx, Nx)
22
  y = np.linspace(0, Ly, Ny)
23
  dx, dy = x[1] - x[0], y[1] - y[0]
 
 
 
 
 
24
  dt = 0.9 / (2 * Gamma * (1/dx**2 + 1/dy**2))
25
  Nt = int(np.ceil(t_max / dt)) + 1
26
  rx, ry = Gamma * dt / dx**2, Gamma * dt / dy**2
 
 
27
  X, Y = np.meshgrid(x, y, indexing='ij')
28
+
29
  u = np.zeros((Nx, Ny))
30
  if initial == "gaussian":
31
  u = np.exp(-(((X - Lx/2)**2 + (Y - Ly/2)**2) / (2*(Lx/10)**2)))
 
36
  u = np.sin(kx * X) * np.sin(ky * Y)
37
  elif initial == "step":
38
  u = np.where((X < Lx/2) & (Y < Ly/2), 1.0, 0.0)
 
 
39
 
40
+ # --- 2. Solve the Heat Equation (Same as before) ---
 
41
  time_indices = np.linspace(0, Nt - 1, M, dtype=int)
42
  U_slider = np.zeros((M, Nx, Ny))
43
  store_idx = 0
 
44
  if 0 in time_indices:
45
  U_slider[store_idx] = u.copy()
46
  store_idx += 1
47
 
 
48
  for n in range(1, Nt):
49
  un = u.copy()
 
50
  u[1:-1, 1:-1] = (
51
  un[1:-1, 1:-1]
52
  + rx * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[:-2, 1:-1])
53
  + ry * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, :-2])
54
  )
 
55
  if bc == "dirichlet":
56
  u[0, :] = u[-1, :] = u[:, 0] = u[:, -1] = 0.0
57
  elif bc == "neumann":
58
+ u[0, :], u[-1, :], u[:, 0], u[:, -1] = u[1, :], u[-2, :], u[:, 1], u[:, -2]
 
 
 
59
  elif bc == "periodic":
60
+ u[0, :], u[-1, :], u[:, 0], u[:, -1] = un[-2, :], un[1, :], un[:, -2], un[:, 1]
 
 
 
 
61
 
62
+ if n in time_indices and store_idx < M:
63
+ U_slider[store_idx] = u.copy()
64
+ store_idx += 1
65
+
66
+ # --- 3. Create a Plotly Figure for Web ---
67
+ # We use Delaunay triangulation to create the mesh faces for Plotly
68
+ points_2d = np.vstack([X.ravel(), Y.ravel()]).T
69
+ tri = Delaunay(points_2d)
70
+
71
+ # Create the figure
72
+ fig = go.Figure()
73
+
74
+ # Add one mesh trace for each time step. We'll make only the first one visible.
75
+ for i in range(M):
76
+ time_value = (time_indices[i] / (Nt-1)) * t_max if Nt > 1 else 0
77
+ z_data = U_slider[i, :, :].flatten()
78
+ fig.add_trace(
79
+ go.Mesh3d(
80
+ x=X.flatten(), y=Y.flatten(), z=z_data,
81
+ i=tri.simplices[:, 0], j=tri.simplices[:, 1], k=tri.simplices[:, 2],
82
+ intensity=z_data,
83
+ colorscale='Viridis',
84
+ name=f'Time: {time_value:.2f}s',
85
+ showscale=True if i == 0 else False, # Show colorbar only once
86
+ visible=(i == 0) # Make only the first trace visible
87
+ )
88
+ )
89
+
90
+ # Create the slider
91
+ steps = []
92
+ for i in range(len(fig.data)):
93
+ time_value = (time_indices[i] / (Nt-1)) * t_max if Nt > 1 else 0
94
+ step = dict(
95
+ method="update",
96
+ args=[{"visible": [False] * len(fig.data)}], # hide all traces
97
+ label=f"{time_value:.2f}s"
98
+ )
99
+ step["args"][0]["visible"][i] = True # show the i-th trace
100
+ steps.append(step)
101
+
102
+ sliders = [dict(
103
+ active=0,
104
+ currentvalue={"prefix": "Time: "},
105
+ pad={"t": 50},
106
+ steps=steps
107
+ )]
108
+
109
+ # Update the layout of the figure
110
+ fig.update_layout(
111
+ title=f'2D Heat Eq — init={initial}, bc={bc}',
112
+ scene=dict(
113
+ xaxis_title='X',
114
+ yaxis_title='Y',
115
+ zaxis_title='Temperature'
116
+ ),
117
+ sliders=sliders
118
  )
119
 
120
+ return fig
 
 
121
 
122
 
123
  # --- Gradio Interface Function ---
124
  def gradio_interface(lx, ly, t_max, m_steps, gamma, nx, ny, initial, bc):
 
 
125
  nx, ny, m_steps = int(nx), int(ny), int(m_steps)
126
+ # This function now returns a Plotly figure object
127
+ return solve_and_plot_interactive(
 
128
  Lx=lx, Ly=ly, t_max=t_max, M=m_steps, Gamma=gamma, Nx=nx, Ny=ny,
129
  initial=initial, bc=bc
130
  )
 
 
131
 
132
 
133
  # --- Gradio UI Definition ---
134
  with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
135
  gr.Markdown("# ♨️ 2D Heat Equation Simulator")
136
+ gr.Markdown("Adjust parameters and click 'Run' to generate an interactive plot directly in your browser.")
137
 
138
  with gr.Row():
139
  with gr.Column(scale=1):
140
+ gr.Markdown("## Simulation Parameters")
141
  lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx (Domain Length X)")
142
  ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly (Domain Length Y)")
143
+ nx_slider = gr.Slider(10, 100, 40, 1, label="Nx (Grid Points X)") # Reduced max for performance
144
+ ny_slider = gr.Slider(10, 100, 40, 1, label="Ny (Grid Points Y)")
145
+ t_slider = gr.Slider(0.01, 2.0, 0.5, 0.01, label="t_max (Total Time)")
 
 
146
  gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma (Diffusivity)")
147
+ m_slider = gr.Slider(10, 100, 30, 1, label="M (Time Steps on Slider)") # Reduced max
148
+
149
+ with gr.Row():
150
+ initial_dropdown = gr.Dropdown(["gaussian", "random", "sinusoidal", "step"], value="gaussian", label="Initial")
151
+ bc_dropdown = gr.Dropdown(["dirichlet", "neumann", "periodic"], value="dirichlet", label="Boundary")
 
 
 
 
 
 
 
152
 
153
  run_btn = gr.Button("Run Simulation", variant="primary")
154
 
155
+ with gr.Column(scale=3):
156
+ # The output is now a gr.Plot component that will render the Plotly figure
157
+ plot_output = gr.Plot(label="Interactive Heatmap")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
 
159
  inputs_list = [lx_slider, ly_slider, t_slider, m_slider, gamma_slider,
160
  nx_slider, ny_slider, initial_dropdown, bc_dropdown]
161
 
162
+ run_btn.click(fn=gradio_interface, inputs=inputs_list, outputs=plot_output)
163
 
 
164
  gr.Examples(
165
  examples=[
166
+ [1.0, 1.0, 0.5, 30, 0.1, 40, 40, "gaussian", "dirichlet"],
167
+ [2.0, 1.0, 1.0, 50, 0.05, 50, 25, "sinusoidal", "periodic"],
 
168
  ],
169
  inputs=inputs_list,
170
+ outputs=plot_output,
171
  fn=gradio_interface,
172
+ cache_examples=True # Caching is fine for Plotly objects
173
  )
174
 
175
  if __name__ == "__main__":