File size: 16,292 Bytes
3dcb82e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8190971
3dcb82e
 
 
 
 
 
 
8190971
3dcb82e
 
 
 
 
 
 
 
 
 
 
8190971
3dcb82e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8190971
3dcb82e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
!apt-get update -qq
!apt-get install -y -qq gmsh
!pip install torch --upgrade -q
!pip install --upgrade -q \
    gmsh \
    meshio \
    trimesh \
    numpy \
    pandas \
    scikit-learn \
    matplotlib \
    plotly \
    ipywidgets \
    gradio
!pip install --upgrade -q jax jaxlib

# ===== CELL 1: SYSTEM INSTALLATION (RUN THIS FIRST) =====
# It is recommended to use the separate, more robust dependency installation
# script provided previously. This cell is a simplified version.
import subprocess
import sys
import os

def install_dependencies():
    """Installs all necessary system and Python packages for Colab."""
    print("πŸš€ Starting installation...")
    try:
        # Step 1: Install system packages like GMSH
        print("πŸ”§ Installing system package: GMSH...")
        subprocess.run(["apt-get", "update", "-qq"], check=True, capture_output=True)
        subprocess.run(["apt-get", "install", "-y", "-qq", "gmsh"], check=True, capture_output=True)
        print("   βœ… GMSH installed.")

        # Step 2: Install PyTorch and PyTorch Geometric correctly
        print("\n🧠 Installing PyTorch & PyTorch Geometric...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "-q"])
        # This command is crucial as it fetches the correct PyG versions
        pyg_install_command = [
            sys.executable, "-m", "pip", "install",
            "torch-scatter", "torch-sparse", "torch-cluster", "torch-spline-conv", "torch-geometric",
            "-f", f"https://data.pyg.org/whl/torch-{subprocess.check_output([sys.executable, '-c', 'import torch; print(torch.__version__)']).decode().strip()}.html",
            "-q"
        ]
        subprocess.check_call(pyg_install_command)
        print("   βœ… PyTorch & PyG installed.")

        # Step 3: Install other core packages
        print("\nπŸ“¦ Installing core libraries...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade",
                               "gmsh", "meshio", "trimesh", "numpy", "pandas",
                               "scikit-learn", "matplotlib", "plotly", "ipywidgets", "gradio", "-q"])
        print("   βœ… Core libraries installed.")

        print("\nπŸŽ‰ Installation complete! Please restart the runtime and run the next cell.")

    except Exception as e:
        print(f"❌ An error occurred during installation: {e}")
        print("   Please check the error message and try again.")

# Run installation
# install_dependencies()


# ===== CELL 2: MAIN APPLICATION (RUN AFTER RESTART) =====

# Safe imports with fallbacks
def safe_import():
    """Safely import all required packages after installation."""
    global gmsh, np, torch, nn, F, Data, GCNConv, pyg_utils, meshio, go, plt, pd, widgets, gr
    print("πŸ”¬ Importing necessary libraries...")
    try:
        import numpy as np
        import pandas as pd
        import matplotlib.pyplot as plt

        # Mesh and geometry
        import gmsh
        import meshio

        # PyTorch and PyTorch Geometric
        import torch
        import torch.nn as nn
        import torch.nn.functional as F
        from torch_geometric.data import Data
        from torch_geometric.nn import GCNConv
        import torch_geometric.utils as pyg_utils

        # Visualization
        import plotly.graph_objects as go

        # UI/UX
        import gradio as gr
        import ipywidgets as widgets
        from IPython.display import display, clear_output

        import warnings
        warnings.filterwarnings('ignore')

        print("βœ… All packages imported successfully!")
        return True

    except ImportError as e:
        print(f"❌ Critical import failure: {e}")
        print("   Please ensure Cell 1 was run and the runtime was restarted.")
        return False
    except Exception as e:
        print(f"❌ An unexpected error occurred during import: {e}")
        return False


# Import all packages
if not safe_import():
    # Stop execution if imports fail
    sys.exit("Stopping due to import errors.")


# ===== STEP 1: MESH GENERATION =====
print("\nπŸ”§ Step 1: Mesh generation and processing")

def create_beam_geometry(length=10.0, width=1.0, height=2.0, mesh_size=0.5):
    """Create a 3D beam geometry using GMSH."""
    try:
        gmsh.initialize()
        gmsh.model.add("cantilever_beam")
        beam = gmsh.model.occ.addBox(0, 0, 0, length, width, height)
        gmsh.model.occ.synchronize()
        gmsh.option.setNumber("Mesh.CharacteristicLengthMin", mesh_size * 0.5)
        gmsh.option.setNumber("Mesh.CharacteristicLengthMax", mesh_size)
        gmsh.model.mesh.generate(3)
        gmsh.write("beam_mesh.msh")
        gmsh.finalize()
        print(f"βœ… GMSH geometry created ('beam_mesh.msh')")
        return "beam_mesh.msh"
    except Exception as e:
        print(f"❌ GMSH geometry creation failed: {e}. Using a fallback mesh.")
        return create_fallback_mesh()

def create_fallback_mesh():
    """Create a simple fallback mesh if GMSH fails."""
    print("πŸ”„ Creating a fallback cubic mesh...")
    points = np.array([
        [0, 0, 0], [10, 0, 0], [10, 1, 0], [0, 1, 0],
        [0, 0, 2], [10, 0, 2], [10, 1, 2], [0, 1, 2]
    ], dtype=np.float32)
    cells = [("hexahedron", np.array([[0, 1, 2, 3, 4, 5, 6, 7]]))]
    mesh = meshio.Mesh(points, cells)
    mesh.write("fallback_mesh.vtk")
    print("βœ… Fallback mesh created ('fallback_mesh.vtk')")
    return "fallback_mesh.vtk"

mesh_file = create_beam_geometry()


# ===== STEP 2: MESH TO GRAPH CONVERSION =====
print("\nπŸ”„ Step 2: Converting mesh to graph representation")

def mesh_to_graph(mesh_file):
    """Convert a mesh file to a PyTorch Geometric graph."""
    try:
        mesh = meshio.read(mesh_file)
        points = mesh.points.astype(np.float32)

        cells = mesh.get_cells_type("tetra")
        if len(cells) == 0:
            cells = mesh.get_cells_type("triangle")
            if len(cells) == 0:
                hex_cells = mesh.get_cells_type("hexahedron")
                temp_cells = []
                for h in hex_cells:
                    temp_cells.extend([[h[0],h[1],h[2],h[4]],[h[1],h[2],h[3],h[7]]])
                cells = np.array(temp_cells)

        # ----- MAJOR FIX HERE -----
        # The function `face_to_edge_index` was removed from torch_geometric.
        # This is the modern, correct way to compute the edge index from faces.
        # We get all edges from the faces and then make the graph undirected.
        faces_tensor = torch.tensor(cells[:, :3].T, dtype=torch.long)
        edge_index = torch.cat([
            faces_tensor[[0, 1]], faces_tensor[[1, 2]], faces_tensor[[2, 0]]
        ], dim=1)
        edge_index = pyg_utils.to_undirected(edge_index)
        # ----- END OF FIX -----

        coords = torch.tensor(points, dtype=torch.float32)
        centroid = coords.mean(dim=0)
        dist_to_centroid = torch.norm(coords - centroid, dim=1, keepdim=True)
        coords_normalized = (coords - centroid) / (coords.std(dim=0) + 1e-8)
        x = torch.cat([coords_normalized, dist_to_centroid], dim=1)

        graph = Data(x=x, edge_index=edge_index, pos=coords)
        print(f"βœ… Graph created: {graph.num_nodes} nodes, {graph.num_edges} edges")
        return graph, points, cells

    except Exception as e:
        print(f"❌ Mesh conversion failed: {e}. Cannot proceed.")
        return None, None, None

graph, points, cells = mesh_to_graph(mesh_file)
if graph is None:
    sys.exit("Stopping due to mesh processing errors.")


# ===== STEP 3: ACCURATE PHYSICS-BASED ANALYSIS (FEM) =====
print("\nβš›οΈ  Step 3: Defining accurate physics-based analysis model")

def cantilever_beam_fem(points, E=210e9, load_magnitude=-1000):
    """Calculates displacement and stress for a cantilever beam using analytical formulas."""
    length = points[:, 0].max()
    height = points[:, 2].max()
    width = points[:, 1].max()
    I = (width * height**3) / 12

    fixed_nodes = np.where(points[:, 0] < 1e-6)[0]
    loaded_nodes = np.where(points[:, 0] > length - 1e-6)[0]

    displacement = np.zeros_like(points)
    stress = np.zeros(len(points))
    P = -load_magnitude

    for i in range(len(points)):
        x, _, z = points[i]
        deflection = (P * x**2) / (6 * E * I) * (3 * length - x)
        displacement[i, 2] = deflection
        moment = P * (length - x)
        z_from_neutral_axis = z - (height / 2)
        stress[i] = (moment * z_from_neutral_axis) / I

    return displacement, stress, fixed_nodes, loaded_nodes


# ===== STEP 4: AI SURROGATE MODEL & LIVE TRAINING =====
print("\n🧠 Step 4: Building and training AI surrogate model")

class EnhancedSurrogateNet(nn.Module):
    def __init__(self, in_channels=4, hidden_channels=64, out_channels=4, num_layers=3):
        super().__init__()
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels))
        self.dropout = nn.Dropout(0.2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for i in range(len(self.convs) - 1):
            x = self.convs[i](x, edge_index)
            if x.shape[0] > 1:
                x = self.batch_norms[i](x)
            x = F.relu(x)
            x = self.dropout(x)
        x = self.convs[-1](x, edge_index)
        return x

def train_surrogate_model(model, graph_data, training_status_callback):
    """Trains the surrogate model on synthetically generated data."""
    print("πŸš€ Starting AI model training...")
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    loss_fn = nn.MSELoss()
    training_data = []
    load_scenarios = np.linspace(-500, -5000, 10)
    for load in load_scenarios:
        disp_fem, stress_fem, _, _ = cantilever_beam_fem(points, load_magnitude=load)
        target = torch.tensor(np.hstack([disp_fem, stress_fem[:, np.newaxis]]), dtype=torch.float32)
        training_data.append(target)

    model.train()
    for epoch in range(100):
        total_loss = 0
        for target_data in training_data:
            optimizer.zero_grad()
            prediction = model(graph_data)
            loss = loss_fn(prediction, target_data)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if (epoch + 1) % 20 == 0:
            status_msg = f"Epoch {epoch+1}/100, Loss: {total_loss/len(training_data):.4f}"
            print(f"   {status_msg}")
            if training_status_callback:
                training_status_callback(status_msg)
    model.eval()
    print("βœ… AI model training complete!")
    return model


# ===== STEP 5: GRADIO INTERFACE & APPLICATION LOGIC =====
print("\n🎨 Step 5: Creating Gradio user interface")

class StructuralAnalysisApp:
    def __init__(self, points, graph):
        self.points = points
        self.graph = graph
        self.model = EnhancedSurrogateNet(in_channels=graph.x.shape[1], out_channels=4)

    def train_model_for_ui(self, training_status_update):
        self.model = train_surrogate_model(self.model, self.graph, training_status_update)
        return "Model trained successfully! Ready for analysis."

    def analyze(self, young_modulus, load_magnitude):
        try:
            E = float(young_modulus) * 1e9
            load = float(load_magnitude)
            disp_fem, stress_fem, fixed, loaded = cantilever_beam_fem(self.points, E=E, load_magnitude=load)
            disp_mag_fem = np.linalg.norm(disp_fem, axis=1)
            with torch.no_grad():
                prediction = self.model(self.graph)
                disp_surrogate = prediction[:, :3].numpy()
                stress_surrogate = prediction[:, 3].numpy()
                disp_mag_surrogate = np.linalg.norm(disp_surrogate, axis=1)
            fig = self.create_3d_plot(disp_mag_fem, stress_fem, fixed, E/1e9, load)
            results_text = self.format_results_text(
                disp_mag_fem, stress_fem, disp_mag_surrogate, stress_surrogate, E/1e9, load, fixed
            )
            return fig, results_text
        except Exception as e:
            error_msg = f"❌ Analysis failed: {str(e)}"
            print(error_msg)
            return go.Figure(), error_msg

    def create_3d_plot(self, disp_mag, stress, fixed_nodes, E, load):
        fig = go.Figure()
        fig.add_trace(go.Scatter3d(
            x=self.points[:, 0], y=self.points[:, 1], z=self.points[:, 2],
            mode='markers',
            marker=dict(
                size=4, color=disp_mag, colorscale='Viridis',
                colorbar=dict(title="Displacement (m)"),
                cmin=disp_mag.min(), cmax=disp_mag.max()
            ),
            text=[f"Stress: {s/1e6:.2f} MPa" for s in stress],
            hoverinfo='text', name='Deformation Field'
        ))
        fig.add_trace(go.Scatter3d(
            x=self.points[fixed_nodes, 0], y=self.points[fixed_nodes, 1], z=self.points[fixed_nodes, 2],
            mode='markers', marker=dict(size=6, color='red', symbol='x'), name='Fixed Support'
        ))
        fig.update_layout(
            title=f"Analysis Results (E={E:.0f} GPa, Load={load:.0f} N)",
            scene=dict(xaxis_title="X (m)", yaxis_title="Y (m)", zaxis_title="Z (m)"),
            width=800, height=600, margin=dict(l=0, r=0, b=0, t=40)
        )
        return fig

    def format_results_text(self, disp_fem, stress_fem, disp_surrogate, stress_surrogate, E, load, fixed):
        corr_disp = np.corrcoef(disp_fem, disp_surrogate)[0, 1]
        corr_stress = np.corrcoef(stress_fem, stress_surrogate)[0, 1]
        return f"""
        ### πŸ“Š Analysis Summary
        | Parameter | Value |
        | :--- | :--- |
        | **Young's Modulus** | {E:.0f} GPa |
        | **Load Magnitude** | {load:.0f} N |
        | **Mesh Nodes** | {len(self.points):,} |
        | **Fixed Nodes** | {len(fixed):,} |

        ### πŸ€– AI vs. FEM Comparison
        | Metric | FEM (Ground Truth) | AI Surrogate | Correlation |
        | :--- | :--- | :--- | :--- |
        | **Max Displacement** | `{disp_fem.max():.3e} m` | `{disp_surrogate.max():.3e} m` | **`{corr_disp:.3f}`** |
        | **Max Stress** | `{stress_fem.max()/1e6:.3f} MPa` | `{stress_surrogate.max()/1e6:.3f} MPa` | **`{corr_stress:.3f}`** |
        """

app = StructuralAnalysisApp(points, graph)

with gr.Blocks(theme=gr.themes.Soft(), title="AI Structural Analysis") as demo:
    gr.Markdown("# πŸ—οΈ AI-Powered Structural Analysis")
    gr.Markdown("An interactive tool combining Finite Element Method (FEM) with a Graph Neural Network (GNN) surrogate model. The GNN is trained in real-time on FEM data to provide fast, accurate predictions.")
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### πŸ› οΈ Parameters")
            young_modulus = gr.Slider(minimum=50, maximum=300, value=210, step=10, label="Young's Modulus (GPa)")
            load_magnitude = gr.Slider(minimum=-5000, maximum=-100, value=-1000, step=100, label="Load Magnitude (N)")
            with gr.Accordion("Advanced: AI Model Training", open=False):
                training_status = gr.Textbox(label="Training Status", value="Model is not trained yet.", interactive=False)
                train_btn = gr.Button("🧠 Train AI Model")
            analyze_btn = gr.Button("πŸš€ Run Analysis", variant="primary")
        with gr.Column(scale=2):
            gr.Markdown("### πŸ“ˆ Visualization & Results")
            plot_output = gr.Plot(label="3D Visualization")
            results_text = gr.Markdown()
    train_btn.click(fn=app.train_model_for_ui, inputs=[], outputs=[training_status], show_progress='full')
    analyze_btn.click(fn=app.analyze, inputs=[young_modulus, load_magnitude], outputs=[plot_output, results_text])
    demo.load(fn=app.train_model_for_ui, inputs=[], outputs=[training_status], show_progress='full')

print("🌐 Launching Gradio interface...")
demo.launch(share=True, debug=True)