harishaseebat92 commited on
Commit
64d58d4
·
verified ·
1 Parent(s): 9f5a2d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -171
app.py CHANGED
@@ -1,85 +1,60 @@
1
  import numpy as np
2
- import plotly.graph_objects as go # Changed from matplotlib
3
  import gradio as gr
4
- # Removed: matplotlib.pyplot, matplotlib.animation.FuncAnimation, tempfile
5
-
6
- def solve_and_create_plotly_animation(Lx: float,
7
- Ly: float,
8
- t_max: float,
9
- Gamma: float = 0.1,
10
- Nx: int = 50,
11
- Ny: int = 50,
12
- initial: str = "gaussian",
13
- bc: str = "dirichlet",
14
- frame_skip: int = 1):
15
  """
16
- Solve the 2D heat equation u_t = Gamma*(u_xx + u_yy) and return an interactive Plotly animation.
17
- Initial conditions: {"gaussian", "random", "sinusoidal", "step"}
18
- Boundary conditions: {"dirichlet", "neumann", "periodic"}
19
  """
20
  # Spatial grid
21
- x_coords = np.linspace(0, Lx, Nx)
22
- y_coords = np.linspace(0, Ly, Ny)
23
- dx, dy = x_coords[1] - x_coords, y_coords[1] - y_coords
24
-
25
- if dx == 0 or dy == 0: # Should not happen with Nx, Ny >= 3 from Gradio sliders
26
- # Create an empty figure with an error message for Gradio
27
- fig = go.Figure()
28
- fig.update_layout(title_text="Error: Nx and Ny must be > 1 for dx, dy calculation.",
29
- xaxis_showticklabels=False, yaxis_showticklabels=False)
30
- return fig
31
-
32
 
33
  # Time stepping for stability
34
- # Ensure dt is positive, handle potential division by zero if Gamma is zero or dx/dy are problematic
35
- denominator = (2 * Gamma * (1/dx**2 + 1/dy**2))
36
- if denominator <= 0: # Avoid division by zero or negative dt
37
- fig = go.Figure()
38
- fig.update_layout(title_text=f"Error: Unstable dt parameters (Gamma={Gamma}, dx={dx}, dy={dy}). Check Gamma.",
39
- xaxis_showticklabels=False, yaxis_showticklabels=False)
40
- return fig
41
- dt = 1.0 / denominator
42
-
43
  Nt = int(np.ceil(t_max / dt)) + 1
44
- if Nt <=1: # Ensure there's at least an initial and one computed frame for animation
45
- Nt = 2
46
-
47
  rx, ry = Gamma * dt / dx**2, Gamma * dt / dy**2
48
 
49
  # Initial condition
50
- X, Y = np.meshgrid(x_coords, y_coords, indexing='ij') # Use x_coords, y_coords
51
- u = np.zeros((Nx, Ny)) # Initialize u
52
  if initial == "gaussian":
53
  u = np.exp(-(((X - Lx/2)**2 + (Y - Ly/2)**2) / (2*(Lx/10)**2)))
54
  elif initial == "random":
55
  u = np.random.rand(Nx, Ny)
56
  elif initial == "sinusoidal":
57
- # Ensure Lx and Ly are not zero to avoid division by zero
58
- kx = 2 * np.pi / Lx if Lx > 0 else 0
59
- ky = 2 * np.pi / Ly if Ly > 0 else 0
60
  u = np.sin(kx * X) * np.sin(ky * Y)
61
  elif initial == "step":
62
  u = np.where((X < Lx/2) & (Y < Ly/2), 1.0, 0.0)
63
- else: # Should not happen due to dropdown
64
- fig = go.Figure()
65
- fig.update_layout(title_text=f"Error: Unknown initial condition: {initial}",
66
- xaxis_showticklabels=False, yaxis_showticklabels=False)
67
- return fig
68
 
69
- # Storage for solution frames to be animated
70
- # We will store only the frames selected by frame_skip
71
- frames_to_animate_data =
72
- frames_to_animate_data.append(u.copy().T) # Transpose for heatmap like imshow(origin='lower')
73
 
74
  # Time-stepping loop
75
  for n in range(1, Nt):
76
  un = u.copy()
77
- # Interior update
78
  u[1:-1, 1:-1] = (
79
  un[1:-1, 1:-1]
80
  + rx * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[:-2, 1:-1])
81
  + ry * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, :-2])
82
  )
 
83
  # Boundary conditions
84
  if bc == "dirichlet":
85
  u[0, :] = u[-1, :] = u[:, 0] = u[:, -1] = 0.0
@@ -89,168 +64,169 @@ def solve_and_create_plotly_animation(Lx: float,
89
  u[:, 0] = u[:, 1]
90
  u[:, -1] = u[:, -2]
91
  elif bc == "periodic":
92
- # Ensure indices are valid for periodic BCs (Nx, Ny >= 3)
93
- if Nx >= 3 and Ny >=3:
94
- u[0, :] = un[-2, :]
95
- u[-1, :] = un[1, :]
96
- u[:, 0] = un[:, -2]
97
- u[:, -1] = un[:, 1]
98
- else: # Fallback for very small grids where periodic BCs are ill-defined
99
- u[0, :] = u[1, :] # Effectively Neumann
100
- u[-1, :] = u[-2, :]
101
- u[:, 0] = u[:, 1]
102
- u[:, -1] = u[:, -2]
103
- else: # Should not happen
104
- fig = go.Figure()
105
- fig.update_layout(title_text=f"Error: Unknown bc: {bc}",
106
- xaxis_showticklabels=False, yaxis_showticklabels=False)
107
- return fig
108
-
109
- if n % frame_skip == 0 or n == Nt -1: # Store frame based on frame_skip
110
- frames_to_animate_data.append(u.copy().T) # Transpose
111
-
112
- if not frames_to_animate_data:
113
- fig = go.Figure()
114
- fig.update_layout(title_text="Error: No frames generated for animation.",
115
- xaxis_showticklabels=False, yaxis_showticklabels=False)
116
- return fig
117
-
118
- # Determine global min/max for consistent colorscale
119
- global_min = np.min([np.min(frame) for frame in frames_to_animate_data])
120
- global_max = np.max([np.max(frame) for frame in frames_to_animate_data])
121
- if global_min == global_max: # Avoid issues if all values are the same
122
- global_max += 1e-6 if global_max == 0 else abs(global_max * 0.01) # Add a tiny epsilon
123
-
124
- # Create Plotly animation
125
  fig = go.Figure(
126
- data=[go.Heatmap(
127
- z=frames_to_animate_data,
128
- x=x_coords, # Use actual coordinates for axes
129
- y=y_coords, # Use actual coordinates for axes
130
- colorscale='Viridis',
131
- zmin=global_min,
132
- zmax=global_max,
133
- showscale=True,
134
- colorbar_title_text="u"
135
- )],
136
- layout=go.Layout(
137
- title_text=f"2D Heat Eq: init={initial}, bc={bc}, Gamma={Gamma:.3f}",
138
- xaxis_title_text="x",
139
- yaxis_title_text="y",
140
- # yaxis_scaleanchor="x", # Makes pixels square if dx=dy and Lx=Ly
141
- font=dict(size=10),
142
- # Ensure plot updates don't change axis ranges during animation
143
- xaxis_range=[0, Lx],
144
- yaxis_range=[0, Ly],
145
- ),
146
- frames=[
147
- go.Frame(
148
- data=[go.Heatmap(
149
- z=frame_data,
150
- x=x_coords,
151
- y=y_coords,
152
- colorscale='Viridis',
153
- zmin=global_min,
154
- zmax=global_max
155
- )],
156
- name=f"frame{k}"
157
- ) for k, frame_data in enumerate(frames_to_animate_data)
158
- ]
159
  )
160
 
161
- # Add animation control buttons
 
 
 
 
 
 
 
 
 
162
  fig.update_layout(
163
- updatemenus=), # No transition smoothing
164
- dict(label="Pause",
165
- method="animate",
166
- args=[[None], {"frame": {"duration": 0, "redraw": False},
167
- "mode": "immediate",
168
- "transition": {"duration": 0}}])
169
- ],
170
  direction="left",
171
- pad={"r": 10, "t": 70}, # Adjust padding
172
- x=0.1, xanchor="left",
173
- y=0, yanchor="top"
 
 
 
 
 
 
 
 
 
174
  )]
175
  )
176
 
177
- # Add frame slider
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  fig.update_layout(
179
- sliders=[dict(
180
- active=0,
181
- steps=[
182
- dict(label=str(k),
183
- method="animate",
184
- args=[[f"frame{k}"],
185
- {"mode": "immediate",
186
- "frame": {"duration": 100, "redraw": True}, # Duration if played
187
- "transition": {"duration": 0}}]) # No transition for slider drag
188
- for k in range(len(frames_to_animate_data))
189
- ],
190
- currentvalue={"prefix": "Frame: ", "visible": True, "xanchor": "right"},
191
- pad={"t": 60, "b": 10} # Adjust padding
192
- )]
193
  )
 
194
  return fig
195
 
196
 
 
197
  def gradio_interface(lx, ly, t_max, gamma, nx, ny, initial, bc, frame_skip):
 
198
  nx, ny, frame_skip = int(nx), int(ny), int(frame_skip)
199
- # Call the new function that returns a Plotly figure
200
- return solve_and_create_plotly_animation(
 
201
  Lx=lx, Ly=ly, t_max=t_max, Gamma=gamma, Nx=nx, Ny=ny,
202
- initial=initial, bc=bc, frame_skip=frame_skip
 
 
 
 
 
 
203
  )
 
204
 
 
205
  with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
206
- gr.Markdown("# ♨️ 2D Heat Equation Simulator (Interactive with Plotly)\nAdjust parameters and run the simulation.")
 
207
  with gr.Row():
208
  with gr.Column(scale=1):
209
  gr.Markdown("## Domain & Grid")
210
  lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx")
211
  ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly")
212
- nx_slider = gr.Slider(3, 200, 50, 1, label="Nx (min 3)") # Min 3 for some BCs
213
- ny_slider = gr.Slider(3, 200, 50, 1, label="Ny (min 3)") # Min 3 for some BCs
 
214
  gr.Markdown("## Simulation")
215
  t_slider = gr.Slider(0.01, 5.0, 0.5, 0.01, label="t_max")
216
- gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma (Diffusion Coeff.)")
 
217
  gr.Markdown("## Conditions")
218
  initial_dropdown = gr.Dropdown(
219
- ["gaussian", "random", "sinusoidal", "step"], value="gaussian", label="Initial Condition"
220
  )
221
  bc_dropdown = gr.Dropdown(
222
- ["dirichlet", "neumann", "periodic"], value="dirichlet", label="Boundary Condition"
223
  )
 
224
  gr.Markdown("## Animation")
225
- frame_skip_slider = gr.Slider(1, 50, 5, 1, label="Frame Skip (Higher = Fewer Frames)")
226
  run_btn = gr.Button("Run Simulation", variant="primary")
 
227
  with gr.Column(scale=3):
228
- # Changed from gr.Image to gr.Plot
229
- plot_output = gr.Plot(label="Interactive Heatmap Animation")
230
 
231
  inputs_list = [lx_slider, ly_slider, t_slider, gamma_slider,
232
  nx_slider, ny_slider, initial_dropdown, bc_dropdown, frame_skip_slider]
233
 
234
  run_btn.click(fn=gradio_interface, inputs=inputs_list, outputs=plot_output)
235
-
236
  gr.Examples(
237
- examples=[1.0, 1.0, 0.5, 0.1, 50, 50, "gaussian", "dirichlet", 5],
 
238
  [2.0, 1.0, 1.0, 0.05, 60, 30, "sinusoidal", "periodic", 10],
239
  [1.0, 1.0, 0.2, 0.2, 80, 80, "step", "neumann", 2],
240
- [0.5, 0.5, 0.1, 0.5, 30, 30, "random", "dirichlet", 1],
241
  inputs=inputs_list,
242
- outputs=[plot_output], # Output is now a single Plot component
243
- fn=gradio_interface,
244
- # cache_examples=True # Consider enabling if simulations are slow and inputs are identical
245
  )
246
- gr.Markdown("### How to Interact:"
247
- "\n- **Play/Pause Buttons:** Control the animation playback (located below the plot title)."
248
- "\n- **Slider:** Drag to scrub through frames (located below the plot)."
249
- "\n- **Plotly Modebar (top right of plot on hover):**"
250
- "\n - **Zoom:** Use zoom tools (box zoom, zoom in/out icons, scroll wheel)."
251
- "\n - **Pan:** Use the pan tool to move the view when zoomed."
252
- "\n - **Autoscale/Reset:** Return to the default view or reset axes."
253
- )
254
 
255
  if __name__ == "__main__":
 
256
  demo.launch()
 
1
  import numpy as np
 
2
  import gradio as gr
3
+ import plotly.graph_objects as go
4
+
5
+ # --- 1. Simulation Core (Modified to return raw data) ---
6
+ def solve_2d_heat_equation(Lx: float,
7
+ Ly: float,
8
+ t_max: float,
9
+ Gamma: float = 0.1,
10
+ Nx: int = 50,
11
+ Ny: int = 50,
12
+ initial: str = "gaussian",
13
+ bc: str = "dirichlet"):
14
  """
15
+ Solves the 2D heat equation and returns the entire time-series data.
 
 
16
  """
17
  # Spatial grid
18
+ x = np.linspace(0, Lx, Nx)
19
+ y = np.linspace(0, Ly, Ny)
20
+ dx, dy = x[1] - x[0], y[1] - y[0]
21
+ if dx == 0 or dy == 0:
22
+ raise ValueError("Nx and Ny must be > 1.")
 
 
 
 
 
 
23
 
24
  # Time stepping for stability
25
+ # A small factor (0.9) is added for extra stability margin
26
+ dt = 0.9 / (2 * Gamma * (1/dx**2 + 1/dy**2))
 
 
 
 
 
 
 
27
  Nt = int(np.ceil(t_max / dt)) + 1
 
 
 
28
  rx, ry = Gamma * dt / dx**2, Gamma * dt / dy**2
29
 
30
  # Initial condition
31
+ X, Y = np.meshgrid(x, y, indexing='ij')
 
32
  if initial == "gaussian":
33
  u = np.exp(-(((X - Lx/2)**2 + (Y - Ly/2)**2) / (2*(Lx/10)**2)))
34
  elif initial == "random":
35
  u = np.random.rand(Nx, Ny)
36
  elif initial == "sinusoidal":
37
+ kx, ky = 2 * np.pi / Lx, 2 * np.pi / Ly
 
 
38
  u = np.sin(kx * X) * np.sin(ky * Y)
39
  elif initial == "step":
40
  u = np.where((X < Lx/2) & (Y < Ly/2), 1.0, 0.0)
41
+ else:
42
+ raise ValueError(f"Unknown initial condition: {initial}")
 
 
 
43
 
44
+ # Storage for solution
45
+ U = np.zeros((Nt, Nx, Ny))
46
+ U[0] = u.copy()
 
47
 
48
  # Time-stepping loop
49
  for n in range(1, Nt):
50
  un = u.copy()
51
+ # Interior update using a vectorized operation
52
  u[1:-1, 1:-1] = (
53
  un[1:-1, 1:-1]
54
  + rx * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[:-2, 1:-1])
55
  + ry * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, :-2])
56
  )
57
+
58
  # Boundary conditions
59
  if bc == "dirichlet":
60
  u[0, :] = u[-1, :] = u[:, 0] = u[:, -1] = 0.0
 
64
  u[:, 0] = u[:, 1]
65
  u[:, -1] = u[:, -2]
66
  elif bc == "periodic":
67
+ # Implement periodic boundary conditions correctly
68
+ u[0, :] = un[-2, :]
69
+ u[-1, :] = un[1, :]
70
+ u[:, 0] = un[:, -2]
71
+ u[:, -1] = un[:, 1]
72
+ else:
73
+ raise ValueError(f"Unknown bc: {bc}")
74
+ U[n] = u.copy()
75
+
76
+ # Return the full data history and the time step size
77
+ return U, dt
78
+
79
+ # --- 2. Plotly Animation Generator (New Function) ---
80
+ def create_plotly_animation(U, Lx, Ly, initial, bc, Gamma, frame_skip, dt):
81
+ """Creates an interactive Plotly animation from the raw simulation data."""
82
+ Nt, Nx, Ny = U.shape
83
+ vmin, vmax = U.min(), U.max()
84
+
85
+ # Create a list of frame indices to animate based on frame_skip
86
+ idx = list(range(0, Nt, frame_skip))
87
+ if not idx or idx[-1] != Nt - 1: # Ensure the last frame is always included
88
+ idx.append(Nt - 1)
89
+
90
+ # Subsample the data for animation frames
91
+ frames_to_animate_data = U[idx]
92
+
93
+ # Create the figure with frames
 
 
 
 
 
 
94
  fig = go.Figure(
95
+ frames=[go.Frame(data=go.Heatmap(z=frame_data.T, zmin=vmin, zmax=vmax), name=str(i))
96
+ for i, frame_data in enumerate(frames_to_animate_data)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  )
98
 
99
+ # Add the initial heatmap trace (frame 0) that will be displayed first
100
+ fig.add_trace(go.Heatmap(
101
+ z=frames_to_animate_data[0].T,
102
+ colorscale='viridis',
103
+ zmin=vmin,
104
+ zmax=vmax,
105
+ colorbar=dict(title="u")
106
+ ))
107
+
108
+ # Create and configure the play/pause button
109
  fig.update_layout(
110
+ updatemenus=[dict(
111
+ type="buttons",
 
 
 
 
 
112
  direction="left",
113
+ x=0.1,
114
+ xanchor="left",
115
+ y=1.15,
116
+ yanchor="top",
117
+ buttons=[dict(label="Play",
118
+ method="animate",
119
+ args=[None, {"frame": {"duration": 50, "redraw": True},
120
+ "fromcurrent": True, "transition": {"duration": 0}}]),
121
+ dict(label="Pause",
122
+ method="animate",
123
+ args=[[None], {"frame": {"duration": 0, "redraw": False},
124
+ "mode": "immediate", "transition": {"duration": 0}}])]
125
  )]
126
  )
127
 
128
+ # Create and configure the frame slider
129
+ sliders = [dict(
130
+ active=0,
131
+ yanchor="top",
132
+ xanchor="left",
133
+ currentvalue=dict(
134
+ font=dict(size=16),
135
+ prefix="Time: ",
136
+ visible=True,
137
+ xanchor="right"
138
+ ),
139
+ pad=dict(b=10, t=50),
140
+ len=0.9,
141
+ x=0.1,
142
+ y=0,
143
+ steps=[dict(
144
+ args=[[f.name], {"frame": {"duration": 0, "redraw": True},
145
+ "mode": "immediate",
146
+ "transition": {"duration": 0}}],
147
+ label=f"{idx[i]*dt:.2f}s", # Display time in seconds on slider
148
+ method="animate")
149
+ for i, f in enumerate(fig.frames)]
150
+ )]
151
+
152
  fig.update_layout(
153
+ title=f"2D Heat Eq — init={initial}, bc={bc}, Gamma={Gamma:.2f}",
154
+ xaxis_title="x",
155
+ yaxis_title="y",
156
+ sliders=sliders,
157
+ # Set aspect ratio to match the domain
158
+ yaxis=dict(scaleanchor="x", scaleratio=Ly/Lx if Lx > 0 else 1)
 
 
 
 
 
 
 
 
159
  )
160
+
161
  return fig
162
 
163
 
164
+ # --- 3. Gradio Interface Logic (Modified to connect new functions) ---
165
  def gradio_interface(lx, ly, t_max, gamma, nx, ny, initial, bc, frame_skip):
166
+ """Main function for the Gradio interface."""
167
  nx, ny, frame_skip = int(nx), int(ny), int(frame_skip)
168
+
169
+ # Call the solver to get the raw simulation data and time step
170
+ U, dt = solve_2d_heat_equation(
171
  Lx=lx, Ly=ly, t_max=t_max, Gamma=gamma, Nx=nx, Ny=ny,
172
+ initial=initial, bc=bc
173
+ ) # <<< THIS IS THE CORRECTED LINE with the closing parenthesis
174
+
175
+ # Create the Plotly figure from the data
176
+ fig = create_plotly_animation(
177
+ U=U, Lx=lx, Ly=ly, initial=initial, bc=bc, Gamma=gamma,
178
+ frame_skip=frame_skip, dt=dt
179
  )
180
+ return fig
181
 
182
+ # --- 4. Gradio UI Layout (Modified to use gr.Plot) ---
183
  with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
184
+ gr.Markdown("# ♨️ Interactive 2D Heat Equation Simulator\nAdjust parameters and run the simulation.")
185
+
186
  with gr.Row():
187
  with gr.Column(scale=1):
188
  gr.Markdown("## Domain & Grid")
189
  lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx")
190
  ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly")
191
+ nx_slider = gr.Slider(3, 200, 50, 1, label="Nx")
192
+ ny_slider = gr.Slider(3, 200, 50, 1, label="Ny")
193
+
194
  gr.Markdown("## Simulation")
195
  t_slider = gr.Slider(0.01, 5.0, 0.5, 0.01, label="t_max")
196
+ gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma")
197
+
198
  gr.Markdown("## Conditions")
199
  initial_dropdown = gr.Dropdown(
200
+ ["gaussian", "random", "sinusoidal", "step"], "gaussian", label="Initial"
201
  )
202
  bc_dropdown = gr.Dropdown(
203
+ ["dirichlet", "neumann", "periodic"], "dirichlet", label="Boundary"
204
  )
205
+
206
  gr.Markdown("## Animation")
207
+ frame_skip_slider = gr.Slider(1, 50, 5, 1, label="Frame Skip")
208
  run_btn = gr.Button("Run Simulation", variant="primary")
209
+
210
  with gr.Column(scale=3):
211
+ gr.Markdown("### Interactive Heatmap Animation\nUse the play/pause buttons, drag the slider, or use your mouse/trackpad to zoom and pan the plot.")
212
+ plot_output = gr.Plot(label="Interactive Heatmap") # Changed from gr.Image to gr.Plot
213
 
214
  inputs_list = [lx_slider, ly_slider, t_slider, gamma_slider,
215
  nx_slider, ny_slider, initial_dropdown, bc_dropdown, frame_skip_slider]
216
 
217
  run_btn.click(fn=gradio_interface, inputs=inputs_list, outputs=plot_output)
218
+
219
  gr.Examples(
220
+ examples=[
221
+ [1.0, 1.0, 0.5, 0.1, 50, 50, "gaussian", "dirichlet", 5],
222
  [2.0, 1.0, 1.0, 0.05, 60, 30, "sinusoidal", "periodic", 10],
223
  [1.0, 1.0, 0.2, 0.2, 80, 80, "step", "neumann", 2],
224
+ ],
225
  inputs=inputs_list,
226
+ outputs=[plot_output],
227
+ fn=gradio_interface
 
228
  )
 
 
 
 
 
 
 
 
229
 
230
  if __name__ == "__main__":
231
+ # To run this, you will need to install plotly: pip install plotly
232
  demo.launch()