harishaseebat92 commited on
Commit
9f5a2d3
·
verified ·
1 Parent(s): 75f0e0a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -0
app.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import plotly.graph_objects as go # Changed from matplotlib
3
+ import gradio as gr
4
+ # Removed: matplotlib.pyplot, matplotlib.animation.FuncAnimation, tempfile
5
+
6
+ def solve_and_create_plotly_animation(Lx: float,
7
+ Ly: float,
8
+ t_max: float,
9
+ Gamma: float = 0.1,
10
+ Nx: int = 50,
11
+ Ny: int = 50,
12
+ initial: str = "gaussian",
13
+ bc: str = "dirichlet",
14
+ frame_skip: int = 1):
15
+ """
16
+ Solve the 2D heat equation u_t = Gamma*(u_xx + u_yy) and return an interactive Plotly animation.
17
+ Initial conditions: {"gaussian", "random", "sinusoidal", "step"}
18
+ Boundary conditions: {"dirichlet", "neumann", "periodic"}
19
+ """
20
+ # Spatial grid
21
+ x_coords = np.linspace(0, Lx, Nx)
22
+ y_coords = np.linspace(0, Ly, Ny)
23
+ dx, dy = x_coords[1] - x_coords, y_coords[1] - y_coords
24
+
25
+ if dx == 0 or dy == 0: # Should not happen with Nx, Ny >= 3 from Gradio sliders
26
+ # Create an empty figure with an error message for Gradio
27
+ fig = go.Figure()
28
+ fig.update_layout(title_text="Error: Nx and Ny must be > 1 for dx, dy calculation.",
29
+ xaxis_showticklabels=False, yaxis_showticklabels=False)
30
+ return fig
31
+
32
+
33
+ # Time stepping for stability
34
+ # Ensure dt is positive, handle potential division by zero if Gamma is zero or dx/dy are problematic
35
+ denominator = (2 * Gamma * (1/dx**2 + 1/dy**2))
36
+ if denominator <= 0: # Avoid division by zero or negative dt
37
+ fig = go.Figure()
38
+ fig.update_layout(title_text=f"Error: Unstable dt parameters (Gamma={Gamma}, dx={dx}, dy={dy}). Check Gamma.",
39
+ xaxis_showticklabels=False, yaxis_showticklabels=False)
40
+ return fig
41
+ dt = 1.0 / denominator
42
+
43
+ Nt = int(np.ceil(t_max / dt)) + 1
44
+ if Nt <=1: # Ensure there's at least an initial and one computed frame for animation
45
+ Nt = 2
46
+
47
+ rx, ry = Gamma * dt / dx**2, Gamma * dt / dy**2
48
+
49
+ # Initial condition
50
+ X, Y = np.meshgrid(x_coords, y_coords, indexing='ij') # Use x_coords, y_coords
51
+ u = np.zeros((Nx, Ny)) # Initialize u
52
+ if initial == "gaussian":
53
+ u = np.exp(-(((X - Lx/2)**2 + (Y - Ly/2)**2) / (2*(Lx/10)**2)))
54
+ elif initial == "random":
55
+ u = np.random.rand(Nx, Ny)
56
+ elif initial == "sinusoidal":
57
+ # Ensure Lx and Ly are not zero to avoid division by zero
58
+ kx = 2 * np.pi / Lx if Lx > 0 else 0
59
+ ky = 2 * np.pi / Ly if Ly > 0 else 0
60
+ u = np.sin(kx * X) * np.sin(ky * Y)
61
+ elif initial == "step":
62
+ u = np.where((X < Lx/2) & (Y < Ly/2), 1.0, 0.0)
63
+ else: # Should not happen due to dropdown
64
+ fig = go.Figure()
65
+ fig.update_layout(title_text=f"Error: Unknown initial condition: {initial}",
66
+ xaxis_showticklabels=False, yaxis_showticklabels=False)
67
+ return fig
68
+
69
+ # Storage for solution frames to be animated
70
+ # We will store only the frames selected by frame_skip
71
+ frames_to_animate_data =
72
+ frames_to_animate_data.append(u.copy().T) # Transpose for heatmap like imshow(origin='lower')
73
+
74
+ # Time-stepping loop
75
+ for n in range(1, Nt):
76
+ un = u.copy()
77
+ # Interior update
78
+ u[1:-1, 1:-1] = (
79
+ un[1:-1, 1:-1]
80
+ + rx * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[:-2, 1:-1])
81
+ + ry * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, :-2])
82
+ )
83
+ # Boundary conditions
84
+ if bc == "dirichlet":
85
+ u[0, :] = u[-1, :] = u[:, 0] = u[:, -1] = 0.0
86
+ elif bc == "neumann":
87
+ u[0, :] = u[1, :]
88
+ u[-1, :] = u[-2, :]
89
+ u[:, 0] = u[:, 1]
90
+ u[:, -1] = u[:, -2]
91
+ elif bc == "periodic":
92
+ # Ensure indices are valid for periodic BCs (Nx, Ny >= 3)
93
+ if Nx >= 3 and Ny >=3:
94
+ u[0, :] = un[-2, :]
95
+ u[-1, :] = un[1, :]
96
+ u[:, 0] = un[:, -2]
97
+ u[:, -1] = un[:, 1]
98
+ else: # Fallback for very small grids where periodic BCs are ill-defined
99
+ u[0, :] = u[1, :] # Effectively Neumann
100
+ u[-1, :] = u[-2, :]
101
+ u[:, 0] = u[:, 1]
102
+ u[:, -1] = u[:, -2]
103
+ else: # Should not happen
104
+ fig = go.Figure()
105
+ fig.update_layout(title_text=f"Error: Unknown bc: {bc}",
106
+ xaxis_showticklabels=False, yaxis_showticklabels=False)
107
+ return fig
108
+
109
+ if n % frame_skip == 0 or n == Nt -1: # Store frame based on frame_skip
110
+ frames_to_animate_data.append(u.copy().T) # Transpose
111
+
112
+ if not frames_to_animate_data:
113
+ fig = go.Figure()
114
+ fig.update_layout(title_text="Error: No frames generated for animation.",
115
+ xaxis_showticklabels=False, yaxis_showticklabels=False)
116
+ return fig
117
+
118
+ # Determine global min/max for consistent colorscale
119
+ global_min = np.min([np.min(frame) for frame in frames_to_animate_data])
120
+ global_max = np.max([np.max(frame) for frame in frames_to_animate_data])
121
+ if global_min == global_max: # Avoid issues if all values are the same
122
+ global_max += 1e-6 if global_max == 0 else abs(global_max * 0.01) # Add a tiny epsilon
123
+
124
+ # Create Plotly animation
125
+ fig = go.Figure(
126
+ data=[go.Heatmap(
127
+ z=frames_to_animate_data,
128
+ x=x_coords, # Use actual coordinates for axes
129
+ y=y_coords, # Use actual coordinates for axes
130
+ colorscale='Viridis',
131
+ zmin=global_min,
132
+ zmax=global_max,
133
+ showscale=True,
134
+ colorbar_title_text="u"
135
+ )],
136
+ layout=go.Layout(
137
+ title_text=f"2D Heat Eq: init={initial}, bc={bc}, Gamma={Gamma:.3f}",
138
+ xaxis_title_text="x",
139
+ yaxis_title_text="y",
140
+ # yaxis_scaleanchor="x", # Makes pixels square if dx=dy and Lx=Ly
141
+ font=dict(size=10),
142
+ # Ensure plot updates don't change axis ranges during animation
143
+ xaxis_range=[0, Lx],
144
+ yaxis_range=[0, Ly],
145
+ ),
146
+ frames=[
147
+ go.Frame(
148
+ data=[go.Heatmap(
149
+ z=frame_data,
150
+ x=x_coords,
151
+ y=y_coords,
152
+ colorscale='Viridis',
153
+ zmin=global_min,
154
+ zmax=global_max
155
+ )],
156
+ name=f"frame{k}"
157
+ ) for k, frame_data in enumerate(frames_to_animate_data)
158
+ ]
159
+ )
160
+
161
+ # Add animation control buttons
162
+ fig.update_layout(
163
+ updatemenus=), # No transition smoothing
164
+ dict(label="Pause",
165
+ method="animate",
166
+ args=[[None], {"frame": {"duration": 0, "redraw": False},
167
+ "mode": "immediate",
168
+ "transition": {"duration": 0}}])
169
+ ],
170
+ direction="left",
171
+ pad={"r": 10, "t": 70}, # Adjust padding
172
+ x=0.1, xanchor="left",
173
+ y=0, yanchor="top"
174
+ )]
175
+ )
176
+
177
+ # Add frame slider
178
+ fig.update_layout(
179
+ sliders=[dict(
180
+ active=0,
181
+ steps=[
182
+ dict(label=str(k),
183
+ method="animate",
184
+ args=[[f"frame{k}"],
185
+ {"mode": "immediate",
186
+ "frame": {"duration": 100, "redraw": True}, # Duration if played
187
+ "transition": {"duration": 0}}]) # No transition for slider drag
188
+ for k in range(len(frames_to_animate_data))
189
+ ],
190
+ currentvalue={"prefix": "Frame: ", "visible": True, "xanchor": "right"},
191
+ pad={"t": 60, "b": 10} # Adjust padding
192
+ )]
193
+ )
194
+ return fig
195
+
196
+
197
+ def gradio_interface(lx, ly, t_max, gamma, nx, ny, initial, bc, frame_skip):
198
+ nx, ny, frame_skip = int(nx), int(ny), int(frame_skip)
199
+ # Call the new function that returns a Plotly figure
200
+ return solve_and_create_plotly_animation(
201
+ Lx=lx, Ly=ly, t_max=t_max, Gamma=gamma, Nx=nx, Ny=ny,
202
+ initial=initial, bc=bc, frame_skip=frame_skip
203
+ )
204
+
205
+ with gr.Blocks(theme=gr.themes.Soft(), title="2D Heat Simulator") as demo:
206
+ gr.Markdown("# ♨️ 2D Heat Equation Simulator (Interactive with Plotly)\nAdjust parameters and run the simulation.")
207
+ with gr.Row():
208
+ with gr.Column(scale=1):
209
+ gr.Markdown("## Domain & Grid")
210
+ lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx")
211
+ ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly")
212
+ nx_slider = gr.Slider(3, 200, 50, 1, label="Nx (min 3)") # Min 3 for some BCs
213
+ ny_slider = gr.Slider(3, 200, 50, 1, label="Ny (min 3)") # Min 3 for some BCs
214
+ gr.Markdown("## Simulation")
215
+ t_slider = gr.Slider(0.01, 5.0, 0.5, 0.01, label="t_max")
216
+ gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma (Diffusion Coeff.)")
217
+ gr.Markdown("## Conditions")
218
+ initial_dropdown = gr.Dropdown(
219
+ ["gaussian", "random", "sinusoidal", "step"], value="gaussian", label="Initial Condition"
220
+ )
221
+ bc_dropdown = gr.Dropdown(
222
+ ["dirichlet", "neumann", "periodic"], value="dirichlet", label="Boundary Condition"
223
+ )
224
+ gr.Markdown("## Animation")
225
+ frame_skip_slider = gr.Slider(1, 50, 5, 1, label="Frame Skip (Higher = Fewer Frames)")
226
+ run_btn = gr.Button("Run Simulation", variant="primary")
227
+ with gr.Column(scale=3):
228
+ # Changed from gr.Image to gr.Plot
229
+ plot_output = gr.Plot(label="Interactive Heatmap Animation")
230
+
231
+ inputs_list = [lx_slider, ly_slider, t_slider, gamma_slider,
232
+ nx_slider, ny_slider, initial_dropdown, bc_dropdown, frame_skip_slider]
233
+
234
+ run_btn.click(fn=gradio_interface, inputs=inputs_list, outputs=plot_output)
235
+
236
+ gr.Examples(
237
+ examples=[1.0, 1.0, 0.5, 0.1, 50, 50, "gaussian", "dirichlet", 5],
238
+ [2.0, 1.0, 1.0, 0.05, 60, 30, "sinusoidal", "periodic", 10],
239
+ [1.0, 1.0, 0.2, 0.2, 80, 80, "step", "neumann", 2],
240
+ [0.5, 0.5, 0.1, 0.5, 30, 30, "random", "dirichlet", 1],
241
+ inputs=inputs_list,
242
+ outputs=[plot_output], # Output is now a single Plot component
243
+ fn=gradio_interface,
244
+ # cache_examples=True # Consider enabling if simulations are slow and inputs are identical
245
+ )
246
+ gr.Markdown("### How to Interact:"
247
+ "\n- **Play/Pause Buttons:** Control the animation playback (located below the plot title)."
248
+ "\n- **Slider:** Drag to scrub through frames (located below the plot)."
249
+ "\n- **Plotly Modebar (top right of plot on hover):**"
250
+ "\n - **Zoom:** Use zoom tools (box zoom, zoom in/out icons, scroll wheel)."
251
+ "\n - **Pan:** Use the pan tool to move the view when zoomed."
252
+ "\n - **Autoscale/Reset:** Return to the default view or reset axes."
253
+ )
254
+
255
+ if __name__ == "__main__":
256
+ demo.launch()