Nugget-cloud commited on
Commit
7908a22
·
verified ·
1 Parent(s): 236efbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -858
app.py CHANGED
@@ -1,882 +1,50 @@
1
  import streamlit as st
2
  import joblib
3
- import pandas as pd
4
- import numpy as np
5
- import plotly.graph_objects as go
6
- import plotly.express as px
7
- from datetime import datetime
8
  import time
9
  import io
 
10
 
11
  # ==================== PAGE CONFIG ====================
12
  st.set_page_config(
13
- page_title="NASA Exoplanet AI Detector",
14
- page_icon="🪐",
15
  layout="wide",
16
- initial_sidebar_state="expanded"
17
  )
18
 
19
- # ==================== CUSTOM CSS ====================
20
- st.markdown("""
21
- <style>
22
- .main-header {
23
- font-size: 3.5rem;
24
- font-weight: bold;
25
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
26
- -webkit-background-clip: text;
27
- -webkit-text-fill-color: transparent;
28
- text-align: center;
29
- padding: 20px;
30
- }
31
- .sub-header {
32
- text-align: center;
33
- color: #666;
34
- font-size: 1.2rem;
35
- }
36
- .metric-card {
37
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
38
- padding: 20px;
39
- border-radius: 10px;
40
- color: white;
41
- text-align: center;
42
- box-shadow: 0 4px 6px rgba(0,0,0,0.1);
43
- }
44
- .stButton>button {
45
- width: 100%;
46
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
47
- color: white;
48
- font-weight: bold;
49
- }
50
- </style>
51
- """, unsafe_allow_html=True)
52
-
53
  # ==================== LOAD MODEL ====================
54
  @st.cache_resource
55
  def load_model_package():
56
  """Load the complete model package"""
57
  try:
58
- # ⚠️ UPDATE THIS FILENAME WITH YOUR ACTUAL MODEL FILE
59
  package = joblib.load("exoplanet_final_model.joblib")
 
60
  return package
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  except Exception as e:
62
- st.error(f" Error loading model: {e}")
63
- st.error("Please update the filename in the code (line 47)")
64
- st.stop()
65
 
66
  # Load package
67
  with st.spinner(" Loading AI model..."):
68
  package = load_model_package()
69
 
70
- model = package['ensemble_model']
71
- scaler = package['scaler']
72
- feature_names = package['feature_names']
73
- metadata = package['metadata']
74
-
75
- # ==================== HEADER ====================
76
- st.markdown('<div class="main-header">🪐 NASA Space Apps Challenge 2025</div>', unsafe_allow_html=True)
77
- st.markdown('<div class="sub-header">AI-Powered Exoplanet Detection System</div>', unsafe_allow_html=True)
78
- st.markdown(f"<div class='sub-header'>Trained on {', '.join(metadata['missions'])} mission data</div>", unsafe_allow_html=True)
79
-
80
- # ==================== SIDEBAR ====================
81
- with st.sidebar:
82
- st.image("https://www.nasa.gov/wp-content/uploads/2018/07/nasa-logo.svg", width=200)
83
-
84
- st.markdown("---")
85
- st.subheader(" Ensemble Components")
86
- for model_name in metadata['ensemble_model_names']:
87
- st.text(f"• {model_name}")
88
-
89
- st.markdown("---")
90
- st.subheader(" Missions")
91
- for mission in metadata['missions']:
92
- st.text(f"• {mission}")
93
-
94
- st.markdown("---")
95
- st.info(f"**Model Version:** {metadata['version']}")
96
- st.info(f"**Created:** {metadata['created_date']}")
97
-
98
- # ==================== MAIN TABS ====================
99
- tab1, tab2, tab3, tab4, tab5 = st.tabs([
100
- " Single Prediction",
101
- " Batch Analysis",
102
- " Model Analytics",
103
- " Hyperparameter Tuning",
104
- "ℹ About"
105
- ])
106
-
107
- # ==================== TAB 1: SINGLE PREDICTION ====================
108
- with tab1:
109
- st.header(" Analyze Single Exoplanet Candidate")
110
- st.markdown("Enter the parameters of an exoplanet candidate to predict if it's a **planet** or **false positive**")
111
-
112
- with st.form("prediction_form"):
113
- col1, col2, col3 = st.columns(3)
114
-
115
- with col1:
116
- st.subheader(" Orbital Properties")
117
- period = st.number_input("Orbital Period (days)", 0.0, 10000.0, 10.0,
118
- help="Time for one complete orbit around the star")
119
- duration = st.number_input("Transit Duration (hours)", 0.0, 48.0, 3.0,
120
- help="Time the planet takes to cross the star")
121
- depth = st.number_input("Transit Depth (ppm)", 0.0, 100000.0, 1000.0,
122
- help="Brightness dip when planet transits")
123
-
124
- with col2:
125
- st.subheader(" Planet Properties")
126
- planet_radius = st.number_input("Planet Radius (Earth radii)", 0.1, 100.0, 1.0,
127
- help="Size relative to Earth")
128
- equilibrium_temp = st.number_input("Equilibrium Temperature (K)", 0, 5000, 288,
129
- help="Expected temperature of the planet")
130
- insolation = st.number_input("Insolation Flux (Earth units)", 0.0, 10000.0, 1.0,
131
- help="Energy received from star (Earth=1.0)")
132
-
133
- with col3:
134
- st.subheader(" Stellar Properties")
135
- star_radius = st.number_input("Star Radius (Solar radii)", 0.1, 50.0, 1.0,
136
- help="Size relative to the Sun")
137
- star_temp = st.number_input("Star Temperature (K)", 2000, 50000, 5778,
138
- help="Surface temperature (Sun=5778K)")
139
- star_logg = st.number_input("Star log(g)", 0.0, 5.0, 4.4,
140
- help="Surface gravity indicator")
141
-
142
- mission = st.selectbox("Mission", metadata['missions'], help="Which telescope detected this candidate")
143
-
144
- submit_button = st.form_submit_button(" Analyze Candidate", type="primary")
145
-
146
- if submit_button:
147
- with st.spinner(" Analyzing candidate..."):
148
- # Create feature dictionary
149
- features_dict = {}
150
-
151
- # Basic features
152
- feature_map = {
153
- 'period': period,
154
- 'duration': duration,
155
- 'depth': depth,
156
- 'planet_radius': planet_radius,
157
- 'star_radius': star_radius,
158
- 'star_temp': star_temp,
159
- 'star_logg': star_logg,
160
- 'equilibrium_temp': equilibrium_temp,
161
- 'insolation_flux': insolation
162
- }
163
-
164
- for fname, fval in feature_map.items():
165
- if fname in feature_names:
166
- features_dict[fname] = fval
167
-
168
- # Engineered features
169
- if 'transit_period_ratio' in feature_names and period > 0:
170
- features_dict['transit_period_ratio'] = duration / (period * 24)
171
-
172
- if 'radius_ratio' in feature_names and star_radius > 0:
173
- features_dict['radius_ratio'] = planet_radius / star_radius
174
-
175
- if 'period_log' in feature_names and period > 0:
176
- features_dict['period_log'] = np.log10(period)
177
-
178
- if 'insolation_flux_log' in feature_names and insolation > 0:
179
- features_dict['insolation_flux_log'] = np.log10(insolation)
180
-
181
- if 'habitable_zone_dist' in feature_names:
182
- features_dict['habitable_zone_dist'] = abs(equilibrium_temp - 288) / 288
183
-
184
- # Stellar classification
185
- if 'star_class' in feature_names:
186
- if star_temp >= 7500: star_class = 5
187
- elif star_temp >= 6000: star_class = 4
188
- elif star_temp >= 5200: star_class = 3
189
- elif star_temp >= 3700: star_class = 2
190
- else: star_class = 1
191
- features_dict['star_class'] = star_class
192
-
193
- if 'luminosity_class' in feature_names:
194
- if star_logg < 3.5: lum_class = 3
195
- elif star_logg < 4.0: lum_class = 2
196
- else: lum_class = 1
197
- features_dict['luminosity_class'] = lum_class
198
-
199
- # Mission encoding
200
- for m in metadata['missions']:
201
- col_name = f'mission_{m}'
202
- if col_name in feature_names:
203
- features_dict[col_name] = 1 if m == mission else 0
204
-
205
- # Create feature vector
206
- feature_vector = [features_dict.get(f, 0) for f in feature_names]
207
- X_input = np.array(feature_vector).reshape(1, -1)
208
-
209
- # Scale and predict
210
- X_scaled = scaler.transform(X_input)
211
- prediction = model.predict(X_scaled)[0]
212
- probabilities = model.predict_proba(X_scaled)[0]
213
-
214
- # Display results
215
- st.markdown("---")
216
- st.markdown("### Prediction Results")
217
-
218
- result_col1, result_col2, result_col3 = st.columns([2, 2, 3])
219
-
220
- with result_col1:
221
- if prediction == 1:
222
- st.success("### PLANET DETECTED!")
223
- confidence = probabilities[1]
224
- else:
225
- st.error("### FALSE POSITIVE")
226
- confidence = probabilities[0]
227
-
228
- with result_col2:
229
- st.metric("Confidence Score", f"{confidence*100:.1f}%",
230
- delta=f"{(confidence-0.5)*100:.1f}% from neutral")
231
-
232
- if confidence > 0.9:
233
- st.info(" Very High Confidence")
234
- elif confidence > 0.75:
235
- st.info(" High Confidence")
236
- elif confidence > 0.6:
237
- st.info(" Moderate Confidence")
238
- else:
239
- st.info(" Low Confidence")
240
-
241
- with result_col3:
242
- # Probability gauge
243
- fig = go.Figure(go.Indicator(
244
- mode="gauge+number+delta",
245
- value=probabilities[1] * 100,
246
- title={'text': "Planet Probability (%)"},
247
- delta={'reference': 50, 'increasing': {'color': "green"}},
248
- gauge={
249
- 'axis': {'range': [0, 100], 'tickwidth': 1},
250
- 'bar': {'color': "darkblue"},
251
- 'steps': [
252
- {'range': [0, 25], 'color': "lightgray"},
253
- {'range': [25, 50], 'color': "gray"},
254
- {'range': [50, 75], 'color': "lightblue"},
255
- {'range': [75, 100], 'color': "lightgreen"}
256
- ],
257
- 'threshold': {
258
- 'line': {'color': "red", 'width': 4},
259
- 'thickness': 0.75,
260
- 'value': 50
261
- }
262
- }
263
- ))
264
- fig.update_layout(height=280, margin=dict(l=20, r=20, t=80, b=20))
265
- st.plotly_chart(fig, use_container_width=True)
266
-
267
- # Detailed probabilities
268
- st.markdown("---")
269
- st.subheader(" Detailed Probabilities")
270
-
271
- prob_col1, prob_col2 = st.columns(2)
272
-
273
- with prob_col1:
274
- st.metric("False Positive Probability", f"{probabilities[0]*100:.2f}%")
275
- with prob_col2:
276
- st.metric("Planet Probability", f"{probabilities[1]*100:.2f}%")
277
-
278
- # ==================== TAB 2: BATCH ANALYSIS ====================
279
- with tab2:
280
- st.header(" Batch Analysis")
281
- st.markdown("Upload a CSV file with multiple exoplanet candidates for batch predictions")
282
-
283
- st.info(" **Tip:** Your CSV should contain columns matching the feature names used by the model")
284
-
285
- uploaded_file = st.file_uploader("Choose CSV file", type=['csv'])
286
-
287
- if uploaded_file:
288
- df_upload = pd.read_csv(uploaded_file)
289
-
290
- st.subheader(" Uploaded Data Preview")
291
- st.dataframe(df_upload.head(10), use_container_width=True)
292
-
293
- st.metric("Total Candidates", len(df_upload))
294
-
295
- if st.button("⚡ Analyze All Candidates", type="primary"):
296
- with st.spinner("Analyzing all candidates..."):
297
- st.success(f" Would analyze {len(df_upload)} candidates!")
298
- st.info(" Feature coming soon: Batch prediction implementation")
299
- st.balloons()
300
-
301
- # ==================== TAB 3: MODEL ANALYTICS ====================
302
- with tab3:
303
- st.header(" Model Performance Analytics")
304
-
305
- # Metrics Overview
306
- st.subheader(" Test Set Performance")
307
- metric_col1, metric_col2, metric_col3, metric_col4, metric_col5 = st.columns(5)
308
-
309
- with metric_col1:
310
- st.metric("Accuracy", f"{metadata['test_accuracy']*100:.2f}%")
311
- with metric_col2:
312
- st.metric("Precision", f"{metadata['test_precision']:.3f}")
313
- with metric_col3:
314
- st.metric("Recall", f"{metadata['test_recall']:.3f}")
315
- with metric_col4:
316
- st.metric("F1 Score", f"{metadata['test_f1_score']:.3f}")
317
- with metric_col5:
318
- st.metric("ROC-AUC", f"{metadata['test_roc_auc']:.3f}")
319
-
320
- st.markdown("---")
321
-
322
- # Dataset Information
323
- st.subheader(" Dataset Information")
324
- data_col1, data_col2, data_col3, data_col4 = st.columns(4)
325
-
326
- with data_col1:
327
- st.metric("Total Samples", f"{metadata['total_samples']:,}")
328
- with data_col2:
329
- st.metric("Planets", f"{metadata['planets_total']:,}")
330
- with data_col3:
331
- st.metric("False Positives", f"{metadata['false_positives_total']:,}")
332
- with data_col4:
333
- st.metric("Planet %", f"{metadata['planet_percentage']:.1f}%")
334
-
335
- st.markdown("---")
336
-
337
- # Model Comparison
338
- st.subheader(" Individual Model Performance (Validation Set)")
339
-
340
- if 'validation_scores' in metadata:
341
- val_scores_df = pd.DataFrame([
342
- {"Model": k, "ROC-AUC": v}
343
- for k, v in metadata['validation_scores'].items()
344
- ]).sort_values('ROC-AUC', ascending=False)
345
-
346
- fig = px.bar(val_scores_df, x='ROC-AUC', y='Model', orientation='h',
347
- title='Model Comparison (Validation ROC-AUC)',
348
- color='ROC-AUC', color_continuous_scale='viridis')
349
- fig.update_layout(height=400, yaxis={'categoryorder':'total ascending'})
350
- st.plotly_chart(fig, use_container_width=True)
351
-
352
- st.markdown("---")
353
-
354
- # Cross-Validation
355
- st.subheader(" Cross-Validation Results")
356
- cv_col1, cv_col2, cv_col3 = st.columns(3)
357
-
358
- with cv_col1:
359
- st.metric("CV Mean ROC-AUC", f"{metadata['cv_mean_roc_auc']:.4f}")
360
- with cv_col2:
361
- st.metric("CV Std Dev", f"±{metadata['cv_std_roc_auc']:.4f}")
362
- with cv_col3:
363
- overfitting_status = metadata.get('overfitting_check', 'Unknown')
364
- st.metric("Overfitting Check", overfitting_status)
365
-
366
- # ==================== TAB 4: HYPERPARAMETER TUNING ====================
367
- with tab4:
368
- st.header(" Hyperparameter Tuning")
369
- st.markdown("Customize model hyperparameters and train new models")
370
-
371
- # ==================== PRESET CONFIGURATIONS ====================
372
- st.subheader(" Quick Presets")
373
-
374
- preset_col1, preset_col2, preset_col3, preset_col4 = st.columns(4)
375
-
376
- with preset_col1:
377
- if st.button(" Best Performance", help="Optimized for maximum accuracy"):
378
- st.session_state.preset = "best"
379
-
380
- with preset_col2:
381
- if st.button(" Fast Training", help="Quick training, good accuracy"):
382
- st.session_state.preset = "fast"
383
-
384
- with preset_col3:
385
- if st.button(" Anti-Overfit", help="Maximum generalization"):
386
- st.session_state.preset = "safe"
387
-
388
- with preset_col4:
389
- if st.button(" Research Grade", help="Publication-quality"):
390
- st.session_state.preset = "research"
391
-
392
- # Initialize session state
393
- if 'preset' not in st.session_state:
394
- st.session_state.preset = "best"
395
-
396
- # Define presets
397
- presets = {
398
- "best": {
399
- "rf_n_estimators": 300, "rf_max_depth": 15, "rf_min_samples_split": 8,
400
- "rf_min_samples_leaf": 4, "rf_max_features": "sqrt",
401
- "gb_n_estimators": 150, "gb_learning_rate": 0.05, "gb_max_depth": 5,
402
- "gb_min_samples_split": 10, "gb_subsample": 0.8,
403
- "xgb_n_estimators": 200, "xgb_learning_rate": 0.05, "xgb_max_depth": 6,
404
- "xgb_min_child_weight": 5, "xgb_subsample": 0.8, "xgb_colsample": 0.8,
405
- "lgb_n_estimators": 200, "lgb_learning_rate": 0.05, "lgb_max_depth": 7,
406
- "lgb_num_leaves": 25, "lgb_min_child_samples": 20, "lgb_subsample": 0.8
407
- },
408
- "fast": {
409
- "rf_n_estimators": 100, "rf_max_depth": 10, "rf_min_samples_split": 10,
410
- "rf_min_samples_leaf": 5, "rf_max_features": "sqrt",
411
- "gb_n_estimators": 75, "gb_learning_rate": 0.1, "gb_max_depth": 4,
412
- "gb_min_samples_split": 10, "gb_subsample": 0.8,
413
- "xgb_n_estimators": 100, "xgb_learning_rate": 0.1, "xgb_max_depth": 5,
414
- "xgb_min_child_weight": 3, "xgb_subsample": 0.8, "xgb_colsample": 0.8,
415
- "lgb_n_estimators": 100, "lgb_learning_rate": 0.1, "lgb_max_depth": 6,
416
- "lgb_num_leaves": 20, "lgb_min_child_samples": 15, "lgb_subsample": 0.8
417
- },
418
- "safe": {
419
- "rf_n_estimators": 200, "rf_max_depth": 10, "rf_min_samples_split": 15,
420
- "rf_min_samples_leaf": 8, "rf_max_features": "sqrt",
421
- "gb_n_estimators": 100, "gb_learning_rate": 0.03, "gb_max_depth": 3,
422
- "gb_min_samples_split": 20, "gb_subsample": 0.7,
423
- "xgb_n_estimators": 150, "xgb_learning_rate": 0.03, "xgb_max_depth": 4,
424
- "xgb_min_child_weight": 8, "xgb_subsample": 0.7, "xgb_colsample": 0.7,
425
- "lgb_n_estimators": 150, "lgb_learning_rate": 0.03, "lgb_max_depth": 5,
426
- "lgb_num_leaves": 15, "lgb_min_child_samples": 30, "lgb_subsample": 0.7
427
- },
428
- "research": {
429
- "rf_n_estimators": 400, "rf_max_depth": 18, "rf_min_samples_split": 6,
430
- "rf_min_samples_leaf": 3, "rf_max_features": "sqrt",
431
- "gb_n_estimators": 200, "gb_learning_rate": 0.03, "gb_max_depth": 6,
432
- "gb_min_samples_split": 8, "gb_subsample": 0.85,
433
- "xgb_n_estimators": 250, "xgb_learning_rate": 0.03, "xgb_max_depth": 7,
434
- "xgb_min_child_weight": 4, "xgb_subsample": 0.85, "xgb_colsample": 0.85,
435
- "lgb_n_estimators": 250, "lgb_learning_rate": 0.03, "lgb_max_depth": 8,
436
- "lgb_num_leaves": 30, "lgb_min_child_samples": 15, "lgb_subsample": 0.85
437
- }
438
- }
439
-
440
- selected_preset = presets[st.session_state.preset]
441
- st.success(f" Using '{st.session_state.preset.upper()}' preset configuration!")
442
-
443
- st.markdown("---")
444
-
445
- # Create two columns for different models
446
- col_left, col_right = st.columns(2)
447
-
448
- with col_left:
449
- st.subheader(" Random Forest")
450
- rf_n_estimators = st.slider("RF: n_estimators", 50, 500, selected_preset["rf_n_estimators"], 10)
451
- rf_max_depth = st.slider("RF: max_depth", 5, 30, selected_preset["rf_max_depth"], 1)
452
- rf_min_samples_split = st.slider("RF: min_samples_split", 2, 20, selected_preset["rf_min_samples_split"], 1)
453
- rf_min_samples_leaf = st.slider("RF: min_samples_leaf", 1, 10, selected_preset["rf_min_samples_leaf"], 1)
454
- rf_max_features = st.selectbox("RF: max_features", ['sqrt', 'log2', None], index=0)
455
-
456
- with col_right:
457
- st.subheader(" Gradient Boosting")
458
- gb_n_estimators = st.slider("GB: n_estimators", 50, 300, selected_preset["gb_n_estimators"], 10)
459
- gb_learning_rate = st.slider("GB: learning_rate", 0.01, 0.3, selected_preset["gb_learning_rate"], 0.01)
460
- gb_max_depth = st.slider("GB: max_depth", 3, 10, selected_preset["gb_max_depth"], 1)
461
- gb_min_samples_split = st.slider("GB: min_samples_split", 2, 20, selected_preset["gb_min_samples_split"], 1)
462
- gb_subsample = st.slider("GB: subsample", 0.5, 1.0, selected_preset["gb_subsample"], 0.05)
463
-
464
- with st.expander(" XGBoost Parameters"):
465
- col1, col2 = st.columns(2)
466
- with col1:
467
- xgb_n_estimators = st.slider("XGB: n_estimators", 50, 300, selected_preset["xgb_n_estimators"], 10, key="xgb_n")
468
- xgb_learning_rate = st.slider("XGB: learning_rate", 0.01, 0.3, selected_preset["xgb_learning_rate"], 0.01, key="xgb_lr")
469
- xgb_max_depth = st.slider("XGB: max_depth", 3, 10, selected_preset["xgb_max_depth"], 1, key="xgb_depth")
470
- with col2:
471
- xgb_min_child_weight = st.slider("XGB: min_child_weight", 1, 10, selected_preset["xgb_min_child_weight"], 1)
472
- xgb_subsample = st.slider("XGB: subsample", 0.5, 1.0, selected_preset["xgb_subsample"], 0.05, key="xgb_sub")
473
- xgb_colsample = st.slider("XGB: colsample_bytree", 0.5, 1.0, selected_preset["xgb_colsample"], 0.05)
474
-
475
- with st.expander(" LightGBM Parameters"):
476
- col1, col2 = st.columns(2)
477
- with col1:
478
- lgb_n_estimators = st.slider("LGB: n_estimators", 50, 300, selected_preset["lgb_n_estimators"], 10, key="lgb_n")
479
- lgb_learning_rate = st.slider("LGB: learning_rate", 0.01, 0.3, selected_preset["lgb_learning_rate"], 0.01, key="lgb_lr")
480
- lgb_max_depth = st.slider("LGB: max_depth", 3, 15, selected_preset["lgb_max_depth"], 1, key="lgb_depth")
481
- with col2:
482
- lgb_num_leaves = st.slider("LGB: num_leaves", 10, 100, selected_preset["lgb_num_leaves"], 5)
483
- lgb_min_child_samples = st.slider("LGB: min_child_samples", 5, 50, selected_preset["lgb_min_child_samples"], 5)
484
- lgb_subsample = st.slider("LGB: subsample", 0.5, 1.0, selected_preset["lgb_subsample"], 0.05, key="lgb_sub")
485
-
486
- st.markdown("---")
487
-
488
- # Generate code button
489
- st.subheader(" Generated Training Code")
490
-
491
- if st.button(" Generate Retraining Code"):
492
- generated_code = f"""# Generated on {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
493
-
494
- # Random Forest Parameters
495
- rf_params = {{
496
- 'n_estimators': {rf_n_estimators},
497
- 'max_depth': {rf_max_depth},
498
- 'min_samples_split': {rf_min_samples_split},
499
- 'min_samples_leaf': {rf_min_samples_leaf},
500
- 'max_features': {repr(rf_max_features)},
501
- 'random_state': 42, 'n_jobs': -1, 'class_weight': 'balanced'
502
- }}
503
-
504
- # Gradient Boosting Parameters
505
- gb_params = {{
506
- 'n_estimators': {gb_n_estimators},
507
- 'learning_rate': {gb_learning_rate},
508
- 'max_depth': {gb_max_depth},
509
- 'min_samples_split': {gb_min_samples_split},
510
- 'subsample': {gb_subsample},
511
- 'random_state': 42
512
- }}
513
-
514
- # XGBoost Parameters
515
- xgb_params = {{
516
- 'n_estimators': {xgb_n_estimators},
517
- 'learning_rate': {xgb_learning_rate},
518
- 'max_depth': {xgb_max_depth},
519
- 'min_child_weight': {xgb_min_child_weight},
520
- 'subsample': {xgb_subsample},
521
- 'colsample_bytree': {xgb_colsample},
522
- 'random_state': 42, 'n_jobs': -1
523
- }}
524
-
525
- # LightGBM Parameters
526
- lgb_params = {{
527
- 'n_estimators': {lgb_n_estimators},
528
- 'learning_rate': {lgb_learning_rate},
529
- 'max_depth': {lgb_max_depth},
530
- 'num_leaves': {lgb_num_leaves},
531
- 'min_child_samples': {lgb_min_child_samples},
532
- 'subsample': {lgb_subsample},
533
- 'random_state': 42, 'n_jobs': -1, 'verbose': -1
534
- }}
535
-
536
- # Train models
537
- trained_models, final_model = train_all_models_anti_overfit(
538
- X_train_scaled, y_train, X_val_scaled, y_val
539
- )
540
- """
541
- st.code(generated_code, language="python")
542
- st.success(" Code generated! Copy and paste into Jupyter notebook.")
543
-
544
- st.markdown("---")
545
-
546
- # ==================== TRAIN MODEL IN STREAMLIT ====================
547
- st.subheader(" Train Model with Custom Parameters")
548
-
549
- train_col1, train_col2 = st.columns([3, 1])
550
-
551
- with train_col1:
552
- st.info("""
553
- **How it works:**
554
- 1. Adjust hyperparameters above
555
- 2. Click "Train New Model"
556
- 3. Wait 5-15 minutes for training
557
- 4. Download trained model
558
- 5. Replace old model and restart app
559
- """)
560
-
561
- with train_col2:
562
- train_button = st.button(" Train New Model", type="primary", use_container_width=True)
563
-
564
- if train_button:
565
- st.markdown("---")
566
- st.header(" Training in Progress...")
567
-
568
- progress_bar = st.progress(0)
569
- status_text = st.empty()
570
-
571
- try:
572
- # Step 1: Load Data
573
- status_text.text("Step 1/5: Loading datasets...")
574
- progress_bar.progress(10)
575
-
576
- @st.cache_data
577
- def load_training_data():
578
- import requests
579
- from io import StringIO
580
- datasets = {}
581
- try:
582
- url = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync?query=select+*+from+koi&format=csv"
583
- response = requests.get(url, timeout=30)
584
- if response.status_code == 200:
585
- datasets['kepler'] = pd.read_csv(StringIO(response.text))
586
- except: pass
587
- try:
588
- url = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync?query=select+*+from+toi&format=csv"
589
- response = requests.get(url, timeout=30)
590
- if response.status_code == 200:
591
- datasets['tess'] = pd.read_csv(StringIO(response.text))
592
- except: pass
593
- try:
594
- url = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync?query=select+*+from+k2pandc&format=csv"
595
- response = requests.get(url, timeout=30)
596
- if response.status_code == 200:
597
- datasets['k2'] = pd.read_csv(StringIO(response.text))
598
- except: pass
599
- return datasets
600
-
601
- datasets = load_training_data()
602
- if len(datasets) == 0:
603
- st.error(" Unable to load datasets")
604
- st.stop()
605
-
606
- st.success(f" Loaded {len(datasets)} dataset(s)")
607
- progress_bar.progress(20)
608
-
609
- # Step 2: Preprocess
610
- status_text.text("Step 2/5: Preprocessing...")
611
-
612
- from sklearn.model_selection import train_test_split
613
- from sklearn.preprocessing import RobustScaler
614
- from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier
615
-
616
- def quick_preprocess(datasets):
617
- dfs = []
618
- for mission, df in datasets.items():
619
- df_copy = df.copy()
620
- numeric_cols = df_copy.select_dtypes(include=[np.number]).columns.tolist()
621
- target_cols = ['koi_disposition', 'tfopwg_disp', 'disposition']
622
- target_col = None
623
- for tc in target_cols:
624
- if tc in df_copy.columns:
625
- target_col = tc
626
- break
627
- if target_col is None:
628
- continue
629
-
630
- # Create binary target
631
- if mission == 'kepler':
632
- df_copy['target'] = df_copy[target_col].apply(
633
- lambda x: 1 if str(x).upper() in ['CONFIRMED', 'CANDIDATE'] else 0
634
- )
635
- elif mission == 'tess':
636
- df_copy['target'] = df_copy[target_col].apply(
637
- lambda x: 1 if str(x).upper() in ['PC', 'CP', 'KP'] else 0
638
- )
639
- else:
640
- df_copy['target'] = df_copy[target_col].apply(
641
- lambda x: 1 if str(x).upper() in ['CONFIRMED', 'CANDIDATE'] else 0
642
- )
643
-
644
- keep_cols = [col for col in numeric_cols if col != target_col] + ['target']
645
- df_subset = df_copy[keep_cols].copy()
646
- dfs.append(df_subset)
647
-
648
- # Combine all datasets
649
- combined = pd.concat(dfs, ignore_index=True)
650
-
651
- # CRITICAL: Remove columns with too many missing values FIRST
652
- missing_pct = combined.isnull().sum() / len(combined)
653
- cols_to_keep = missing_pct[missing_pct < 0.7].index.tolist() # Keep columns with <70% missing
654
- combined = combined[cols_to_keep]
655
-
656
- # Fill remaining NaN values with median
657
- for col in combined.columns:
658
- if col != 'target':
659
- if combined[col].isnull().any():
660
- median_val = combined[col].median()
661
- # If median is also NaN (all values are NaN), use 0
662
- if pd.isna(median_val):
663
- combined[col].fillna(0, inplace=True)
664
- else:
665
- combined[col].fillna(median_val, inplace=True)
666
-
667
- # Replace infinite values
668
- combined = combined.replace([np.inf, -np.inf], 0)
669
-
670
- # Remove rows with ANY remaining missing values in features
671
- combined = combined.dropna(subset=[col for col in combined.columns if col != 'target'])
672
-
673
- # Final safety check: ensure NO NaN values remain
674
- assert combined.isnull().sum().sum() == 0, "NaN values still present after preprocessing!"
675
-
676
- return combined
677
-
678
- processed_data = quick_preprocess(datasets)
679
- X = processed_data.drop('target', axis=1)
680
- y = processed_data['target']
681
-
682
- st.success(f" Preprocessed {len(X)} samples")
683
- progress_bar.progress(35)
684
-
685
- # Step 3: Split and Scale
686
- status_text.text("Step 3/5: Splitting and scaling...")
687
-
688
- X_train, X_test, y_train, y_test = train_test_split(
689
- X, y, test_size=0.2, random_state=42, stratify=y
690
- )
691
-
692
- scaler_new = RobustScaler()
693
- X_train_scaled = scaler_new.fit_transform(X_train)
694
- X_test_scaled = scaler_new.transform(X_test)
695
-
696
- progress_bar.progress(45)
697
-
698
- # Step 4: Train Models
699
- status_text.text("Step 4/5: Training models...")
700
-
701
- models_trained = {}
702
-
703
- st.write(" Training Random Forest...")
704
- rf_new = RandomForestClassifier(
705
- n_estimators=rf_n_estimators, max_depth=rf_max_depth,
706
- min_samples_split=rf_min_samples_split, min_samples_leaf=rf_min_samples_leaf,
707
- max_features=rf_max_features, class_weight='balanced',
708
- random_state=42, n_jobs=-1
709
- )
710
- rf_new.fit(X_train_scaled, y_train)
711
- models_trained['RandomForest'] = rf_new
712
- progress_bar.progress(55)
713
-
714
- st.write(" Training Gradient Boosting...")
715
- gb_new = GradientBoostingClassifier(
716
- n_estimators=gb_n_estimators, learning_rate=gb_learning_rate,
717
- max_depth=gb_max_depth, min_samples_split=gb_min_samples_split,
718
- subsample=gb_subsample, random_state=42
719
- )
720
- gb_new.fit(X_train_scaled, y_train)
721
- models_trained['GradientBoosting'] = gb_new
722
- progress_bar.progress(65)
723
-
724
- try:
725
- import xgboost as xgb
726
- st.write(" Training XGBoost...")
727
- xgb_new = xgb.XGBClassifier(
728
- n_estimators=xgb_n_estimators, learning_rate=xgb_learning_rate,
729
- max_depth=xgb_max_depth, min_child_weight=xgb_min_child_weight,
730
- subsample=xgb_subsample, colsample_bytree=xgb_colsample,
731
- random_state=42, n_jobs=-1
732
- )
733
- xgb_new.fit(X_train_scaled, y_train)
734
- models_trained['XGBoost'] = xgb_new
735
- except:
736
- st.warning(" XGBoost not available")
737
- progress_bar.progress(75)
738
-
739
- try:
740
- import lightgbm as lgb
741
- st.write(" Training LightGBM...")
742
- lgb_new = lgb.LGBMClassifier(
743
- n_estimators=lgb_n_estimators, learning_rate=lgb_learning_rate,
744
- max_depth=lgb_max_depth, num_leaves=lgb_num_leaves,
745
- min_child_samples=lgb_min_child_samples, subsample=lgb_subsample,
746
- random_state=42, n_jobs=-1, verbose=-1
747
- )
748
- lgb_new.fit(X_train_scaled, y_train)
749
- models_trained['LightGBM'] = lgb_new
750
- except:
751
- st.warning(" LightGBM not available")
752
- progress_bar.progress(85)
753
-
754
- st.write(" Creating Ensemble...")
755
- estimators = [(name, model) for name, model in models_trained.items()]
756
- ensemble_new = VotingClassifier(estimators=estimators, voting='soft', n_jobs=-1)
757
- ensemble_new.fit(X_train_scaled, y_train)
758
- progress_bar.progress(90)
759
-
760
- # Step 5: Evaluate
761
- status_text.text("Step 5/5: Evaluating...")
762
-
763
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
764
-
765
- y_pred = ensemble_new.predict(X_test_scaled)
766
- y_pred_proba = ensemble_new.predict_proba(X_test_scaled)[:, 1]
767
-
768
- new_metrics = {
769
- 'accuracy': accuracy_score(y_test, y_pred),
770
- 'precision': precision_score(y_test, y_pred, zero_division=0),
771
- 'recall': recall_score(y_test, y_pred, zero_division=0),
772
- 'f1_score': f1_score(y_test, y_pred, zero_division=0),
773
- 'roc_auc': roc_auc_score(y_test, y_pred_proba)
774
- }
775
-
776
- progress_bar.progress(100)
777
- status_text.text(" Training complete!")
778
-
779
- st.success(" Model training complete!")
780
-
781
- st.markdown("---")
782
- st.subheader(" New Model Performance")
783
-
784
- metric_col1, metric_col2, metric_col3, metric_col4, metric_col5 = st.columns(5)
785
- with metric_col1:
786
- st.metric("Accuracy", f"{new_metrics['accuracy']:.3f}")
787
- with metric_col2:
788
- st.metric("Precision", f"{new_metrics['precision']:.3f}")
789
- with metric_col3:
790
- st.metric("Recall", f"{new_metrics['recall']:.3f}")
791
- with metric_col4:
792
- st.metric("F1 Score", f"{new_metrics['f1_score']:.3f}")
793
- with metric_col5:
794
- st.metric("ROC-AUC", f"{new_metrics['roc_auc']:.3f}")
795
-
796
- # Save model
797
- st.markdown("---")
798
- st.subheader(" Download New Model")
799
-
800
- new_model_package = {
801
- 'ensemble_model': ensemble_new,
802
- 'individual_models': models_trained,
803
- 'scaler': scaler_new,
804
- 'feature_names': X.columns.tolist(),
805
- 'metadata': {
806
- 'version': '2.0',
807
- 'created_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
808
- 'created_date': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
809
- 'missions': list(datasets.keys()),
810
- 'total_samples': len(X),
811
- 'train_samples': len(X_train),
812
- 'test_samples': len(X_test),
813
- 'n_features': len(X.columns),
814
- 'test_accuracy': float(new_metrics['accuracy']),
815
- 'test_precision': float(new_metrics['precision']),
816
- 'test_recall': float(new_metrics['recall']),
817
- 'test_f1_score': float(new_metrics['f1_score']),
818
- 'test_roc_auc': float(new_metrics['roc_auc']),
819
- 'n_models_in_ensemble': len(models_trained),
820
- 'ensemble_model_names': list(models_trained.keys()),
821
- 'planets_total': int(y.sum()),
822
- 'false_positives_total': int((y==0).sum()),
823
- 'planet_percentage': float(y.mean() * 100),
824
- 'cv_mean_roc_auc': 0.0,
825
- 'cv_std_roc_auc': 0.0,
826
- 'overfitting_check': 'Not tested'
827
- }
828
- }
829
-
830
- buffer = io.BytesIO()
831
- joblib.dump(new_model_package, buffer)
832
- buffer.seek(0)
833
-
834
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
835
- filename = f"exoplanet_final_model.joblib"
836
-
837
- st.download_button(
838
- label="⬇ Download New Model",
839
- data=buffer,
840
- file_name=filename,
841
- mime="application/octet-stream",
842
- type="primary"
843
- )
844
-
845
- st.success(f" Model ready! Update line 47 with: `{filename}`")
846
-
847
- except Exception as e:
848
- st.error(f" Error: {str(e)}")
849
-
850
- # ==================== TAB 5: ABOUT ====================
851
- with tab5:
852
- st.header("ℹ About This System")
853
-
854
- st.markdown("""
855
- ### Project Overview
856
- AI-powered exoplanet detection using NASA telescope data.
857
-
858
- ### Data Sources
859
- - **Kepler Mission**: Stellar transit observations
860
- - **TESS Mission**: Transiting Exoplanet Survey Satellite
861
- - **K2 Mission**: Extended Kepler observations
862
-
863
- ### ML Approach
864
- Multi-model ensemble with advanced feature engineering
865
-
866
- ### NASA Space Apps Challenge 2025
867
- Built for "A World Away: Hunting for Exoplanets with AI"
868
-
869
- ### Resources
870
- - [NASA Exoplanet Archive](https://exoplanetarchive.ipac.caltech.edu/)
871
- - [Space Apps Challenge](https://www.spaceappschallenge.org/)
872
- """)
873
-
874
- st.markdown("---")
875
- st.markdown("""
876
- <div style='text-align: center; color: #666;'>
877
- <p><strong>NASA Space Apps Challenge 2025</strong></p>
878
- <p>Built with ❤️ using Streamlit & Machine Learning</p>
879
- <p>🌟 Detecting exoplanets • One transit at a time 🪐</p>
880
- </div>
881
- """, unsafe_allow_html=True)
882
 
 
 
1
  import streamlit as st
2
  import joblib
 
 
 
 
 
3
  import time
4
  import io
5
+ from huggingface_hub import hf_hub_download
6
 
7
  # ==================== PAGE CONFIG ====================
8
  st.set_page_config(
9
+ page_title="Exoplanet Classification",
10
+ page_icon="🌌",
11
  layout="wide",
12
+ initial_sidebar_state="expanded",
13
  )
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # ==================== LOAD MODEL ====================
16
  @st.cache_resource
17
  def load_model_package():
18
  """Load the complete model package"""
19
  try:
20
+ # Try to load the model from a local path first
21
  package = joblib.load("exoplanet_final_model.joblib")
22
+ st.info("Loaded model from local file.")
23
  return package
24
+ except FileNotFoundError:
25
+ st.info("Model file not found locally. Attempting to download from Hugging Face Hub...")
26
+ try:
27
+ # Download from Hugging Face Hub
28
+ model_path = hf_hub_download(
29
+ repo_id="Nugget-cloud/nasa-space-apps-exoplanet",
30
+ filename="exoplanet_final_model.joblib"
31
+ )
32
+ package = joblib.load(model_path)
33
+ st.success("Model successfully downloaded and loaded from Hugging Face Hub.")
34
+ return package
35
+ except Exception as hub_e:
36
+ st.error(f"Failed to download or load model from Hugging Face Hub: {hub_e}")
37
+ st.error("You can train a new model in the 'Hyperparameter Tuning' tab.")
38
+ return None
39
  except Exception as e:
40
+ st.error(f"An unexpected error occurred while loading the model: {e}")
41
+ return None
 
42
 
43
  # Load package
44
  with st.spinner(" Loading AI model..."):
45
  package = load_model_package()
46
 
47
+ if package is None:
48
+ st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ model = package['ensemble_model']