Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,10 +18,10 @@ def solve_3d_heat_equation(Lx: float, Ly: float, Lz: float,
|
|
| 18 |
x = np.linspace(0, Lx, Nx)
|
| 19 |
y = np.linspace(0, Ly, Ny)
|
| 20 |
z = np.linspace(0, Lz, Nz)
|
| 21 |
-
dx, dy, dz = x[1] - x, y[1] - y, z[1] - z
|
| 22 |
|
| 23 |
if dx == 0 or dy == 0 or dz == 0:
|
| 24 |
-
raise ValueError("Nx, Ny, and Nz must be > 1.")
|
| 25 |
|
| 26 |
# Stability condition for 3D FTCS scheme
|
| 27 |
dt = 0.5 / (Gamma * (1/dx**2 + 1/dy**2 + 1/dz**2)) # Adjusted for 3D
|
|
@@ -44,7 +44,7 @@ def solve_3d_heat_equation(Lx: float, Ly: float, Lz: float,
|
|
| 44 |
raise ValueError(f"Unknown initial condition: {initial}")
|
| 45 |
|
| 46 |
U = np.zeros((Nt, Nx, Ny, Nz))
|
| 47 |
-
U = u.copy()
|
| 48 |
|
| 49 |
for n in range(1, Nt):
|
| 50 |
un = u.copy()
|
|
@@ -83,7 +83,6 @@ def create_animation_gif_3d_slice(U, Lx, Ly, Lz, initial, bc, Gamma, frame_skip,
|
|
| 83 |
Nt, Nx, Ny, Nz = U.shape
|
| 84 |
fig, ax = plt.subplots()
|
| 85 |
|
| 86 |
-
# We'll animate the central xy-slice
|
| 87 |
slice_z_idx = Nz // 2
|
| 88 |
z_coord_slice = np.linspace(0, Lz, Nz)[slice_z_idx]
|
| 89 |
|
|
@@ -100,17 +99,22 @@ def create_animation_gif_3d_slice(U, Lx, Ly, Lz, initial, bc, Gamma, frame_skip,
|
|
| 100 |
im.set_data(U[frame, :, :, slice_z_idx].T)
|
| 101 |
return [im]
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
ani = FuncAnimation(fig, update, frames=idx, blit=True)
|
| 111 |
|
| 112 |
with tempfile.NamedTemporaryFile(suffix='.gif', delete=False) as tmpfile:
|
| 113 |
-
ani.save(tmpfile.name, writer='pillow', fps=max(1, 30 // frame_skip))
|
| 114 |
gif_path = tmpfile.name
|
| 115 |
|
| 116 |
plt.close(fig)
|
|
@@ -130,10 +134,10 @@ def create_plotly_figure_3d(u_3d, Lx, Ly, Lz, time_label):
|
|
| 130 |
y=Y.flatten(),
|
| 131 |
z=Z.flatten(),
|
| 132 |
value=u_3d.flatten(),
|
| 133 |
-
isomin=u_3d.min(),
|
| 134 |
isomax=u_3d.max(),
|
| 135 |
-
opacity=0.1,
|
| 136 |
-
surface_count=
|
| 137 |
colorscale='viridis'
|
| 138 |
))
|
| 139 |
fig.update_layout(
|
|
@@ -141,7 +145,9 @@ def create_plotly_figure_3d(u_3d, Lx, Ly, Lz, time_label):
|
|
| 141 |
scene=dict(
|
| 142 |
xaxis_title='x',
|
| 143 |
yaxis_title='y',
|
| 144 |
-
zaxis_title='z'
|
|
|
|
|
|
|
| 145 |
)
|
| 146 |
)
|
| 147 |
return fig
|
|
@@ -152,7 +158,7 @@ def run_simulation_3d(lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_s
|
|
| 152 |
Lx=lx, Ly=ly, Lz=lz, t_max=t_max, Gamma=gamma,
|
| 153 |
Nx=nx, Ny=ny, Nz=nz, initial=initial, bc=bc
|
| 154 |
)
|
| 155 |
-
Nt = U.shape
|
| 156 |
|
| 157 |
idx0 = 0
|
| 158 |
idx1 = round((Nt - 1) / 4) if Nt > 1 else 0
|
|
@@ -175,6 +181,8 @@ def run_simulation_3d(lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_s
|
|
| 175 |
# --- Gradio Interface Logic (3D) ---
|
| 176 |
def gradio_interface_3d(lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_skip):
|
| 177 |
nx, ny, nz, frame_skip = int(nx), int(ny), int(nz), int(frame_skip)
|
|
|
|
|
|
|
| 178 |
gif_path, fig0, fig1, fig2, fig3 = run_simulation_3d(
|
| 179 |
lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_skip
|
| 180 |
)
|
|
@@ -188,13 +196,13 @@ with gr.Blocks(theme=gr.themes.Soft(), title="3D Heat Simulator") as demo:
|
|
| 188 |
gr.Markdown("## Domain & Grid")
|
| 189 |
lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx")
|
| 190 |
ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly")
|
| 191 |
-
lz_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lz
|
| 192 |
-
nx_slider = gr.Slider(3,
|
| 193 |
-
ny_slider = gr.Slider(3,
|
| 194 |
-
nz_slider = gr.Slider(3,
|
| 195 |
|
| 196 |
gr.Markdown("## Simulation")
|
| 197 |
-
t_slider = gr.Slider(0.01,
|
| 198 |
gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma")
|
| 199 |
|
| 200 |
gr.Markdown("## Conditions")
|
|
@@ -206,7 +214,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="3D Heat Simulator") as demo:
|
|
| 206 |
)
|
| 207 |
|
| 208 |
gr.Markdown("## Animation")
|
| 209 |
-
frame_skip_slider = gr.Slider(1, 50,
|
| 210 |
run_btn = gr.Button("Run 3D Simulation", variant="primary")
|
| 211 |
|
| 212 |
with gr.Column(scale=3):
|
|
@@ -225,43 +233,43 @@ with gr.Blocks(theme=gr.themes.Soft(), title="3D Heat Simulator") as demo:
|
|
| 225 |
run_btn.click(fn=gradio_interface_3d, inputs=inputs_list, outputs=outputs_list)
|
| 226 |
|
| 227 |
gr.Examples(
|
| 228 |
-
examples=[1.0, 1.0, 1.0, 0.
|
| 229 |
-
[1.5, 1.0, 0.5, 0.
|
| 230 |
-
[1.0, 1.0, 1.0, 0.
|
| 231 |
inputs=inputs_list,
|
| 232 |
outputs=outputs_list,
|
| 233 |
-
fn=gradio_interface_3d
|
|
|
|
| 234 |
)
|
| 235 |
|
| 236 |
# --- FastAPI Setup for API Endpoint (3D) ---
|
| 237 |
app = FastAPI()
|
| 238 |
|
| 239 |
-
# Mount Gradio app to FastAPI
|
| 240 |
app = gr.mount_gradio_app(app, demo, path="/")
|
| 241 |
|
| 242 |
-
# Define the simulation parameters model for 3D
|
| 243 |
class SimulationParams3D(BaseModel):
|
| 244 |
lx: float
|
| 245 |
ly: float
|
| 246 |
-
lz: float
|
| 247 |
t_max: float
|
| 248 |
gamma: float
|
| 249 |
nx: int
|
| 250 |
ny: int
|
| 251 |
-
nz: int
|
| 252 |
initial: str
|
| 253 |
bc: str
|
| 254 |
frame_skip: int
|
| 255 |
|
| 256 |
-
|
| 257 |
-
@app.post("/simulate_3d") # Renamed endpoint
|
| 258 |
def simulate_3d_api(params: SimulationParams3D):
|
|
|
|
|
|
|
| 259 |
gif_path, fig0, fig1, fig2, fig3 = run_simulation_3d(**params.dict())
|
| 260 |
with open(gif_path, "rb") as f:
|
| 261 |
gif_data = base64.b64encode(f.read()).decode('utf-8')
|
| 262 |
return {
|
| 263 |
"gif_base64": gif_data,
|
| 264 |
-
"plot0_3d_volume": fig0.to_json(),
|
| 265 |
"plot1_3d_volume": fig1.to_json(),
|
| 266 |
"plot2_3d_volume": fig2.to_json(),
|
| 267 |
"plot3_3d_volume": fig3.to_json()
|
|
|
|
| 18 |
x = np.linspace(0, Lx, Nx)
|
| 19 |
y = np.linspace(0, Ly, Ny)
|
| 20 |
z = np.linspace(0, Lz, Nz)
|
| 21 |
+
dx, dy, dz = x[1] - x, y[1] - y, z[1] - z # Corrected indexing
|
| 22 |
|
| 23 |
if dx == 0 or dy == 0 or dz == 0:
|
| 24 |
+
raise ValueError("Nx, Ny, and Nz must be > 1 for valid dx, dy, dz.")
|
| 25 |
|
| 26 |
# Stability condition for 3D FTCS scheme
|
| 27 |
dt = 0.5 / (Gamma * (1/dx**2 + 1/dy**2 + 1/dz**2)) # Adjusted for 3D
|
|
|
|
| 44 |
raise ValueError(f"Unknown initial condition: {initial}")
|
| 45 |
|
| 46 |
U = np.zeros((Nt, Nx, Ny, Nz))
|
| 47 |
+
U = u.copy() # Store initial condition
|
| 48 |
|
| 49 |
for n in range(1, Nt):
|
| 50 |
un = u.copy()
|
|
|
|
| 83 |
Nt, Nx, Ny, Nz = U.shape
|
| 84 |
fig, ax = plt.subplots()
|
| 85 |
|
|
|
|
| 86 |
slice_z_idx = Nz // 2
|
| 87 |
z_coord_slice = np.linspace(0, Lz, Nz)[slice_z_idx]
|
| 88 |
|
|
|
|
| 99 |
im.set_data(U[frame, :, :, slice_z_idx].T)
|
| 100 |
return [im]
|
| 101 |
|
| 102 |
+
if Nt <= 1:
|
| 103 |
+
idx = # Only one frame (initial state)
|
| 104 |
+
else:
|
| 105 |
+
idx = list(range(0, Nt, frame_skip))
|
| 106 |
+
if not idx: # If frame_skip is too large for Nt > 1
|
| 107 |
+
idx =
|
| 108 |
+
if idx[-1]!= Nt - 1: # Ensure last frame is included
|
| 109 |
+
idx.append(Nt - 1)
|
| 110 |
+
# Remove duplicates if Nt-1 was already included by range
|
| 111 |
+
idx = sorted(list(set(idx)))
|
| 112 |
|
| 113 |
|
| 114 |
ani = FuncAnimation(fig, update, frames=idx, blit=True)
|
| 115 |
|
| 116 |
with tempfile.NamedTemporaryFile(suffix='.gif', delete=False) as tmpfile:
|
| 117 |
+
ani.save(tmpfile.name, writer='pillow', fps=max(1, 30 // max(1,frame_skip)))
|
| 118 |
gif_path = tmpfile.name
|
| 119 |
|
| 120 |
plt.close(fig)
|
|
|
|
| 134 |
y=Y.flatten(),
|
| 135 |
z=Z.flatten(),
|
| 136 |
value=u_3d.flatten(),
|
| 137 |
+
isomin=u_3d.min(), # Use actual min/max of the current data
|
| 138 |
isomax=u_3d.max(),
|
| 139 |
+
opacity=0.1,
|
| 140 |
+
surface_count=17, # Adjusted for potentially better performance/look
|
| 141 |
colorscale='viridis'
|
| 142 |
))
|
| 143 |
fig.update_layout(
|
|
|
|
| 145 |
scene=dict(
|
| 146 |
xaxis_title='x',
|
| 147 |
yaxis_title='y',
|
| 148 |
+
zaxis_title='z',
|
| 149 |
+
# Aspect ratio can be important for 3D visualization
|
| 150 |
+
aspectmode='cube'
|
| 151 |
)
|
| 152 |
)
|
| 153 |
return fig
|
|
|
|
| 158 |
Lx=lx, Ly=ly, Lz=lz, t_max=t_max, Gamma=gamma,
|
| 159 |
Nx=nx, Ny=ny, Nz=nz, initial=initial, bc=bc
|
| 160 |
)
|
| 161 |
+
Nt = U.shape # Corrected: U.shape is Nt
|
| 162 |
|
| 163 |
idx0 = 0
|
| 164 |
idx1 = round((Nt - 1) / 4) if Nt > 1 else 0
|
|
|
|
| 181 |
# --- Gradio Interface Logic (3D) ---
|
| 182 |
def gradio_interface_3d(lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_skip):
|
| 183 |
nx, ny, nz, frame_skip = int(nx), int(ny), int(nz), int(frame_skip)
|
| 184 |
+
# Ensure frame_skip is at least 1
|
| 185 |
+
frame_skip = max(1, frame_skip)
|
| 186 |
gif_path, fig0, fig1, fig2, fig3 = run_simulation_3d(
|
| 187 |
lx, ly, lz, t_max, gamma, nx, ny, nz, initial, bc, frame_skip
|
| 188 |
)
|
|
|
|
| 196 |
gr.Markdown("## Domain & Grid")
|
| 197 |
lx_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lx")
|
| 198 |
ly_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Ly")
|
| 199 |
+
lz_slider = gr.Slider(0.1, 5.0, 1.0, 0.1, label="Lz")
|
| 200 |
+
nx_slider = gr.Slider(3, 60, 20, 1, label="Nx (e.g., 20-40 for speed)") # Max reduced
|
| 201 |
+
ny_slider = gr.Slider(3, 60, 20, 1, label="Ny (e.g., 20-40 for speed)") # Max reduced
|
| 202 |
+
nz_slider = gr.Slider(3, 60, 20, 1, label="Nz (e.g., 20-40 for speed)") # Max reduced
|
| 203 |
|
| 204 |
gr.Markdown("## Simulation")
|
| 205 |
+
t_slider = gr.Slider(0.01, 1.0, 0.1, 0.01, label="t_max (e.g., 0.1-0.5)") # Max t_max reduced
|
| 206 |
gamma_slider = gr.Slider(0.001, 1.0, 0.1, 0.001, label="Gamma")
|
| 207 |
|
| 208 |
gr.Markdown("## Conditions")
|
|
|
|
| 214 |
)
|
| 215 |
|
| 216 |
gr.Markdown("## Animation")
|
| 217 |
+
frame_skip_slider = gr.Slider(1, 50, 5, 1, label="Frame Skip (e.g., 5-10)")
|
| 218 |
run_btn = gr.Button("Run 3D Simulation", variant="primary")
|
| 219 |
|
| 220 |
with gr.Column(scale=3):
|
|
|
|
| 233 |
run_btn.click(fn=gradio_interface_3d, inputs=inputs_list, outputs=outputs_list)
|
| 234 |
|
| 235 |
gr.Examples(
|
| 236 |
+
examples=[1.0, 1.0, 1.0, 0.1, 0.1, 20, 20, 20, "gaussian", "dirichlet", 5],
|
| 237 |
+
[1.5, 1.0, 0.5, 0.2, 0.05, 25, 20, 15, "sinusoidal", "periodic", 10],
|
| 238 |
+
[1.0, 1.0, 1.0, 0.05, 0.2, 15, 15, 15, "step", "neumann", 2],
|
| 239 |
inputs=inputs_list,
|
| 240 |
outputs=outputs_list,
|
| 241 |
+
fn=gradio_interface_3d,
|
| 242 |
+
cache_examples=False # Consider disabling cache if examples are slow or large
|
| 243 |
)
|
| 244 |
|
| 245 |
# --- FastAPI Setup for API Endpoint (3D) ---
|
| 246 |
app = FastAPI()
|
| 247 |
|
|
|
|
| 248 |
app = gr.mount_gradio_app(app, demo, path="/")
|
| 249 |
|
|
|
|
| 250 |
class SimulationParams3D(BaseModel):
|
| 251 |
lx: float
|
| 252 |
ly: float
|
| 253 |
+
lz: float
|
| 254 |
t_max: float
|
| 255 |
gamma: float
|
| 256 |
nx: int
|
| 257 |
ny: int
|
| 258 |
+
nz: int
|
| 259 |
initial: str
|
| 260 |
bc: str
|
| 261 |
frame_skip: int
|
| 262 |
|
| 263 |
+
@app.post("/simulate_3d")
|
|
|
|
| 264 |
def simulate_3d_api(params: SimulationParams3D):
|
| 265 |
+
# Ensure frame_skip is at least 1 for API calls too
|
| 266 |
+
params.frame_skip = max(1, params.frame_skip)
|
| 267 |
gif_path, fig0, fig1, fig2, fig3 = run_simulation_3d(**params.dict())
|
| 268 |
with open(gif_path, "rb") as f:
|
| 269 |
gif_data = base64.b64encode(f.read()).decode('utf-8')
|
| 270 |
return {
|
| 271 |
"gif_base64": gif_data,
|
| 272 |
+
"plot0_3d_volume": fig0.to_json(),
|
| 273 |
"plot1_3d_volume": fig1.to_json(),
|
| 274 |
"plot2_3d_volume": fig2.to_json(),
|
| 275 |
"plot3_3d_volume": fig3.to_json()
|