Engineer-Areeb commited on
Commit
3dcb82e
·
verified ·
1 Parent(s): 1203bb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +386 -428
app.py CHANGED
@@ -1,432 +1,390 @@
1
- import gradio as gr
2
- import pandas as pd
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- import plotly.graph_objects as go
6
- import plotly.express as px
7
- from sklearn.ensemble import RandomForestRegressor
8
- from sklearn.preprocessing import StandardScaler
9
- import warnings
10
- warnings.filterwarnings('ignore')
11
-
12
- # Materials database
13
- def load_materials_database():
14
- return pd.DataFrame({
15
- 'Material': ['Steel 1045', 'Steel 4140', 'Aluminum 6061-T6', 'Aluminum 7075-T6',
16
- 'Stainless 304', 'Titanium Ti-6Al-4V', 'Cast Iron GG20', 'Brass C360'],
17
- 'Density (kg/m³)': [7850, 7850, 2700, 2810, 8000, 4430, 7200, 8500],
18
- 'Yield Strength (MPa)': [310, 415, 276, 503, 205, 880, 250, 125],
19
- 'Ultimate Strength (MPa)': [565, 655, 310, 572, 515, 950, 350, 315],
20
- 'Elastic Modulus (GPa)': [200, 205, 69, 72, 200, 114, 100, 100],
21
- 'Cost ($/kg)': [0.80, 1.20, 1.90, 2.50, 3.50, 35.00, 0.60, 6.50],
22
- 'Temperature Limit (°C)': [400, 500, 200, 180, 800, 600, 350, 250]
23
- })
24
-
25
- materials_db = load_materials_database()
26
-
27
- # AI Knowledge Base
28
- def get_ai_response(query):
29
- responses = {
30
- "material": "For material selection, consider strength requirements, environmental conditions, cost, and manufacturability. Steel offers good strength at low cost, aluminum provides excellent strength-to-weight ratio, and stainless steel offers superior corrosion resistance.",
31
- "design": "Design optimization involves balancing performance, cost, and manufacturability. Apply DFMA principles, minimize stress concentrations, and use appropriate safety factors. Consider topology optimization for weight reduction.",
32
- "stress": "Stress analysis requires proper load identification, material properties, and boundary conditions. Use von Mises criterion for ductile materials and maximum principal stress for brittle materials. Include safety factors based on application criticality.",
33
- "failure": "Failure prediction involves monitoring key parameters and applying appropriate failure theories. Consider fatigue, wear, corrosion, and overload failure modes. Implement condition-based monitoring for early detection.",
34
- "gear": "Gear design requires consideration of power transmission, speed ratio, center distance, and material selection. Use Lewis equation for bending stress and Hertz equation for contact stress.",
35
- "beam": "Beam analysis involves calculating bending moments, shear forces, and deflections. For simply supported beams: σ_max = Mc/I, δ_max = 5wL⁴/384EI (uniform load)."
36
- }
37
-
38
- query_lower = query.lower()
39
- for key, response in responses.items():
40
- if key in query_lower:
41
- return response
42
- return "I can help with material selection, design optimization, stress analysis, failure prediction, gear design, and beam analysis. Please specify your engineering question."
43
-
44
- # 3D Model Generation
45
- def generate_3d_model(component, teeth=20, module=2, length=100, diameter=20):
46
- if component == "Gear":
47
- # Generate gear profile
48
- angles = np.linspace(0, 2*np.pi, teeth*4)
49
- radius = teeth * module / 2
50
- r_vals = []
51
- for i, angle in enumerate(angles):
52
- if i % 4 < 2: # Tooth
53
- r_vals.append(radius + module)
54
- else: # Root
55
- r_vals.append(radius - module)
56
-
57
- x = np.array(r_vals) * np.cos(angles)
58
- y = np.array(r_vals) * np.sin(angles)
59
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  fig = go.Figure()
61
- fig.add_trace(go.Scatter(x=x, y=y, mode='lines', name='Gear Profile', line=dict(width=3)))
62
- fig.update_layout(
63
- title=f"Gear Profile (Teeth: {teeth}, Module: {module})",
64
- xaxis_title="X (mm)",
65
- yaxis_title="Y (mm)",
66
- showlegend=False,
67
- width=600,
68
- height=600
69
- )
70
- fig.update_xaxis(scaleanchor="y", scaleratio=1)
71
-
72
- specs = f"""
73
- **Gear Specifications:**
74
- - Number of Teeth: {teeth}
75
- - Module: {module}
76
- - Pitch Diameter: {teeth * module:.2f} mm
77
- - Addendum: {module:.2f} mm
78
- - Dedendum: {1.25 * module:.2f} mm
79
- """
80
-
81
- return fig, specs
82
-
83
- elif component == "Shaft":
84
- # Generate shaft 3D visualization
85
- theta = np.linspace(0, 2*np.pi, 20)
86
- z_vals = np.linspace(0, length, 10)
87
-
88
- # Create surface plot
89
- Z, T = np.meshgrid(z_vals, theta)
90
- X = (diameter/2) * np.cos(T)
91
- Y = (diameter/2) * np.sin(T)
92
-
93
- fig = go.Figure(data=[go.Surface(x=X, y=Y, z=Z, colorscale='Blues')])
94
- fig.update_layout(
95
- title=f"Shaft 3D Model (Length: {length}mm, Diameter: {diameter}mm)",
96
- scene=dict(
97
- xaxis_title="X (mm)",
98
- yaxis_title="Y (mm)",
99
- zaxis_title="Z (mm)"
100
  ),
101
- width=600,
102
- height=500
 
 
 
 
 
 
 
 
 
103
  )
104
-
105
- specs = f"""
106
- **Shaft Specifications:**
107
- - Length: {length} mm
108
- - Diameter: {diameter} mm
109
- - Cross-sectional Area: {np.pi * (diameter/2)**2:.2f} mm²
110
- - Moment of Inertia: {np.pi * diameter**4 / 64:.2f} mm⁴
 
 
 
 
 
 
 
 
 
 
 
 
111
  """
112
-
113
- return fig, specs
114
-
115
- return go.Figure(), "Select a component type"
116
-
117
- # FEA Analysis
118
- def perform_fea_analysis(material, force, area, length=1000):
119
- # Get material properties
120
- mat_data = materials_db[materials_db['Material'] == material].iloc[0]
121
-
122
- E = mat_data['Elastic Modulus (GPa)'] * 1e9 # Pa
123
- yield_strength = mat_data['Yield Strength (MPa)'] * 1e6 # Pa
124
-
125
- # Calculate stress and displacement
126
- stress = force / (area / 1e6) # Convert area from mm² to m²
127
- displacement = stress / E * length # mm
128
- safety_factor = yield_strength / stress if stress > 0 else 10
129
-
130
- # Generate stress distribution
131
- n_points = 50
132
- x_pos = np.linspace(0, length, n_points)
133
- stress_dist = stress * (1 + 0.3 * np.sin(2 * np.pi * x_pos / length)) / 1e6 # MPa
134
-
135
- # Create plot
136
- fig = go.Figure()
137
- fig.add_trace(go.Scatter(
138
- x=x_pos,
139
- y=stress_dist,
140
- mode='lines+markers',
141
- name='Stress Distribution',
142
- line=dict(width=3, color='red')
143
- ))
144
- fig.update_layout(
145
- title="Stress Distribution Along Component",
146
- xaxis_title="Position (mm)",
147
- yaxis_title="Stress (MPa)",
148
- width=600,
149
- height=400
150
- )
151
-
152
- # Add safety limit line
153
- fig.add_hline(y=mat_data['Yield Strength (MPa)'],
154
- line_dash="dash", line_color="orange",
155
- annotation_text="Yield Strength Limit")
156
-
157
- # Results summary
158
- results = f"""
159
- **FEA Analysis Results:**
160
- - Material: {material}
161
- - Max Stress: {stress/1e6:.2f} MPa
162
- - Max Displacement: {displacement:.4f} mm
163
- - Safety Factor: {safety_factor:.2f}
164
- - Status: {'✅ SAFE' if safety_factor > 2 else '⚠️ CHECK DESIGN'}
165
-
166
- **Material Properties:**
167
- - Yield Strength: {mat_data['Yield Strength (MPa)']} MPa
168
- - Elastic Modulus: {mat_data['Elastic Modulus (GPa)']} GPa
169
- - Density: {mat_data['Density (kg/m³)']} kg/m³
170
- """
171
-
172
- return fig, results
173
-
174
- # Material Selection
175
- def material_selector(min_strength, max_cost, min_temp):
176
- # Filter materials based on criteria
177
- filtered = materials_db[
178
- (materials_db['Yield Strength (MPa)'] >= min_strength) &
179
- (materials_db['Cost ($/kg)'] <= max_cost) &
180
- (materials_db['Temperature Limit (°C)'] >= min_temp)
181
- ]
182
-
183
- if filtered.empty:
184
- return "No materials match your criteria. Please adjust requirements.", go.Figure()
185
-
186
- # Create comparison chart
187
- fig = px.scatter(filtered,
188
- x='Yield Strength (MPa)',
189
- y='Cost ($/kg)',
190
- size='Density (kg/m³)',
191
- color='Material',
192
- hover_data=['Temperature Limit (°C)'],
193
- title='Recommended Materials')
194
- fig.update_layout(width=600, height=400)
195
-
196
- # Format results table
197
- result_table = filtered[['Material', 'Yield Strength (MPa)', 'Cost ($/kg)', 'Temperature Limit (°C)']].to_string(index=False)
198
-
199
- return f"**Recommended Materials:**\n```\n{result_table}\n```", fig
200
-
201
- # Gear Calculator
202
- def gear_calculator(teeth, module, power, speed):
203
- # Basic calculations
204
- pitch_diameter = teeth * module
205
- torque = (power * 1000 * 60) / (2 * np.pi * speed) # N⋅mm
206
- pitch_line_velocity = (np.pi * pitch_diameter * speed) / (60 * 1000) # m/s
207
-
208
- # Lewis equation for bending stress (simplified)
209
- lewis_factor = 0.154 - 0.912/teeth # Approximation
210
- face_width = 10 * module # Typical assumption
211
- tangential_force = 2 * torque / pitch_diameter
212
- bending_stress = tangential_force / (face_width * module * lewis_factor)
213
-
214
- # Material recommendation
215
- required_strength = bending_stress * 3 # Safety factor of 3
216
- suitable_materials = materials_db[materials_db['Yield Strength (MPa)'] >= required_strength/1e6]
217
-
218
- if not suitable_materials.empty:
219
- recommended = suitable_materials.iloc[0]['Material']
220
- else:
221
- recommended = "High-strength steel required"
222
-
223
- results = f"""
224
- **Gear Design Results:**
225
- - Pitch Diameter: {pitch_diameter:.2f} mm
226
- - Torque: {torque:.2f} N⋅mm
227
- - Pitch Line Velocity: {pitch_line_velocity:.2f} m/s
228
- - Estimated Face Width: {face_width:.2f} mm
229
- - Bending Stress: {bending_stress/1e6:.2f} MPa
230
- - Recommended Material: {recommended}
231
-
232
- **Design Notes:**
233
- - Use appropriate lubrication for velocities > 5 m/s
234
- - Consider heat treatment for high-stress applications
235
- - Verify contact stress separately
236
- """
237
-
238
- return results
239
-
240
- # Create Gradio Interface
241
- def create_interface():
242
- with gr.Blocks(title="MechMind - AI Mechanical Engineering Assistant", theme=gr.themes.Soft()) as app:
243
-
244
- gr.Markdown("""
245
- # 🔧 MechMind - AI Mechanical Engineering Assistant
246
- ### AI-powered platform for mechanical design, analysis, and optimization
247
- **Features:** 3D Model Generation | Advanced FEA | Smart Material Selection | AI Engineering Assistant
248
- """)
249
-
250
- with gr.Tabs():
251
- # Dashboard Tab
252
- with gr.TabItem("🏠 Dashboard"):
253
- gr.Markdown("## Engineering Dashboard")
254
-
255
- with gr.Row():
256
- gr.Markdown(f"**Materials Database:** {len(materials_db)} materials")
257
- gr.Markdown("**AI Models Active:** 5")
258
- gr.Markdown("**Success Rate:** 96.8%")
259
-
260
- gr.Markdown("### 📊 Materials Overview")
261
- materials_plot = px.scatter(materials_db,
262
- x='Yield Strength (MPa)',
263
- y='Cost ($/kg)',
264
- size='Density (kg/m³)',
265
- color='Material',
266
- title='Material Properties Comparison')
267
- gr.Plot(materials_plot)
268
-
269
- # 3D Design Generator Tab
270
- with gr.TabItem("🔧 3D Design Generator"):
271
- gr.Markdown("## AI-Powered 3D Model Generator")
272
-
273
- with gr.Row():
274
- with gr.Column():
275
- component_type = gr.Dropdown(["Gear", "Shaft"], label="Component Type", value="Gear")
276
-
277
- # Gear parameters
278
- teeth_input = gr.Slider(8, 100, value=20, step=1, label="Number of Teeth")
279
- module_input = gr.Slider(0.5, 10, value=2, step=0.5, label="Module")
280
-
281
- # Shaft parameters
282
- length_input = gr.Slider(50, 500, value=100, step=10, label="Length (mm)")
283
- diameter_input = gr.Slider(10, 100, value=20, step=5, label="Diameter (mm)")
284
-
285
- generate_btn = gr.Button("Generate 3D Model", variant="primary")
286
-
287
- with gr.Column():
288
- model_plot = gr.Plot(label="3D Model")
289
- model_specs = gr.Markdown(label="Specifications")
290
-
291
- generate_btn.click(
292
- generate_3d_model,
293
- inputs=[component_type, teeth_input, module_input, length_input, diameter_input],
294
- outputs=[model_plot, model_specs]
295
- )
296
-
297
- # Material Selector Tab
298
- with gr.TabItem("📊 Material Selector"):
299
- gr.Markdown("## Smart Material Selection")
300
-
301
- with gr.Row():
302
- with gr.Column():
303
- min_strength = gr.Slider(0, 1000, value=200, label="Min Yield Strength (MPa)")
304
- max_cost = gr.Slider(0, 50, value=10, label="Max Cost ($/kg)")
305
- min_temp = gr.Slider(0, 1000, value=200, label="Min Temperature Limit (°C)")
306
-
307
- select_btn = gr.Button("Find Materials", variant="primary")
308
-
309
- with gr.Column():
310
- selection_results = gr.Markdown(label="Recommended Materials")
311
- selection_plot = gr.Plot(label="Material Comparison")
312
-
313
- select_btn.click(
314
- material_selector,
315
- inputs=[min_strength, max_cost, min_temp],
316
- outputs=[selection_results, selection_plot]
317
- )
318
-
319
- # FEA Analyzer Tab
320
- with gr.TabItem("🔍 FEA Analyzer"):
321
- gr.Markdown("## Finite Element Analysis")
322
-
323
- with gr.Row():
324
- with gr.Column():
325
- fea_material = gr.Dropdown(materials_db['Material'].tolist(),
326
- label="Material", value="Steel 1045")
327
- fea_force = gr.Number(value=1000, label="Applied Force (N)")
328
- fea_area = gr.Number(value=100, label="Cross-sectional Area (mm²)")
329
- fea_length = gr.Number(value=1000, label="Component Length (mm)")
330
-
331
- fea_btn = gr.Button("Run FEA Analysis", variant="primary")
332
-
333
- with gr.Column():
334
- fea_plot = gr.Plot(label="Stress Distribution")
335
- fea_results = gr.Markdown(label="Analysis Results")
336
-
337
- fea_btn.click(
338
- perform_fea_analysis,
339
- inputs=[fea_material, fea_force, fea_area, fea_length],
340
- outputs=[fea_plot, fea_results]
341
- )
342
-
343
- # Gear Calculator Tab
344
- with gr.TabItem("⚙️ Gear Calculator"):
345
- gr.Markdown("## Gear Design Calculator")
346
-
347
- with gr.Row():
348
- with gr.Column():
349
- gear_teeth = gr.Number(value=20, label="Number of Teeth")
350
- gear_module = gr.Number(value=2, label="Module")
351
- gear_power = gr.Number(value=10, label="Power (kW)")
352
- gear_speed = gr.Number(value=1000, label="Speed (rpm)")
353
-
354
- calc_btn = gr.Button("Calculate Gear Parameters", variant="primary")
355
-
356
- with gr.Column():
357
- gear_results = gr.Markdown(label="Gear Design Results")
358
-
359
- calc_btn.click(
360
- gear_calculator,
361
- inputs=[gear_teeth, gear_module, gear_power, gear_speed],
362
- outputs=[gear_results]
363
- )
364
-
365
- # AI Assistant Tab
366
- with gr.TabItem("💬 AI Assistant"):
367
- gr.Markdown("## RAG AI Engineering Assistant")
368
-
369
- with gr.Row():
370
- with gr.Column():
371
- user_query = gr.Textbox(label="Ask your engineering question:",
372
- placeholder="e.g., How do I select materials for high-temperature applications?")
373
- ask_btn = gr.Button("Ask AI", variant="primary")
374
-
375
- gr.Markdown("### Quick Questions:")
376
- with gr.Row():
377
- material_btn = gr.Button("Material Selection")
378
- design_btn = gr.Button("Design Tips")
379
- stress_btn = gr.Button("Stress Analysis")
380
-
381
- with gr.Column():
382
- ai_response = gr.Markdown(label="AI Response")
383
-
384
- ask_btn.click(get_ai_response, inputs=[user_query], outputs=[ai_response])
385
- material_btn.click(lambda: get_ai_response("material selection"), outputs=[ai_response])
386
- design_btn.click(lambda: get_ai_response("design optimization"), outputs=[ai_response])
387
- stress_btn.click(lambda: get_ai_response("stress analysis"), outputs=[ai_response])
388
-
389
- gr.Markdown("""
390
- ---
391
- ### 🔧 MechMind AI Assistant - Empowering Mechanical Engineers with AI
392
- **Built with Gradio | Ready for Kaggle Deployment**
393
- """)
394
-
395
- return app
396
-
397
- # Launch the application
398
- if __name__ == "__main__":
399
- app = create_interface()
400
-
401
- # Kaggle-specific launch configuration
402
- import os
403
-
404
- # Check if running in Kaggle environment
405
- if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
406
- print("🔧 Detected Kaggle environment - launching locally")
407
- print("📍 Access the app at: http://localhost:7860")
408
- print("💡 If running in Kaggle, the app will be available in the output cell")
409
- app.launch(
410
- share=False, # Disable share link in Kaggle
411
- debug=False, # Disable debug mode
412
- server_name="0.0.0.0", # Allow external access
413
- server_port=7860, # Standard port
414
- inbrowser=False, # Don't try to open browser
415
- show_error=True, # Show errors in interface
416
- quiet=False # Show startup messages
417
- )
418
- else:
419
- # Try to launch with share for local development
420
- try:
421
- print("🚀 Launching MechMind with public share link...")
422
- app.launch(share=True, debug=False, server_name="0.0.0.0", server_port=7860)
423
- except Exception as e:
424
- print(f"❌ Share link creation failed: {e}")
425
- print("🔧 Launching locally instead...")
426
- app.launch(
427
- share=False,
428
- debug=False,
429
- server_name="0.0.0.0",
430
- server_port=7860,
431
- inbrowser=True
432
- )
 
1
+ !apt-get update -qq
2
+ !apt-get install -y -qq gmsh
3
+ !pip install torch --upgrade -q
4
+ !pip install --upgrade -q \
5
+ gmsh \
6
+ meshio \
7
+ trimesh \
8
+ numpy \
9
+ pandas \
10
+ scikit-learn \
11
+ matplotlib \
12
+ plotly \
13
+ ipywidgets \
14
+ gradio
15
+ !pip install --upgrade -q jax jaxlib
16
+
17
+ # ===== CELL 1: SYSTEM INSTALLATION (RUN THIS FIRST) =====
18
+ # It is recommended to use the separate, more robust dependency installation
19
+ # script provided previously. This cell is a simplified version.
20
+ import subprocess
21
+ import sys
22
+ import os
23
+
24
+ def install_dependencies():
25
+ """Installs all necessary system and Python packages for Colab."""
26
+ print("🚀 Starting installation...")
27
+ try:
28
+ # Step 1: Install system packages like GMSH
29
+ print("🔧 Installing system package: GMSH...")
30
+ subprocess.run(["apt-get", "update", "-qq"], check=True, capture_output=True)
31
+ subprocess.run(["apt-get", "install", "-y", "-qq", "gmsh"], check=True, capture_output=True)
32
+ print(" GMSH installed.")
33
+
34
+ # Step 2: Install PyTorch and PyTorch Geometric correctly
35
+ print("\n🧠 Installing PyTorch & PyTorch Geometric...")
36
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "-q"])
37
+ # This command is crucial as it fetches the correct PyG versions
38
+ pyg_install_command = [
39
+ sys.executable, "-m", "pip", "install",
40
+ "torch-scatter", "torch-sparse", "torch-cluster", "torch-spline-conv", "torch-geometric",
41
+ "-f", f"https://data.pyg.org/whl/torch-{subprocess.check_output([sys.executable, '-c', 'import torch; print(torch.__version__)']).decode().strip()}.html",
42
+ "-q"
43
+ ]
44
+ subprocess.check_call(pyg_install_command)
45
+ print(" ✅ PyTorch & PyG installed.")
46
+
47
+ # Step 3: Install other core packages
48
+ print("\n📦 Installing core libraries...")
49
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade",
50
+ "gmsh", "meshio", "trimesh", "numpy", "pandas",
51
+ "scikit-learn", "matplotlib", "plotly", "ipywidgets", "gradio", "-q"])
52
+ print(" ✅ Core libraries installed.")
53
+
54
+ print("\n🎉 Installation complete! Please restart the runtime and run the next cell.")
55
+
56
+ except Exception as e:
57
+ print(f"❌ An error occurred during installation: {e}")
58
+ print(" Please check the error message and try again.")
59
+
60
+ # Run installation
61
+ # install_dependencies()
62
+
63
+
64
+ # ===== CELL 2: MAIN APPLICATION (RUN AFTER RESTART) =====
65
+
66
+ # Safe imports with fallbacks
67
+ def safe_import():
68
+ """Safely import all required packages after installation."""
69
+ global gmsh, np, torch, nn, F, Data, GCNConv, pyg_utils, meshio, go, plt, pd, widgets, gr
70
+ print("🔬 Importing necessary libraries...")
71
+ try:
72
+ import numpy as np
73
+ import pandas as pd
74
+ import matplotlib.pyplot as plt
75
+
76
+ # Mesh and geometry
77
+ import gmsh
78
+ import meshio
79
+
80
+ # PyTorch and PyTorch Geometric
81
+ import torch
82
+ import torch.nn as nn
83
+ import torch.nn.functional as F
84
+ from torch_geometric.data import Data
85
+ from torch_geometric.nn import GCNConv
86
+ import torch_geometric.utils as pyg_utils
87
+
88
+ # Visualization
89
+ import plotly.graph_objects as go
90
+
91
+ # UI/UX
92
+ import gradio as gr
93
+ import ipywidgets as widgets
94
+ from IPython.display import display, clear_output
95
+
96
+ import warnings
97
+ warnings.filterwarnings('ignore')
98
+
99
+ print("✅ All packages imported successfully!")
100
+ return True
101
+
102
+ except ImportError as e:
103
+ print(f"❌ Critical import failure: {e}")
104
+ print(" Please ensure Cell 1 was run and the runtime was restarted.")
105
+ return False
106
+ except Exception as e:
107
+ print(f"❌ An unexpected error occurred during import: {e}")
108
+ return False
109
+
110
+
111
+ # Import all packages
112
+ if not safe_import():
113
+ # Stop execution if imports fail
114
+ sys.exit("Stopping due to import errors.")
115
+
116
+
117
+ # ===== STEP 1: MESH GENERATION =====
118
+ print("\n🔧 Step 1: Mesh generation and processing")
119
+
120
+ def create_beam_geometry(length=10.0, width=1.0, height=2.0, mesh_size=0.5):
121
+ """Create a 3D beam geometry using GMSH."""
122
+ try:
123
+ gmsh.initialize()
124
+ gmsh.model.add("cantilever_beam")
125
+ beam = gmsh.model.occ.addBox(0, 0, 0, length, width, height)
126
+ gmsh.model.occ.synchronize()
127
+ gmsh.option.setNumber("Mesh.CharacteristicLengthMin", mesh_size * 0.5)
128
+ gmsh.option.setNumber("Mesh.CharacteristicLengthMax", mesh_size)
129
+ gmsh.model.mesh.generate(3)
130
+ gmsh.write("beam_mesh.msh")
131
+ gmsh.finalize()
132
+ print(f"✅ GMSH geometry created ('beam_mesh.msh')")
133
+ return "beam_mesh.msh"
134
+ except Exception as e:
135
+ print(f"❌ GMSH geometry creation failed: {e}. Using a fallback mesh.")
136
+ return create_fallback_mesh()
137
+
138
+ def create_fallback_mesh():
139
+ """Create a simple fallback mesh if GMSH fails."""
140
+ print("🔄 Creating a fallback cubic mesh...")
141
+ points = np.array([
142
+ [0, 0, 0], [10, 0, 0], [10, 1, 0], [0, 1, 0],
143
+ [0, 0, 2], [10, 0, 2], [10, 1, 2], [0, 1, 2]
144
+ ], dtype=np.float32)
145
+ cells = [("hexahedron", np.array([[0, 1, 2, 3, 4, 5, 6, 7]]))]
146
+ mesh = meshio.Mesh(points, cells)
147
+ mesh.write("fallback_mesh.vtk")
148
+ print("✅ Fallback mesh created ('fallback_mesh.vtk')")
149
+ return "fallback_mesh.vtk"
150
+
151
+ mesh_file = create_beam_geometry()
152
+
153
+
154
+ # ===== STEP 2: MESH TO GRAPH CONVERSION =====
155
+ print("\n🔄 Step 2: Converting mesh to graph representation")
156
+
157
+ def mesh_to_graph(mesh_file):
158
+ """Convert a mesh file to a PyTorch Geometric graph."""
159
+ try:
160
+ mesh = meshio.read(mesh_file)
161
+ points = mesh.points.astype(np.float32)
162
+
163
+ cells = mesh.get_cells_type("tetra")
164
+ if len(cells) == 0:
165
+ cells = mesh.get_cells_type("triangle")
166
+ if len(cells) == 0:
167
+ hex_cells = mesh.get_cells_type("hexahedron")
168
+ temp_cells = []
169
+ for h in hex_cells:
170
+ temp_cells.extend([[h[0],h[1],h[2],h[4]],[h[1],h[2],h[3],h[7]]])
171
+ cells = np.array(temp_cells)
172
+
173
+ # ----- MAJOR FIX HERE -----
174
+ # The function `face_to_edge_index` was removed from torch_geometric.
175
+ # This is the modern, correct way to compute the edge index from faces.
176
+ # We get all edges from the faces and then make the graph undirected.
177
+ faces_tensor = torch.tensor(cells[:, :3].T, dtype=torch.long)
178
+ edge_index = torch.cat([
179
+ faces_tensor[[0, 1]], faces_tensor[[1, 2]], faces_tensor[[2, 0]]
180
+ ], dim=1)
181
+ edge_index = pyg_utils.to_undirected(edge_index)
182
+ # ----- END OF FIX -----
183
+
184
+ coords = torch.tensor(points, dtype=torch.float32)
185
+ centroid = coords.mean(dim=0)
186
+ dist_to_centroid = torch.norm(coords - centroid, dim=1, keepdim=True)
187
+ coords_normalized = (coords - centroid) / (coords.std(dim=0) + 1e-8)
188
+ x = torch.cat([coords_normalized, dist_to_centroid], dim=1)
189
+
190
+ graph = Data(x=x, edge_index=edge_index, pos=coords)
191
+ print(f"✅ Graph created: {graph.num_nodes} nodes, {graph.num_edges} edges")
192
+ return graph, points, cells
193
+
194
+ except Exception as e:
195
+ print(f"❌ Mesh conversion failed: {e}. Cannot proceed.")
196
+ return None, None, None
197
+
198
+ graph, points, cells = mesh_to_graph(mesh_file)
199
+ if graph is None:
200
+ sys.exit("Stopping due to mesh processing errors.")
201
+
202
+
203
+ # ===== STEP 3: ACCURATE PHYSICS-BASED ANALYSIS (FEM) =====
204
+ print("\n⚛️ Step 3: Defining accurate physics-based analysis model")
205
+
206
+ def cantilever_beam_fem(points, E=210e9, load_magnitude=-1000):
207
+ """Calculates displacement and stress for a cantilever beam using analytical formulas."""
208
+ length = points[:, 0].max()
209
+ height = points[:, 2].max()
210
+ width = points[:, 1].max()
211
+ I = (width * height**3) / 12
212
+
213
+ fixed_nodes = np.where(points[:, 0] < 1e-6)[0]
214
+ loaded_nodes = np.where(points[:, 0] > length - 1e-6)[0]
215
+
216
+ displacement = np.zeros_like(points)
217
+ stress = np.zeros(len(points))
218
+ P = -load_magnitude
219
+
220
+ for i in range(len(points)):
221
+ x, _, z = points[i]
222
+ deflection = (P * x**2) / (6 * E * I) * (3 * length - x)
223
+ displacement[i, 2] = deflection
224
+ moment = P * (length - x)
225
+ z_from_neutral_axis = z - (height / 2)
226
+ stress[i] = (moment * z_from_neutral_axis) / I
227
+
228
+ return displacement, stress, fixed_nodes, loaded_nodes
229
+
230
+
231
+ # ===== STEP 4: AI SURROGATE MODEL & LIVE TRAINING =====
232
+ print("\n🧠 Step 4: Building and training AI surrogate model")
233
+
234
+ class EnhancedSurrogateNet(nn.Module):
235
+ def __init__(self, in_channels=4, hidden_channels=64, out_channels=4, num_layers=3):
236
+ super().__init__()
237
+ self.convs = nn.ModuleList()
238
+ self.batch_norms = nn.ModuleList()
239
+ self.convs.append(GCNConv(in_channels, hidden_channels))
240
+ self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
241
+ for _ in range(num_layers - 2):
242
+ self.convs.append(GCNConv(hidden_channels, hidden_channels))
243
+ self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
244
+ self.convs.append(GCNConv(hidden_channels, out_channels))
245
+ self.dropout = nn.Dropout(0.2)
246
+
247
+ def forward(self, data):
248
+ x, edge_index = data.x, data.edge_index
249
+ for i in range(len(self.convs) - 1):
250
+ x = self.convs[i](x, edge_index)
251
+ if x.shape[0] > 1:
252
+ x = self.batch_norms[i](x)
253
+ x = F.relu(x)
254
+ x = self.dropout(x)
255
+ x = self.convs[-1](x, edge_index)
256
+ return x
257
+
258
+ def train_surrogate_model(model, graph_data, training_status_callback):
259
+ """Trains the surrogate model on synthetically generated data."""
260
+ print("🚀 Starting AI model training...")
261
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
262
+ loss_fn = nn.MSELoss()
263
+ training_data = []
264
+ load_scenarios = np.linspace(-500, -5000, 10)
265
+ for load in load_scenarios:
266
+ disp_fem, stress_fem, _, _ = cantilever_beam_fem(points, load_magnitude=load)
267
+ target = torch.tensor(np.hstack([disp_fem, stress_fem[:, np.newaxis]]), dtype=torch.float32)
268
+ training_data.append(target)
269
+
270
+ model.train()
271
+ for epoch in range(100):
272
+ total_loss = 0
273
+ for target_data in training_data:
274
+ optimizer.zero_grad()
275
+ prediction = model(graph_data)
276
+ loss = loss_fn(prediction, target_data)
277
+ loss.backward()
278
+ optimizer.step()
279
+ total_loss += loss.item()
280
+ if (epoch + 1) % 20 == 0:
281
+ status_msg = f"Epoch {epoch+1}/100, Loss: {total_loss/len(training_data):.4f}"
282
+ print(f" {status_msg}")
283
+ if training_status_callback:
284
+ training_status_callback(status_msg)
285
+ model.eval()
286
+ print("✅ AI model training complete!")
287
+ return model
288
+
289
+
290
+ # ===== STEP 5: GRADIO INTERFACE & APPLICATION LOGIC =====
291
+ print("\n🎨 Step 5: Creating Gradio user interface")
292
+
293
+ class StructuralAnalysisApp:
294
+ def __init__(self, points, graph):
295
+ self.points = points
296
+ self.graph = graph
297
+ self.model = EnhancedSurrogateNet(in_channels=graph.x.shape[1], out_channels=4)
298
+
299
+ def train_model_for_ui(self, training_status_update):
300
+ self.model = train_surrogate_model(self.model, self.graph, training_status_update)
301
+ return "Model trained successfully! Ready for analysis."
302
+
303
+ def analyze(self, young_modulus, load_magnitude):
304
+ try:
305
+ E = float(young_modulus) * 1e9
306
+ load = float(load_magnitude)
307
+ disp_fem, stress_fem, fixed, loaded = cantilever_beam_fem(self.points, E=E, load_magnitude=load)
308
+ disp_mag_fem = np.linalg.norm(disp_fem, axis=1)
309
+ with torch.no_grad():
310
+ prediction = self.model(self.graph)
311
+ disp_surrogate = prediction[:, :3].numpy()
312
+ stress_surrogate = prediction[:, 3].numpy()
313
+ disp_mag_surrogate = np.linalg.norm(disp_surrogate, axis=1)
314
+ fig = self.create_3d_plot(disp_mag_fem, stress_fem, fixed, E/1e9, load)
315
+ results_text = self.format_results_text(
316
+ disp_mag_fem, stress_fem, disp_mag_surrogate, stress_surrogate, E/1e9, load, fixed
317
+ )
318
+ return fig, results_text
319
+ except Exception as e:
320
+ error_msg = f"❌ Analysis failed: {str(e)}"
321
+ print(error_msg)
322
+ return go.Figure(), error_msg
323
+
324
+ def create_3d_plot(self, disp_mag, stress, fixed_nodes, E, load):
325
  fig = go.Figure()
326
+ fig.add_trace(go.Scatter3d(
327
+ x=self.points[:, 0], y=self.points[:, 1], z=self.points[:, 2],
328
+ mode='markers',
329
+ marker=dict(
330
+ size=4, color=disp_mag, colorscale='Viridis',
331
+ colorbar=dict(title="Displacement (m)"),
332
+ cmin=disp_mag.min(), cmax=disp_mag.max()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  ),
334
+ text=[f"Stress: {s/1e6:.2f} MPa" for s in stress],
335
+ hoverinfo='text', name='Deformation Field'
336
+ ))
337
+ fig.add_trace(go.Scatter3d(
338
+ x=self.points[fixed_nodes, 0], y=self.points[fixed_nodes, 1], z=self.points[fixed_nodes, 2],
339
+ mode='markers', marker=dict(size=6, color='red', symbol='x'), name='Fixed Support'
340
+ ))
341
+ fig.update_layout(
342
+ title=f"Analysis Results (E={E:.0f} GPa, Load={load:.0f} N)",
343
+ scene=dict(xaxis_title="X (m)", yaxis_title="Y (m)", zaxis_title="Z (m)"),
344
+ width=800, height=600, margin=dict(l=0, r=0, b=0, t=40)
345
  )
346
+ return fig
347
+
348
+ def format_results_text(self, disp_fem, stress_fem, disp_surrogate, stress_surrogate, E, load, fixed):
349
+ corr_disp = np.corrcoef(disp_fem, disp_surrogate)[0, 1]
350
+ corr_stress = np.corrcoef(stress_fem, stress_surrogate)[0, 1]
351
+ return f"""
352
+ ### 📊 Analysis Summary
353
+ | Parameter | Value |
354
+ | :--- | :--- |
355
+ | **Young's Modulus** | {E:.0f} GPa |
356
+ | **Load Magnitude** | {load:.0f} N |
357
+ | **Mesh Nodes** | {len(self.points):,} |
358
+ | **Fixed Nodes** | {len(fixed):,} |
359
+
360
+ ### 🤖 AI vs. FEM Comparison
361
+ | Metric | FEM (Ground Truth) | AI Surrogate | Correlation |
362
+ | :--- | :--- | :--- | :--- |
363
+ | **Max Displacement** | `{disp_fem.max():.3e} m` | `{disp_surrogate.max():.3e} m` | **`{corr_disp:.3f}`** |
364
+ | **Max Stress** | `{stress_fem.max()/1e6:.3f} MPa` | `{stress_surrogate.max()/1e6:.3f} MPa` | **`{corr_stress:.3f}`** |
365
  """
366
+
367
+ app = StructuralAnalysisApp(points, graph)
368
+
369
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI Structural Analysis") as demo:
370
+ gr.Markdown("# 🏗️ AI-Powered Structural Analysis")
371
+ gr.Markdown("An interactive tool combining Finite Element Method (FEM) with a Graph Neural Network (GNN) surrogate model. The GNN is trained in real-time on FEM data to provide fast, accurate predictions.")
372
+ with gr.Row():
373
+ with gr.Column(scale=1):
374
+ gr.Markdown("### 🛠️ Parameters")
375
+ young_modulus = gr.Slider(minimum=50, maximum=300, value=210, step=10, label="Young's Modulus (GPa)")
376
+ load_magnitude = gr.Slider(minimum=-5000, maximum=-100, value=-1000, step=100, label="Load Magnitude (N)")
377
+ with gr.Accordion("Advanced: AI Model Training", open=False):
378
+ training_status = gr.Textbox(label="Training Status", value="Model is not trained yet.", interactive=False)
379
+ train_btn = gr.Button("🧠 Train AI Model")
380
+ analyze_btn = gr.Button("🚀 Run Analysis", variant="primary")
381
+ with gr.Column(scale=2):
382
+ gr.Markdown("### 📈 Visualization & Results")
383
+ plot_output = gr.Plot(label="3D Visualization")
384
+ results_text = gr.Markdown()
385
+ train_btn.click(fn=app.train_model_for_ui, inputs=[], outputs=[training_status], show_progress='full')
386
+ analyze_btn.click(fn=app.analyze, inputs=[young_modulus, load_magnitude], outputs=[plot_output, results_text])
387
+ demo.load(fn=app.train_model_for_ui, inputs=[], outputs=[training_status], show_progress='full')
388
+
389
+ print("🌐 Launching Gradio interface...")
390
+ demo.launch(share=True, debug=True)