harishaseebat92 commited on
Commit
3d367f2
·
verified ·
1 Parent(s): c7b125a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -65
app.py CHANGED
@@ -2,7 +2,7 @@ import numpy as np
2
  import gradio as gr
3
  import plotly.graph_objects as go
4
 
5
- # --- 1. Simulation Core (Modified to return raw data) ---
6
  def solve_2d_heat_equation(Lx: float,
7
  Ly: float,
8
  t_max: float,
@@ -11,23 +11,15 @@ def solve_2d_heat_equation(Lx: float,
11
  Ny: int = 50,
12
  initial: str = "gaussian",
13
  bc: str = "dirichlet"):
14
- """
15
- Solves the 2D heat equation and returns the entire time-series data.
16
- """
17
- # Spatial grid
18
  x = np.linspace(0, Lx, Nx)
19
  y = np.linspace(0, Ly, Ny)
20
  dx, dy = x[1] - x[0], y[1] - y[0]
21
  if dx == 0 or dy == 0:
22
  raise ValueError("Nx and Ny must be > 1.")
23
-
24
- # Time stepping for stability
25
- # A small factor (0.9) is added for extra stability margin
26
  dt = 0.9 / (2 * Gamma * (1/dx**2 + 1/dy**2))
27
  Nt = int(np.ceil(t_max / dt)) + 1
28
  rx, ry = Gamma * dt / dx**2, Gamma * dt / dy**2
29
-
30
- # Initial condition
31
  X, Y = np.meshgrid(x, y, indexing='ij')
32
  if initial == "gaussian":
33
  u = np.exp(-(((X - Lx/2)**2 + (Y - Ly/2)**2) / (2*(Lx/10)**2)))
@@ -40,22 +32,15 @@ def solve_2d_heat_equation(Lx: float,
40
  u = np.where((X < Lx/2) & (Y < Ly/2), 1.0, 0.0)
41
  else:
42
  raise ValueError(f"Unknown initial condition: {initial}")
43
-
44
- # Storage for solution
45
  U = np.zeros((Nt, Nx, Ny))
46
  U[0] = u.copy()
47
-
48
- # Time-stepping loop
49
  for n in range(1, Nt):
50
  un = u.copy()
51
- # Interior update using a vectorized operation
52
  u[1:-1, 1:-1] = (
53
  un[1:-1, 1:-1]
54
  + rx * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[:-2, 1:-1])
55
  + ry * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, :-2])
56
  )
57
-
58
- # Boundary conditions
59
  if bc == "dirichlet":
60
  u[0, :] = u[-1, :] = u[:, 0] = u[:, -1] = 0.0
61
  elif bc == "neumann":
@@ -64,7 +49,6 @@ def solve_2d_heat_equation(Lx: float,
64
  u[:, 0] = u[:, 1]
65
  u[:, -1] = u[:, -2]
66
  elif bc == "periodic":
67
- # Implement periodic boundary conditions correctly
68
  u[0, :] = un[-2, :]
69
  u[-1, :] = un[1, :]
70
  u[:, 0] = un[:, -2]
@@ -72,31 +56,21 @@ def solve_2d_heat_equation(Lx: float,
72
  else:
73
  raise ValueError(f"Unknown bc: {bc}")
74
  U[n] = u.copy()
75
-
76
- # Return the full data history and the time step size
77
  return U, dt
78
 
79
- # --- 2. Plotly Animation Generator (New Function) ---
80
  def create_plotly_animation(U, Lx, Ly, initial, bc, Gamma, frame_skip, dt):
81
- """Creates an interactive Plotly animation from the raw simulation data."""
82
  Nt, Nx, Ny = U.shape
83
  vmin, vmax = U.min(), U.max()
84
-
85
- # Create a list of frame indices to animate based on frame_skip
86
  idx = list(range(0, Nt, frame_skip))
87
- if not idx or idx[-1] != Nt - 1: # Ensure the last frame is always included
88
  idx.append(Nt - 1)
89
-
90
- # Subsample the data for animation frames
91
  frames_to_animate_data = U[idx]
92
-
93
- # Create the figure with frames
94
  fig = go.Figure(
95
  frames=[go.Frame(data=go.Heatmap(z=frame_data.T, zmin=vmin, zmax=vmax), name=str(i))
96
  for i, frame_data in enumerate(frames_to_animate_data)]
97
  )
98
-
99
- # Add the initial heatmap trace (frame 0) that will be displayed first
100
  fig.add_trace(go.Heatmap(
101
  z=frames_to_animate_data[0].T,
102
  colorscale='viridis',
@@ -104,8 +78,6 @@ def create_plotly_animation(U, Lx, Ly, initial, bc, Gamma, frame_skip, dt):
104
  zmax=vmax,
105
  colorbar=dict(title="u")
106
  ))
107
-
108
- # Create and configure the play/pause button
109
  fig.update_layout(
110
  updatemenus=[dict(
111
  type="buttons",
@@ -124,8 +96,6 @@ def create_plotly_animation(U, Lx, Ly, initial, bc, Gamma, frame_skip, dt):
124
  "mode": "immediate", "transition": {"duration": 0}}])]
125
  )]
126
  )
127
-
128
- # Create and configure the frame slider
129
  sliders = [dict(
130
  active=0,
131
  yanchor="top",
@@ -144,45 +114,40 @@ def create_plotly_animation(U, Lx, Ly, initial, bc, Gamma, frame_skip, dt):
144
  args=[[f.name], {"frame": {"duration": 0, "redraw": True},
145
  "mode": "immediate",
146
  "transition": {"duration": 0}}],
147
- label=f"{idx[i]*dt:.2f}s", # Display time in seconds on slider
148
  method="animate")
149
  for i, f in enumerate(fig.frames)]
150
  )]
151
-
152
  fig.update_layout(
153
  title=f"2D Heat Eq — init={initial}, bc={bc}, Gamma={Gamma:.2f}",
154
  xaxis_title="x",
155
  yaxis_title="y",
156
  sliders=sliders,
157
- # Set aspect ratio to match the domain
158
  yaxis=dict(scaleanchor="x", scaleratio=Ly/Lx if Lx > 0 else 1)
159
  )
160
-
161
  return fig
162
 
163
-
164
- # --- 3. Gradio Interface Logic (Modified to connect new functions) ---
165
  def gradio_interface(lx, ly, t_max, gamma, nx, ny, initial, bc, frame_skip):
166
  """Main function for the Gradio interface."""
167
  nx, ny, frame_skip = int(nx), int(ny), int(frame_skip)
168
-
169
- # Call the solver to get the raw simulation data and time step
170
- U, dt = solve_2d_heat_equation(
171
- Lx=lx, Ly=ly, t_max=t_max, Gamma=gamma, Nx=nx, Ny=ny,
172
- initial=initial, bc=bc
173
- ) # <<< THIS IS THE CORRECTED LINE with the closing parenthesis
174
-
175
- # Create the Plotly figure from the data
176
- fig = create_plotly_animation(
177
- U=U, Lx=lx, Ly=ly, initial=initial, bc=bc, Gamma=gamma,
178
- frame_skip=frame_skip, dt=dt
179
- )
180
- return fig
181
 
182
- # --- 4. Gradio UI Layout (Modified to use gr.Plot) ---
183
  with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
184
  gr.Markdown("# ♨️ Interactive 2D Heat Equation Simulator\nAdjust parameters and run the simulation.")
185
-
186
  with gr.Row():
187
  with gr.Column(scale=1):
188
  gr.Markdown("## Domain & Grid")
@@ -190,11 +155,9 @@ with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
190
  ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly")
191
  nx_slider = gr.Slider(3, 200, 50, 1, label="Nx")
192
  ny_slider = gr.Slider(3, 200, 50, 1, label="Ny")
193
-
194
  gr.Markdown("## Simulation")
195
  t_slider = gr.Slider(0.01, 5.0, 0.5, 0.01, label="t_max")
196
  gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma")
197
-
198
  gr.Markdown("## Conditions")
199
  initial_dropdown = gr.Dropdown(
200
  ["gaussian", "random", "sinusoidal", "step"], "gaussian", label="Initial"
@@ -202,20 +165,15 @@ with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
202
  bc_dropdown = gr.Dropdown(
203
  ["dirichlet", "neumann", "periodic"], "dirichlet", label="Boundary"
204
  )
205
-
206
  gr.Markdown("## Animation")
207
  frame_skip_slider = gr.Slider(1, 50, 5, 1, label="Frame Skip")
208
  run_btn = gr.Button("Run Simulation", variant="primary")
209
-
210
  with gr.Column(scale=3):
211
  gr.Markdown("### Interactive Heatmap Animation\nUse the play/pause buttons, drag the slider, or use your mouse/trackpad to zoom and pan the plot.")
212
- plot_output = gr.Plot(label="Interactive Heatmap") # Changed from gr.Image to gr.Plot
213
-
214
  inputs_list = [lx_slider, ly_slider, t_slider, gamma_slider,
215
  nx_slider, ny_slider, initial_dropdown, bc_dropdown, frame_skip_slider]
216
-
217
  run_btn.click(fn=gradio_interface, inputs=inputs_list, outputs=plot_output)
218
-
219
  gr.Examples(
220
  examples=[
221
  [1.0, 1.0, 0.5, 0.1, 50, 50, "gaussian", "dirichlet", 5],
@@ -228,5 +186,4 @@ with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
228
  )
229
 
230
  if __name__ == "__main__":
231
- # To run this, you will need to install plotly: pip install plotly
232
  demo.launch()
 
2
  import gradio as gr
3
  import plotly.graph_objects as go
4
 
5
+ # --- 1. Simulation Core ---
6
  def solve_2d_heat_equation(Lx: float,
7
  Ly: float,
8
  t_max: float,
 
11
  Ny: int = 50,
12
  initial: str = "gaussian",
13
  bc: str = "dirichlet"):
14
+ # [Unchanged code from original]
 
 
 
15
  x = np.linspace(0, Lx, Nx)
16
  y = np.linspace(0, Ly, Ny)
17
  dx, dy = x[1] - x[0], y[1] - y[0]
18
  if dx == 0 or dy == 0:
19
  raise ValueError("Nx and Ny must be > 1.")
 
 
 
20
  dt = 0.9 / (2 * Gamma * (1/dx**2 + 1/dy**2))
21
  Nt = int(np.ceil(t_max / dt)) + 1
22
  rx, ry = Gamma * dt / dx**2, Gamma * dt / dy**2
 
 
23
  X, Y = np.meshgrid(x, y, indexing='ij')
24
  if initial == "gaussian":
25
  u = np.exp(-(((X - Lx/2)**2 + (Y - Ly/2)**2) / (2*(Lx/10)**2)))
 
32
  u = np.where((X < Lx/2) & (Y < Ly/2), 1.0, 0.0)
33
  else:
34
  raise ValueError(f"Unknown initial condition: {initial}")
 
 
35
  U = np.zeros((Nt, Nx, Ny))
36
  U[0] = u.copy()
 
 
37
  for n in range(1, Nt):
38
  un = u.copy()
 
39
  u[1:-1, 1:-1] = (
40
  un[1:-1, 1:-1]
41
  + rx * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[:-2, 1:-1])
42
  + ry * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, :-2])
43
  )
 
 
44
  if bc == "dirichlet":
45
  u[0, :] = u[-1, :] = u[:, 0] = u[:, -1] = 0.0
46
  elif bc == "neumann":
 
49
  u[:, 0] = u[:, 1]
50
  u[:, -1] = u[:, -2]
51
  elif bc == "periodic":
 
52
  u[0, :] = un[-2, :]
53
  u[-1, :] = un[1, :]
54
  u[:, 0] = un[:, -2]
 
56
  else:
57
  raise ValueError(f"Unknown bc: {bc}")
58
  U[n] = u.copy()
 
 
59
  return U, dt
60
 
61
+ # --- 2. Plotly Animation Generator ---
62
  def create_plotly_animation(U, Lx, Ly, initial, bc, Gamma, frame_skip, dt):
63
+ # [Unchanged code from original]
64
  Nt, Nx, Ny = U.shape
65
  vmin, vmax = U.min(), U.max()
 
 
66
  idx = list(range(0, Nt, frame_skip))
67
+ if not idx or idx[-1] != Nt - 1:
68
  idx.append(Nt - 1)
 
 
69
  frames_to_animate_data = U[idx]
 
 
70
  fig = go.Figure(
71
  frames=[go.Frame(data=go.Heatmap(z=frame_data.T, zmin=vmin, zmax=vmax), name=str(i))
72
  for i, frame_data in enumerate(frames_to_animate_data)]
73
  )
 
 
74
  fig.add_trace(go.Heatmap(
75
  z=frames_to_animate_data[0].T,
76
  colorscale='viridis',
 
78
  zmax=vmax,
79
  colorbar=dict(title="u")
80
  ))
 
 
81
  fig.update_layout(
82
  updatemenus=[dict(
83
  type="buttons",
 
96
  "mode": "immediate", "transition": {"duration": 0}}])]
97
  )]
98
  )
 
 
99
  sliders = [dict(
100
  active=0,
101
  yanchor="top",
 
114
  args=[[f.name], {"frame": {"duration": 0, "redraw": True},
115
  "mode": "immediate",
116
  "transition": {"duration": 0}}],
117
+ label=f"{idx[i]*dt:.2f}s",
118
  method="animate")
119
  for i, f in enumerate(fig.frames)]
120
  )]
 
121
  fig.update_layout(
122
  title=f"2D Heat Eq — init={initial}, bc={bc}, Gamma={Gamma:.2f}",
123
  xaxis_title="x",
124
  yaxis_title="y",
125
  sliders=sliders,
 
126
  yaxis=dict(scaleanchor="x", scaleratio=Ly/Lx if Lx > 0 else 1)
127
  )
 
128
  return fig
129
 
130
+ # --- 3. Gradio Interface Logic ---
 
131
  def gradio_interface(lx, ly, t_max, gamma, nx, ny, initial, bc, frame_skip):
132
  """Main function for the Gradio interface."""
133
  nx, ny, frame_skip = int(nx), int(ny), int(frame_skip)
134
+ try:
135
+ U, dt = solve_2d_heat_equation(
136
+ Lx=lx, Ly=ly, t_max=t_max, Gamma=gamma, Nx=nx, Ny=ny,
137
+ initial=initial, bc=bc
138
+ )
139
+ fig = create_plotly_animation(
140
+ U=U, Lx=lx, Ly=ly, initial=initial, bc=bc, Gamma=gamma,
141
+ frame_skip=frame_skip, dt=dt
142
+ )
143
+ return fig
144
+ except Exception as e:
145
+ print(f"Error in simulation: {e}")
146
+ return None
147
 
148
+ # --- 4. Gradio UI Layout ---
149
  with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
150
  gr.Markdown("# ♨️ Interactive 2D Heat Equation Simulator\nAdjust parameters and run the simulation.")
 
151
  with gr.Row():
152
  with gr.Column(scale=1):
153
  gr.Markdown("## Domain & Grid")
 
155
  ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly")
156
  nx_slider = gr.Slider(3, 200, 50, 1, label="Nx")
157
  ny_slider = gr.Slider(3, 200, 50, 1, label="Ny")
 
158
  gr.Markdown("## Simulation")
159
  t_slider = gr.Slider(0.01, 5.0, 0.5, 0.01, label="t_max")
160
  gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma")
 
161
  gr.Markdown("## Conditions")
162
  initial_dropdown = gr.Dropdown(
163
  ["gaussian", "random", "sinusoidal", "step"], "gaussian", label="Initial"
 
165
  bc_dropdown = gr.Dropdown(
166
  ["dirichlet", "neumann", "periodic"], "dirichlet", label="Boundary"
167
  )
 
168
  gr.Markdown("## Animation")
169
  frame_skip_slider = gr.Slider(1, 50, 5, 1, label="Frame Skip")
170
  run_btn = gr.Button("Run Simulation", variant="primary")
 
171
  with gr.Column(scale=3):
172
  gr.Markdown("### Interactive Heatmap Animation\nUse the play/pause buttons, drag the slider, or use your mouse/trackpad to zoom and pan the plot.")
173
+ plot_output = gr.Plot(label="Interactive Heatmap")
 
174
  inputs_list = [lx_slider, ly_slider, t_slider, gamma_slider,
175
  nx_slider, ny_slider, initial_dropdown, bc_dropdown, frame_skip_slider]
 
176
  run_btn.click(fn=gradio_interface, inputs=inputs_list, outputs=plot_output)
 
177
  gr.Examples(
178
  examples=[
179
  [1.0, 1.0, 0.5, 0.1, 50, 50, "gaussian", "dirichlet", 5],
 
186
  )
187
 
188
  if __name__ == "__main__":
 
189
  demo.launch()