NavyDevilDoc commited on
Commit
d44e379
·
verified ·
1 Parent(s): b6a881b

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +172 -38
src/streamlit_app.py CHANGED
@@ -1,8 +1,8 @@
1
  import streamlit as st
2
  import numpy as np
3
- import matplotlib.pyplot as plt
4
  import plotly.graph_objects as go
5
  import plotly.express as px
 
6
  from kinematics_visualizer import Motion1D, Motion2D, KinematicsVisualizer
7
 
8
  # Configure Streamlit page
@@ -70,13 +70,6 @@ def main():
70
  ["1D Motion", "2D Projectile Motion", "Compare Motions"]
71
  )
72
 
73
- # Clear matplotlib figures when switching tabs to prevent cross-contamination
74
- if 'current_tab' not in st.session_state:
75
- st.session_state.current_tab = tutorial_type
76
- elif st.session_state.current_tab != tutorial_type:
77
- plt.close('all') # Clear all figures when switching tabs
78
- st.session_state.current_tab = tutorial_type
79
-
80
  if tutorial_type == "1D Motion":
81
  show_1d_motion()
82
  elif tutorial_type == "2D Projectile Motion":
@@ -84,6 +77,135 @@ def main():
84
  elif tutorial_type == "Compare Motions":
85
  show_motion_comparison()
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def show_1d_motion():
88
  st.markdown('<div class="section-header">📏 One-Dimensional Motion</div>', unsafe_allow_html=True)
89
 
@@ -171,11 +293,9 @@ def show_1d_motion():
171
  else:
172
  motion_type = "Constant Velocity Motion"
173
 
174
- visualizer = KinematicsVisualizer()
175
- fig = visualizer.plot_1d_motion(motion, duration, motion_type)
176
-
177
- st.pyplot(fig)
178
- plt.close()
179
 
180
  def show_2d_motion():
181
  st.markdown('<div class="section-header">🎯 Two-Dimensional Projectile Motion</div>', unsafe_allow_html=True)
@@ -309,9 +429,6 @@ def show_2d_motion():
309
 
310
  st.markdown("---")
311
 
312
- # Launch parameter presets and sliders (same as before)
313
- # ... [include all the existing preset and slider code] ...
314
-
315
  # Initialize default values
316
  default_speed = 25.0
317
  default_angle = 45.0
@@ -435,26 +552,31 @@ def show_2d_motion():
435
 
436
  with col2:
437
  # Create and display trajectory with model info
438
- visualizer = KinematicsVisualizer()
439
-
440
  model_info = f"({'Sphere' if is_sphere else 'Point Mass'})"
441
  air_info = " (with Air Resistance)" if air_resistance_enabled else " (No Air Resistance)"
442
  trajectory_title = f"Projectile Motion - {launch_angle:.0f}° Launch {model_info}{air_info}"
443
 
444
- fig = visualizer.plot_2d_trajectory(motion, title=trajectory_title)
445
- st.pyplot(fig)
446
- plt.close()
447
 
448
  # Show model comparison if using sphere
449
  if is_sphere and 'mass_g' in info:
450
  st.markdown("### 📊 Sphere vs Point Mass Comparison")
451
 
452
- fig_comp, ax = plt.subplots(figsize=(12, 6))
 
453
 
454
  # Plot sphere model
455
  data_sphere = motion.trajectory_data(motion.calculate_flight_time())
456
- ax.plot(data_sphere['x'], data_sphere['y'],
457
- 'r-', linewidth=3, label=f'Sphere Model ({info["mass_g"]:.0f}g)', alpha=0.8)
 
 
 
 
 
 
458
 
459
  # Plot equivalent point mass
460
  motion_point = Motion2D(
@@ -464,22 +586,30 @@ def show_2d_motion():
464
  is_sphere=False
465
  )
466
  data_point = motion_point.trajectory_data(motion_point.calculate_flight_time())
467
- ax.plot(data_point['x'], data_point['y'],
468
- 'b--', linewidth=2, label='Point Mass Model', alpha=0.7)
 
 
 
 
 
 
469
 
470
- ax.set_xlabel('Horizontal Position (m)')
471
- ax.set_ylabel('Vertical Position (m)')
472
- ax.set_title('Sphere Model vs Point Mass Model')
473
- ax.grid(True, alpha=0.3)
474
- ax.legend()
475
- ax.set_ylim(bottom=0)
 
 
 
476
 
477
- st.pyplot(fig_comp)
478
- plt.close()
479
-
480
- # Add this import at the top
481
- import plotly.graph_objects as go
482
- import plotly.express as px
483
 
484
  def show_motion_comparison():
485
  st.markdown('<div class="section-header">📊 Compare Different Trajectories</div>', unsafe_allow_html=True)
@@ -575,6 +705,10 @@ def show_motion_comparison():
575
  fig.update_xaxes(range=[0, None])
576
  fig.update_yaxes(range=[0, None])
577
 
 
 
 
 
578
  return fig, trajectories_data
579
 
580
  # Create cached comparison data
 
1
  import streamlit as st
2
  import numpy as np
 
3
  import plotly.graph_objects as go
4
  import plotly.express as px
5
+ from plotly.subplots import make_subplots
6
  from kinematics_visualizer import Motion1D, Motion2D, KinematicsVisualizer
7
 
8
  # Configure Streamlit page
 
70
  ["1D Motion", "2D Projectile Motion", "Compare Motions"]
71
  )
72
 
 
 
 
 
 
 
 
73
  if tutorial_type == "1D Motion":
74
  show_1d_motion()
75
  elif tutorial_type == "2D Projectile Motion":
 
77
  elif tutorial_type == "Compare Motions":
78
  show_motion_comparison()
79
 
80
+ def create_1d_motion_plot(motion: Motion1D, duration: float, title: str):
81
+ """Create 1D motion plots using Plotly"""
82
+ # Generate time arrays
83
+ t, x, v, a = motion.time_arrays(duration, dt=0.01)
84
+
85
+ # Create subplots
86
+ fig = make_subplots(
87
+ rows=3, cols=1,
88
+ subplot_titles=('Position vs Time', 'Velocity vs Time', 'Acceleration vs Time'),
89
+ vertical_spacing=0.08,
90
+ shared_xaxes=True
91
+ )
92
+
93
+ # Position plot
94
+ fig.add_trace(
95
+ go.Scatter(x=t, y=x, mode='lines', name='Position',
96
+ line=dict(color='blue', width=3),
97
+ hovertemplate='Time: %{x:.1f} s<br>Position: %{y:.1f} m<extra></extra>'),
98
+ row=1, col=1
99
+ )
100
+
101
+ # Velocity plot
102
+ fig.add_trace(
103
+ go.Scatter(x=t, y=v, mode='lines', name='Velocity',
104
+ line=dict(color='red', width=3),
105
+ hovertemplate='Time: %{x:.1f} s<br>Velocity: %{y:.1f} m/s<extra></extra>'),
106
+ row=2, col=1
107
+ )
108
+
109
+ # Acceleration plot
110
+ fig.add_trace(
111
+ go.Scatter(x=t, y=a, mode='lines', name='Acceleration',
112
+ line=dict(color='green', width=3),
113
+ hovertemplate='Time: %{x:.1f} s<br>Acceleration: %{y:.1f} m/s²<extra></extra>'),
114
+ row=3, col=1
115
+ )
116
+
117
+ # Update layout
118
+ fig.update_layout(
119
+ title=dict(text=title, font=dict(size=18)),
120
+ showlegend=False,
121
+ height=800,
122
+ template='plotly_white'
123
+ )
124
+
125
+ # Update y-axis labels
126
+ fig.update_yaxes(title_text="Position (m)", row=1, col=1)
127
+ fig.update_yaxes(title_text="Velocity (m/s)", row=2, col=1)
128
+ fig.update_yaxes(title_text="Acceleration (m/s²)", row=3, col=1)
129
+ fig.update_xaxes(title_text="Time (s)", row=3, col=1)
130
+
131
+ # Add grid
132
+ fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.3)')
133
+ fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.3)')
134
+
135
+ return fig
136
+
137
+ def create_2d_trajectory_plot(motion: Motion2D, title: str):
138
+ """Create 2D trajectory plot using Plotly"""
139
+ flight_time = motion.calculate_flight_time()
140
+ data = motion.trajectory_data(flight_time)
141
+
142
+ fig = go.Figure()
143
+
144
+ # Add trajectory line
145
+ fig.add_trace(go.Scatter(
146
+ x=data['x'],
147
+ y=data['y'],
148
+ mode='lines',
149
+ name='Trajectory',
150
+ line=dict(color='blue', width=4),
151
+ hovertemplate='X: %{x:.1f} m<br>Y: %{y:.1f} m<extra></extra>'
152
+ ))
153
+
154
+ # Add launch point
155
+ fig.add_trace(go.Scatter(
156
+ x=[motion.launch_x],
157
+ y=[motion.launch_height],
158
+ mode='markers',
159
+ name='Launch Point',
160
+ marker=dict(color='green', size=12, symbol='circle'),
161
+ hovertemplate='Launch<br>X: %{x:.1f} m<br>Y: %{y:.1f} m<extra></extra>'
162
+ ))
163
+
164
+ # Add landing point
165
+ if len(data['x']) > 0:
166
+ fig.add_trace(go.Scatter(
167
+ x=[data['x'][-1]],
168
+ y=[data['y'][-1]],
169
+ mode='markers',
170
+ name='Landing Point',
171
+ marker=dict(color='red', size=12, symbol='square'),
172
+ hovertemplate='Landing<br>X: %{x:.1f} m<br>Y: %{y:.1f} m<extra></extra>'
173
+ ))
174
+
175
+ # Add maximum height point
176
+ max_height_idx = np.argmax(data['y'])
177
+ if max_height_idx > 0:
178
+ fig.add_trace(go.Scatter(
179
+ x=[data['x'][max_height_idx]],
180
+ y=[data['y'][max_height_idx]],
181
+ mode='markers',
182
+ name='Max Height',
183
+ marker=dict(color='orange', size=10, symbol='triangle-up'),
184
+ hovertemplate='Max Height<br>X: %{x:.1f} m<br>Y: %{y:.1f} m<extra></extra>'
185
+ ))
186
+
187
+ # Update layout
188
+ fig.update_layout(
189
+ title=dict(text=title, font=dict(size=18)),
190
+ xaxis_title="Horizontal Position (m)",
191
+ yaxis_title="Vertical Position (m)",
192
+ showlegend=True,
193
+ hovermode='closest',
194
+ template='plotly_white',
195
+ height=600,
196
+ margin=dict(l=50, r=50, t=80, b=50)
197
+ )
198
+
199
+ # Set axis ranges
200
+ fig.update_xaxes(range=[0, max(data['x']) * 1.1 if len(data['x']) > 0 else 10])
201
+ fig.update_yaxes(range=[0, max(data['y']) * 1.2 if len(data['y']) > 0 else 10])
202
+
203
+ # Add grid
204
+ fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.3)')
205
+ fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.3)')
206
+
207
+ return fig
208
+
209
  def show_1d_motion():
210
  st.markdown('<div class="section-header">📏 One-Dimensional Motion</div>', unsafe_allow_html=True)
211
 
 
293
  else:
294
  motion_type = "Constant Velocity Motion"
295
 
296
+ # Create and display Plotly chart
297
+ fig = create_1d_motion_plot(motion, duration, motion_type)
298
+ st.plotly_chart(fig, use_container_width=True)
 
 
299
 
300
  def show_2d_motion():
301
  st.markdown('<div class="section-header">🎯 Two-Dimensional Projectile Motion</div>', unsafe_allow_html=True)
 
429
 
430
  st.markdown("---")
431
 
 
 
 
432
  # Initialize default values
433
  default_speed = 25.0
434
  default_angle = 45.0
 
552
 
553
  with col2:
554
  # Create and display trajectory with model info
 
 
555
  model_info = f"({'Sphere' if is_sphere else 'Point Mass'})"
556
  air_info = " (with Air Resistance)" if air_resistance_enabled else " (No Air Resistance)"
557
  trajectory_title = f"Projectile Motion - {launch_angle:.0f}° Launch {model_info}{air_info}"
558
 
559
+ # Create and display main trajectory plot
560
+ fig = create_2d_trajectory_plot(motion, trajectory_title)
561
+ st.plotly_chart(fig, use_container_width=True)
562
 
563
  # Show model comparison if using sphere
564
  if is_sphere and 'mass_g' in info:
565
  st.markdown("### 📊 Sphere vs Point Mass Comparison")
566
 
567
+ # Create comparison plot
568
+ fig_comp = go.Figure()
569
 
570
  # Plot sphere model
571
  data_sphere = motion.trajectory_data(motion.calculate_flight_time())
572
+ fig_comp.add_trace(go.Scatter(
573
+ x=data_sphere['x'],
574
+ y=data_sphere['y'],
575
+ mode='lines',
576
+ name=f'Sphere Model ({info["mass_g"]:.0f}g)',
577
+ line=dict(color='red', width=3),
578
+ hovertemplate='<b>Sphere</b><br>X: %{x:.1f} m<br>Y: %{y:.1f} m<extra></extra>'
579
+ ))
580
 
581
  # Plot equivalent point mass
582
  motion_point = Motion2D(
 
586
  is_sphere=False
587
  )
588
  data_point = motion_point.trajectory_data(motion_point.calculate_flight_time())
589
+ fig_comp.add_trace(go.Scatter(
590
+ x=data_point['x'],
591
+ y=data_point['y'],
592
+ mode='lines',
593
+ name='Point Mass Model',
594
+ line=dict(color='blue', width=2, dash='dash'),
595
+ hovertemplate='<b>Point Mass</b><br>X: %{x:.1f} m<br>Y: %{y:.1f} m<extra></extra>'
596
+ ))
597
 
598
+ # Update layout for comparison plot
599
+ fig_comp.update_layout(
600
+ title="Sphere Model vs Point Mass Model",
601
+ xaxis_title="Horizontal Position (m)",
602
+ yaxis_title="Vertical Position (m)",
603
+ showlegend=True,
604
+ template='plotly_white',
605
+ height=400
606
+ )
607
 
608
+ fig_comp.update_yaxes(range=[0, None])
609
+ fig_comp.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.3)')
610
+ fig_comp.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.3)')
611
+
612
+ st.plotly_chart(fig_comp, use_container_width=True)
 
613
 
614
  def show_motion_comparison():
615
  st.markdown('<div class="section-header">📊 Compare Different Trajectories</div>', unsafe_allow_html=True)
 
705
  fig.update_xaxes(range=[0, None])
706
  fig.update_yaxes(range=[0, None])
707
 
708
+ # Add grid
709
+ fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.3)')
710
+ fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.3)')
711
+
712
  return fig, trajectories_data
713
 
714
  # Create cached comparison data