harishaseebat92 commited on
Commit
8a4db24
·
verified ·
1 Parent(s): df0ed6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -32
app.py CHANGED
@@ -18,10 +18,10 @@ def solve_3d_heat_equation(Lx: float, Ly: float, Lz: float,
18
  x = np.linspace(0, Lx, Nx)
19
  y = np.linspace(0, Ly, Ny)
20
  z = np.linspace(0, Lz, Nz)
21
- dx, dy, dz = x[1] - x, y[1] - y, z[1] - z
22
 
23
  if dx == 0 or dy == 0 or dz == 0:
24
- raise ValueError("Nx, Ny, and Nz must be > 1.")
25
 
26
  # Stability condition for 3D FTCS scheme
27
  dt = 0.5 / (Gamma * (1/dx**2 + 1/dy**2 + 1/dz**2)) # Adjusted for 3D
@@ -44,7 +44,7 @@ def solve_3d_heat_equation(Lx: float, Ly: float, Lz: float,
44
  raise ValueError(f"Unknown initial condition: {initial}")
45
 
46
  U = np.zeros((Nt, Nx, Ny, Nz))
47
- U = u.copy()
48
 
49
  for n in range(1, Nt):
50
  un = u.copy()
@@ -83,7 +83,6 @@ def create_animation_gif_3d_slice(U, Lx, Ly, Lz, initial, bc, Gamma, frame_skip,
83
  Nt, Nx, Ny, Nz = U.shape
84
  fig, ax = plt.subplots()
85
 
86
- # We'll animate the central xy-slice
87
  slice_z_idx = Nz // 2
88
  z_coord_slice = np.linspace(0, Lz, Nz)[slice_z_idx]
89
 
@@ -100,17 +99,22 @@ def create_animation_gif_3d_slice(U, Lx, Ly, Lz, initial, bc, Gamma, frame_skip,
100
  im.set_data(U[frame, :, :, slice_z_idx].T)
101
  return [im]
102
 
103
- idx = list(range(0, Nt, frame_skip))
104
- if not idx or (idx[-1]!= Nt - 1 and Nt > 1) : # Ensure last frame is included if Nt > 1
105
- if Nt-1 not in idx: idx.append(Nt - 1)
106
- if not idx and Nt ==1: # Handle case with only one time step
107
- idx =
 
 
 
 
 
108
 
109
 
110
  ani = FuncAnimation(fig, update, frames=idx, blit=True)
111
 
112
  with tempfile.NamedTemporaryFile(suffix='.gif', delete=False) as tmpfile:
113
- ani.save(tmpfile.name, writer='pillow', fps=max(1, 30 // frame_skip)) # Adjust fps based on skip
114
  gif_path = tmpfile.name
115
 
116
  plt.close(fig)
@@ -130,10 +134,10 @@ def create_plotly_figure_3d(u_3d, Lx, Ly, Lz, time_label):
130
  y=Y.flatten(),
131
  z=Z.flatten(),
132
  value=u_3d.flatten(),
133
- isomin=u_3d.min(),
134
  isomax=u_3d.max(),
135
- opacity=0.1, # needs to be small to see through all surfaces
136
- surface_count=20, # needs to be a large number for good volume rendering
137
  colorscale='viridis'
138
  ))
139
  fig.update_layout(
@@ -141,7 +145,9 @@ def create_plotly_figure_3d(u_3d, Lx, Ly, Lz, time_label):
141
  scene=dict(
142
  xaxis_title='x',
143
  yaxis_title='y',
144
- zaxis_title='z'
 
 
145
  )
146
  )
147
  return fig
@@ -152,7 +158,7 @@ def run_simulation_3d(lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_s
152
  Lx=lx, Ly=ly, Lz=lz, t_max=t_max, Gamma=gamma,
153
  Nx=nx, Ny=ny, Nz=nz, initial=initial, bc=bc
154
  )
155
- Nt = U.shape
156
 
157
  idx0 = 0
158
  idx1 = round((Nt - 1) / 4) if Nt > 1 else 0
@@ -175,6 +181,8 @@ def run_simulation_3d(lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_s
175
  # --- Gradio Interface Logic (3D) ---
176
  def gradio_interface_3d(lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_skip):
177
  nx, ny, nz, frame_skip = int(nx), int(ny), int(nz), int(frame_skip)
 
 
178
  gif_path, fig0, fig1, fig2, fig3 = run_simulation_3d(
179
  lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_skip
180
  )
@@ -188,13 +196,13 @@ with gr.Blocks(theme=gr.themes.Soft(), title="3D Heat Simulator") as demo:
188
  gr.Markdown("## Domain & Grid")
189
  lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx")
190
  ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly")
191
- lz_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lz (New)") # New
192
- nx_slider = gr.Slider(3, 100, 30, 1, label="Nx") # Max reduced for 3D
193
- ny_slider = gr.Slider(3, 100, 30, 1, label="Ny") # Max reduced for 3D
194
- nz_slider = gr.Slider(3, 100, 30, 1, label="Nz (New)") # New, max reduced
195
 
196
  gr.Markdown("## Simulation")
197
- t_slider = gr.Slider(0.01, 2.0, 0.2, 0.01, label="t_max") # Max t_max reduced
198
  gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma")
199
 
200
  gr.Markdown("## Conditions")
@@ -206,7 +214,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="3D Heat Simulator") as demo:
206
  )
207
 
208
  gr.Markdown("## Animation")
209
- frame_skip_slider = gr.Slider(1, 50, 10, 1, label="Frame Skip") # Default skip increased
210
  run_btn = gr.Button("Run 3D Simulation", variant="primary")
211
 
212
  with gr.Column(scale=3):
@@ -225,43 +233,43 @@ with gr.Blocks(theme=gr.themes.Soft(), title="3D Heat Simulator") as demo:
225
  run_btn.click(fn=gradio_interface_3d, inputs=inputs_list, outputs=outputs_list)
226
 
227
  gr.Examples(
228
- examples=[1.0, 1.0, 1.0, 0.2, 0.1, 20, 20, 20, "gaussian", "dirichlet", 10], # Reduced Nx,Ny,Nz for example
229
- [1.5, 1.0, 0.5, 0.3, 0.05, 25, 20, 15, "sinusoidal", "periodic", 15],
230
- [1.0, 1.0, 1.0, 0.1, 0.2, 30, 30, 30, "step", "neumann", 5],
231
  inputs=inputs_list,
232
  outputs=outputs_list,
233
- fn=gradio_interface_3d
 
234
  )
235
 
236
  # --- FastAPI Setup for API Endpoint (3D) ---
237
  app = FastAPI()
238
 
239
- # Mount Gradio app to FastAPI
240
  app = gr.mount_gradio_app(app, demo, path="/")
241
 
242
- # Define the simulation parameters model for 3D
243
  class SimulationParams3D(BaseModel):
244
  lx: float
245
  ly: float
246
- lz: float # New
247
  t_max: float
248
  gamma: float
249
  nx: int
250
  ny: int
251
- nz: int # New
252
  initial: str
253
  bc: str
254
  frame_skip: int
255
 
256
- # Custom API endpoint to run 3D simulation and return results
257
- @app.post("/simulate_3d") # Renamed endpoint
258
  def simulate_3d_api(params: SimulationParams3D):
 
 
259
  gif_path, fig0, fig1, fig2, fig3 = run_simulation_3d(**params.dict())
260
  with open(gif_path, "rb") as f:
261
  gif_data = base64.b64encode(f.read()).decode('utf-8')
262
  return {
263
  "gif_base64": gif_data,
264
- "plot0_3d_volume": fig0.to_json(), # Indicate 3D plot
265
  "plot1_3d_volume": fig1.to_json(),
266
  "plot2_3d_volume": fig2.to_json(),
267
  "plot3_3d_volume": fig3.to_json()
 
18
  x = np.linspace(0, Lx, Nx)
19
  y = np.linspace(0, Ly, Ny)
20
  z = np.linspace(0, Lz, Nz)
21
+ dx, dy, dz = x[1] - x, y[1] - y, z[1] - z # Corrected indexing
22
 
23
  if dx == 0 or dy == 0 or dz == 0:
24
+ raise ValueError("Nx, Ny, and Nz must be > 1 for valid dx, dy, dz.")
25
 
26
  # Stability condition for 3D FTCS scheme
27
  dt = 0.5 / (Gamma * (1/dx**2 + 1/dy**2 + 1/dz**2)) # Adjusted for 3D
 
44
  raise ValueError(f"Unknown initial condition: {initial}")
45
 
46
  U = np.zeros((Nt, Nx, Ny, Nz))
47
+ U = u.copy() # Store initial condition
48
 
49
  for n in range(1, Nt):
50
  un = u.copy()
 
83
  Nt, Nx, Ny, Nz = U.shape
84
  fig, ax = plt.subplots()
85
 
 
86
  slice_z_idx = Nz // 2
87
  z_coord_slice = np.linspace(0, Lz, Nz)[slice_z_idx]
88
 
 
99
  im.set_data(U[frame, :, :, slice_z_idx].T)
100
  return [im]
101
 
102
+ if Nt <= 1:
103
+ idx = # Only one frame (initial state)
104
+ else:
105
+ idx = list(range(0, Nt, frame_skip))
106
+ if not idx: # If frame_skip is too large for Nt > 1
107
+ idx =
108
+ if idx[-1]!= Nt - 1: # Ensure last frame is included
109
+ idx.append(Nt - 1)
110
+ # Remove duplicates if Nt-1 was already included by range
111
+ idx = sorted(list(set(idx)))
112
 
113
 
114
  ani = FuncAnimation(fig, update, frames=idx, blit=True)
115
 
116
  with tempfile.NamedTemporaryFile(suffix='.gif', delete=False) as tmpfile:
117
+ ani.save(tmpfile.name, writer='pillow', fps=max(1, 30 // max(1,frame_skip)))
118
  gif_path = tmpfile.name
119
 
120
  plt.close(fig)
 
134
  y=Y.flatten(),
135
  z=Z.flatten(),
136
  value=u_3d.flatten(),
137
+ isomin=u_3d.min(), # Use actual min/max of the current data
138
  isomax=u_3d.max(),
139
+ opacity=0.1,
140
+ surface_count=17, # Adjusted for potentially better performance/look
141
  colorscale='viridis'
142
  ))
143
  fig.update_layout(
 
145
  scene=dict(
146
  xaxis_title='x',
147
  yaxis_title='y',
148
+ zaxis_title='z',
149
+ # Aspect ratio can be important for 3D visualization
150
+ aspectmode='cube'
151
  )
152
  )
153
  return fig
 
158
  Lx=lx, Ly=ly, Lz=lz, t_max=t_max, Gamma=gamma,
159
  Nx=nx, Ny=ny, Nz=nz, initial=initial, bc=bc
160
  )
161
+ Nt = U.shape # Corrected: U.shape is Nt
162
 
163
  idx0 = 0
164
  idx1 = round((Nt - 1) / 4) if Nt > 1 else 0
 
181
  # --- Gradio Interface Logic (3D) ---
182
  def gradio_interface_3d(lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_skip):
183
  nx, ny, nz, frame_skip = int(nx), int(ny), int(nz), int(frame_skip)
184
+ # Ensure frame_skip is at least 1
185
+ frame_skip = max(1, frame_skip)
186
  gif_path, fig0, fig1, fig2, fig3 = run_simulation_3d(
187
  lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_skip
188
  )
 
196
  gr.Markdown("## Domain & Grid")
197
  lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx")
198
  ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly")
199
+ lz_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lz")
200
+ nx_slider = gr.Slider(3, 60, 20, 1, label="Nx (e.g., 20-40 for speed)") # Max reduced
201
+ ny_slider = gr.Slider(3, 60, 20, 1, label="Ny (e.g., 20-40 for speed)") # Max reduced
202
+ nz_slider = gr.Slider(3, 60, 20, 1, label="Nz (e.g., 20-40 for speed)") # Max reduced
203
 
204
  gr.Markdown("## Simulation")
205
+ t_slider = gr.Slider(0.01, 1.0, 0.1, 0.01, label="t_max (e.g., 0.1-0.5)") # Max t_max reduced
206
  gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma")
207
 
208
  gr.Markdown("## Conditions")
 
214
  )
215
 
216
  gr.Markdown("## Animation")
217
+ frame_skip_slider = gr.Slider(1, 50, 5, 1, label="Frame Skip (e.g., 5-10)")
218
  run_btn = gr.Button("Run 3D Simulation", variant="primary")
219
 
220
  with gr.Column(scale=3):
 
233
  run_btn.click(fn=gradio_interface_3d, inputs=inputs_list, outputs=outputs_list)
234
 
235
  gr.Examples(
236
+ examples=[1.0, 1.0, 1.0, 0.1, 0.1, 20, 20, 20, "gaussian", "dirichlet", 5],
237
+ [1.5, 1.0, 0.5, 0.2, 0.05, 25, 20, 15, "sinusoidal", "periodic", 10],
238
+ [1.0, 1.0, 1.0, 0.05, 0.2, 15, 15, 15, "step", "neumann", 2],
239
  inputs=inputs_list,
240
  outputs=outputs_list,
241
+ fn=gradio_interface_3d,
242
+ cache_examples=False # Consider disabling cache if examples are slow or large
243
  )
244
 
245
  # --- FastAPI Setup for API Endpoint (3D) ---
246
  app = FastAPI()
247
 
 
248
  app = gr.mount_gradio_app(app, demo, path="/")
249
 
 
250
  class SimulationParams3D(BaseModel):
251
  lx: float
252
  ly: float
253
+ lz: float
254
  t_max: float
255
  gamma: float
256
  nx: int
257
  ny: int
258
+ nz: int
259
  initial: str
260
  bc: str
261
  frame_skip: int
262
 
263
+ @app.post("/simulate_3d")
 
264
  def simulate_3d_api(params: SimulationParams3D):
265
+ # Ensure frame_skip is at least 1 for API calls too
266
+ params.frame_skip = max(1, params.frame_skip)
267
  gif_path, fig0, fig1, fig2, fig3 = run_simulation_3d(**params.dict())
268
  with open(gif_path, "rb") as f:
269
  gif_data = base64.b64encode(f.read()).decode('utf-8')
270
  return {
271
  "gif_base64": gif_data,
272
+ "plot0_3d_volume": fig0.to_json(),
273
  "plot1_3d_volume": fig1.to_json(),
274
  "plot2_3d_volume": fig2.to_json(),
275
  "plot3_3d_volume": fig3.to_json()