Nugget-cloud commited on
Commit
dcf91d4
·
verified ·
1 Parent(s): bb7eeca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +848 -223
app.py CHANGED
@@ -1,257 +1,882 @@
1
- import gradio as gr
2
  import joblib
3
  import pandas as pd
4
  import numpy as np
5
  import plotly.graph_objects as go
6
  import plotly.express as px
7
- from huggingface_hub import hf_hub_download
8
- import warnings
 
9
 
10
- # ==================== CONFIGURATION ====================
11
- # ⚠️ UPDATE THIS WITH YOUR HUGGING FACE REPOSITORY ID
12
- HF_REPO_ID = "YOUR_USERNAME/YOUR_REPO_NAME"
13
- MODEL_FILENAME = "exoplanet_final_model.joblib"
 
 
 
14
 
15
- # Suppress specific warnings for a cleaner output
16
- warnings.filterwarnings("ignore", category=UserWarning, message="Trying to unpickle estimator.*")
17
- warnings.filterwarnings("ignore", category=FutureWarning)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # ==================== LOAD MODEL FROM HUGGING FACE ====================
20
- def load_model_package(repo_id, filename):
21
- """Load the complete model package from Hugging Face Hub"""
 
22
  try:
23
- model_path = hf_hub_download(repo_id=repo_id, filename=filename)
24
- package = joblib.load(model_path)
25
  return package
26
  except Exception as e:
27
- # Fallback for local development if HF download fails
28
- print(f"Could not download from Hugging Face: {e}. Trying local file...")
29
- try:
30
- package = joblib.load(filename)
31
- return package
32
- except FileNotFoundError:
33
- raise gr.Error(f"Model file not found locally or on Hugging Face at {repo_id}. Please check HF_REPO_ID and ensure the model file is available.")
34
- except Exception as e_local:
35
- raise gr.Error(f"Error loading local model: {e_local}")
36
 
37
- # Load package and extract components
38
- try:
39
- print("Loading AI model...")
40
- package = load_model_package(HF_REPO_ID, MODEL_FILENAME)
41
- model = package['ensemble_model']
42
- scaler = package['scaler']
43
- feature_names = package['feature_names']
44
- metadata = package.get('metadata', {}) # Use .get for safety
45
- print("AI model loaded successfully.")
46
- except Exception as e:
47
- # If model loading fails, we can't run the app.
48
- print(str(e))
49
- # Create a dummy structure to prevent the UI from crashing on startup
50
- model, scaler, feature_names, metadata = None, None, [], {'missions': ['N/A'], 'version': 'Error'}
51
 
 
 
 
 
52
 
53
- # ==================== PREDICTION LOGIC ====================
54
- def predict_exoplanet(period, duration, depth, planet_radius, equilibrium_temp, insolation, star_radius, star_temp, star_logg, mission):
55
- """Predicts if a candidate is an exoplanet based on input features."""
56
- if not model:
57
- raise gr.Error("Model is not loaded. Cannot perform prediction.")
58
 
59
- # Create feature dictionary from inputs
60
- features_dict = {}
61
-
62
- # Basic features
63
- feature_map = {
64
- 'period': period, 'duration': duration, 'depth': depth,
65
- 'planet_radius': planet_radius, 'star_radius': star_radius,
66
- 'star_temp': star_temp, 'star_logg': star_logg,
67
- 'equilibrium_temp': equilibrium_temp, 'insolation_flux': insolation
68
- }
69
 
70
- for fname, fval in feature_map.items():
71
- if fname in feature_names:
72
- features_dict[fname] = fval
73
-
74
- # Engineered features
75
- if 'transit_period_ratio' in feature_names and period > 0:
76
- features_dict['transit_period_ratio'] = duration / (period * 24)
77
- if 'radius_ratio' in feature_names and star_radius > 0:
78
- features_dict['radius_ratio'] = planet_radius / star_radius
79
- if 'period_log' in feature_names and period > 0:
80
- features_dict['period_log'] = np.log10(period)
81
- if 'insolation_flux_log' in feature_names and insolation > 0:
82
- features_dict['insolation_flux_log'] = np.log10(insolation)
83
- if 'habitable_zone_dist' in feature_names:
84
- features_dict['habitable_zone_dist'] = abs(equilibrium_temp - 288) / 288
85
-
86
- if 'star_class' in feature_names:
87
- if star_temp >= 7500: features_dict['star_class'] = 5
88
- elif star_temp >= 6000: features_dict['star_class'] = 4
89
- elif star_temp >= 5200: features_dict['star_class'] = 3
90
- elif star_temp >= 3700: features_dict['star_class'] = 2
91
- else: features_dict['star_class'] = 1
92
-
93
- if 'luminosity_class' in feature_names:
94
- if star_logg < 3.5: features_dict['luminosity_class'] = 3
95
- elif star_logg < 4.0: features_dict['luminosity_class'] = 2
96
- else: features_dict['luminosity_class'] = 1
97
-
98
- for m in metadata.get('missions', []):
99
- col_name = f'mission_{m}'
100
- if col_name in feature_names:
101
- features_dict[col_name] = 1 if m == mission else 0
102
-
103
- # Create feature vector in the correct order
104
- feature_vector = [features_dict.get(f, 0) for f in feature_names]
105
- X_input = np.array(feature_vector).reshape(1, -1)
106
-
107
- # Scale and predict
108
- X_scaled = scaler.transform(X_input)
109
- prediction = model.predict(X_scaled)[0]
110
- probabilities = model.predict_proba(X_scaled)[0]
111
-
112
- # Prepare outputs
113
- result_label_val = "PLANET DETECTED!" if prediction == 1 else "FALSE POSITIVE"
114
- confidence = probabilities[1] if prediction == 1 else probabilities[0]
115
-
116
- # Probability gauge
117
- gauge_fig = go.Figure(go.Indicator(
118
- mode="gauge+number",
119
- value=probabilities[1] * 100,
120
- title={'text': "Planet Probability (%)"},
121
- gauge={
122
- 'axis': {'range': [0, 100]},
123
- 'bar': {'color': "darkblue"},
124
- 'steps': [{'range': [0, 50], 'color': "lightgray"}, {'range': [50, 100], 'color': "lightgreen"}],
125
- 'threshold': {'line': {'color': "red", 'width': 4}, 'thickness': 0.75, 'value': 50}
126
- }
127
- ))
128
- gauge_fig.update_layout(height=250, margin=dict(l=20, r=20, t=50, b=20))
129
 
130
- return {result_label_val: confidence}, gauge_fig
131
-
132
- # ==================== BATCH ANALYSIS LOGIC ====================
133
- def batch_analysis(file_obj):
134
- """Performs batch prediction on an uploaded CSV file."""
135
- if not model:
136
- raise gr.Error("Model is not loaded. Cannot perform batch analysis.")
137
- if file_obj is None:
138
- return None, "Please upload a file first."
139
 
140
- try:
141
- df_upload = pd.read_csv(file_obj.name)
142
- except Exception as e:
143
- return None, f"Error reading CSV: {e}"
 
 
 
 
144
 
145
- # For this simplified batch prediction, we only use columns that directly match model features.
146
- # A more robust implementation would perform full feature engineering for each row.
147
- X_batch = pd.DataFrame(columns=feature_names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- for col in feature_names:
150
- if col in df_upload.columns:
151
- X_batch[col] = df_upload[col]
152
- else:
153
- X_batch[col] = 0 # Fill missing feature columns with 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- X_batch = X_batch.fillna(0)
 
 
 
156
 
157
- # Scale and predict
158
- X_scaled = scaler.transform(X_batch)
159
- predictions = model.predict(X_scaled)
160
- probabilities = model.predict_proba(X_scaled)[:, 1]
161
 
162
- # Add results to a new dataframe for clarity
163
- results_df = df_upload.copy()
164
- results_df['prediction'] = ['Planet' if p == 1 else 'False Positive' for p in predictions]
165
- results_df['planet_probability'] = probabilities
166
 
167
- return results_df, f"Analysis complete for {len(results_df)} candidates."
168
-
169
- # ==================== GRADIO UI ====================
170
- css = """
171
- .main-header { font-size: 2.5rem; font-weight: bold; text-align: center; }
172
- .sub-header { text-align: center; color: #666; font-size: 1.2rem; }
173
- """
174
-
175
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
176
- gr.Markdown('<div class="main-header">🪐 NASA Exoplanet AI Detector</div>', elem_classes="main-header")
177
- gr.Markdown(f"<div class='sub-header'>AI-Powered Exoplanet Detection | Model Version: {metadata.get('version', 'N/A')}</div>", elem_classes="sub-header")
 
 
 
 
178
 
179
- with gr.Tabs():
180
- with gr.TabItem("Single Prediction"):
181
- gr.Markdown("### Analyze a Single Exoplanet Candidate")
182
- with gr.Row():
183
- with gr.Column(scale=2):
184
- with gr.Accordion("Orbital & Planet Properties", open=True):
185
- period = gr.Slider(0.0, 10000.0, value=10.0, label="Orbital Period (days)")
186
- duration = gr.Slider(0.0, 48.0, value=3.0, label="Transit Duration (hours)")
187
- depth = gr.Slider(0.0, 100000.0, value=1000.0, label="Transit Depth (ppm)")
188
- planet_radius = gr.Slider(0.1, 100.0, value=1.0, label="Planet Radius (Earth radii)")
189
- equilibrium_temp = gr.Slider(0, 5000, value=288, label="Equilibrium Temperature (K)")
190
- insolation = gr.Slider(0.0, 10000.0, value=1.0, label="Insolation Flux (Earth units)")
191
-
192
- with gr.Accordion("Stellar Properties & Mission", open=True):
193
- star_radius = gr.Slider(0.1, 50.0, value=1.0, label="Star Radius (Solar radii)")
194
- star_temp = gr.Slider(2000, 50000, value=5778, label="Star Temperature (K)")
195
- star_logg = gr.Slider(0.0, 5.0, value=4.4, label="Star log(g)")
196
- mission = gr.Dropdown(metadata.get('missions', ['N/A']), label="Mission", value=metadata.get('missions', ['N/A'])[0])
197
-
198
- predict_btn = gr.Button("Analyze Candidate", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- with gr.Column(scale=1):
201
- gr.Markdown("### Prediction Results")
202
- result_label = gr.Label(label="Prediction")
203
- gauge_plot = gr.Plot(label="Probability Gauge")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- predict_btn.click(
206
- fn=predict_exoplanet,
207
- inputs=[period, duration, depth, planet_radius, equilibrium_temp, insolation, star_radius, star_temp, star_logg, mission],
208
- outputs=[result_label, gauge_plot],
209
- api_name="predict"
210
- )
 
 
 
211
 
212
- with gr.TabItem("Batch Analysis"):
213
- gr.Markdown("### Batch Analysis of Exoplanet Candidates")
214
- gr.Info("Upload a CSV file. The file should contain columns matching the model's features for best results.")
215
- with gr.Row():
216
- file_input = gr.File(label="Upload CSV", file_types=[".csv"])
217
- batch_status = gr.Textbox(label="Status", interactive=False)
218
- batch_run_btn = gr.Button("⚡ Analyze All Candidates", variant="primary")
219
- gr.Markdown("### Results")
220
- batch_output_df = gr.DataFrame(label="Batch Results")
221
 
222
- batch_run_btn.click(fn=batch_analysis, inputs=[file_input], outputs=[batch_output_df, batch_status], api_name="batch_predict")
 
 
 
 
 
 
 
 
 
223
 
224
- with gr.TabItem("Model Analytics"):
225
- gr.Markdown("### Model Performance & Dataset Information")
226
- with gr.Row():
227
- gr.Textbox(f"{metadata.get('test_accuracy', 0)*100:.2f}%", label="Test Accuracy")
228
- gr.Textbox(f"{metadata.get('test_precision', 0):.3f}", label="Precision")
229
- gr.Textbox(f"{metadata.get('test_recall', 0):.3f}", label="Recall")
230
- gr.Textbox(f"{metadata.get('test_f1_score', 0):.3f}", label="F1 Score")
231
- gr.Textbox(f"{metadata.get('test_roc_auc', 0):.3f}", label="ROC-AUC")
 
 
232
 
233
- if 'validation_scores' in metadata:
234
- gr.Markdown("### Individual Model Performance (Validation Set)")
235
- val_scores_df = pd.DataFrame([{"Model": k, "ROC-AUC": v} for k, v in metadata['validation_scores'].items()]).sort_values('ROC-AUC', ascending=False)
236
- fig = px.bar(val_scores_df, x='ROC-AUC', y='Model', orientation='h', title='Model Comparison (Validation ROC-AUC)', color='ROC-AUC', color_continuous_scale='viridis')
237
- fig.update_layout(height=400, yaxis={'categoryorder':'total ascending'})
238
- gr.Plot(value=fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
- with gr.TabItem("ℹ About"):
241
- gr.Markdown("""
242
- ### 🚀 Project Overview
243
- This application provides an interface for an AI model designed to detect exoplanets from NASA telescope data. It is built for the **NASA Space Apps Challenge 2025**.
244
- ### 📊 Data Sources
245
- The model was trained on publicly available data from multiple NASA missions, including Kepler, K2, and TESS.
246
- ### 🤖 Machine Learning Approach
247
- The core of this system is a sophisticated **ensemble model**, which combines the predictions of several machine learning algorithms to achieve higher accuracy and robustness.
248
- ### 🔗 Resources
249
- - [NASA Exoplanet Archive](https://exoplanetarchive.ipac.caltech.edu/)
250
- - [NASA Space Apps Challenge](https://www.spaceappschallenge.org/)
251
- - [Hugging Face (for model hosting)](https://huggingface.co/)
252
- - [Gradio (for the web UI)](https://www.gradio.app/)
253
- """)
 
 
 
 
 
 
 
 
 
254
 
255
- if __name__ == "__main__":
256
- demo.launch()
 
 
 
 
 
 
257
 
 
1
+ import streamlit as st
2
  import joblib
3
  import pandas as pd
4
  import numpy as np
5
  import plotly.graph_objects as go
6
  import plotly.express as px
7
+ from datetime import datetime
8
+ import time
9
+ import io
10
 
11
+ # ==================== PAGE CONFIG ====================
12
+ st.set_page_config(
13
+ page_title="NASA Exoplanet AI Detector",
14
+ page_icon="🪐",
15
+ layout="wide",
16
+ initial_sidebar_state="expanded"
17
+ )
18
 
19
+ # ==================== CUSTOM CSS ====================
20
+ st.markdown("""
21
+ <style>
22
+ .main-header {
23
+ font-size: 3.5rem;
24
+ font-weight: bold;
25
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
26
+ -webkit-background-clip: text;
27
+ -webkit-text-fill-color: transparent;
28
+ text-align: center;
29
+ padding: 20px;
30
+ }
31
+ .sub-header {
32
+ text-align: center;
33
+ color: #666;
34
+ font-size: 1.2rem;
35
+ }
36
+ .metric-card {
37
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
38
+ padding: 20px;
39
+ border-radius: 10px;
40
+ color: white;
41
+ text-align: center;
42
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
43
+ }
44
+ .stButton>button {
45
+ width: 100%;
46
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
47
+ color: white;
48
+ font-weight: bold;
49
+ }
50
+ </style>
51
+ """, unsafe_allow_html=True)
52
 
53
+ # ==================== LOAD MODEL ====================
54
+ @st.cache_resource
55
+ def load_model_package():
56
+ """Load the complete model package"""
57
  try:
58
+ # ⚠️ UPDATE THIS FILENAME WITH YOUR ACTUAL MODEL FILE
59
+ package = joblib.load("exoplanet_final_model.joblib")
60
  return package
61
  except Exception as e:
62
+ st.error(f" Error loading model: {e}")
63
+ st.error("Please update the filename in the code (line 47)")
64
+ st.stop()
 
 
 
 
 
 
65
 
66
+ # Load package
67
+ with st.spinner(" Loading AI model..."):
68
+ package = load_model_package()
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ model = package['ensemble_model']
71
+ scaler = package['scaler']
72
+ feature_names = package['feature_names']
73
+ metadata = package['metadata']
74
 
75
+ # ==================== HEADER ====================
76
+ st.markdown('<div class="main-header">🪐 NASA Space Apps Challenge 2025</div>', unsafe_allow_html=True)
77
+ st.markdown('<div class="sub-header">AI-Powered Exoplanet Detection System</div>', unsafe_allow_html=True)
78
+ st.markdown(f"<div class='sub-header'>Trained on {', '.join(metadata['missions'])} mission data</div>", unsafe_allow_html=True)
 
79
 
80
+ # ==================== SIDEBAR ====================
81
+ with st.sidebar:
82
+ st.image("https://www.nasa.gov/wp-content/uploads/2018/07/nasa-logo.svg", width=200)
 
 
 
 
 
 
 
83
 
84
+ st.markdown("---")
85
+ st.subheader(" Ensemble Components")
86
+ for model_name in metadata['ensemble_model_names']:
87
+ st.text(f"• {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ st.markdown("---")
90
+ st.subheader(" Missions")
91
+ for mission in metadata['missions']:
92
+ st.text(f"• {mission}")
93
+
94
+ st.markdown("---")
95
+ st.info(f"**Model Version:** {metadata['version']}")
96
+ st.info(f"**Created:** {metadata['created_date']}")
 
97
 
98
+ # ==================== MAIN TABS ====================
99
+ tab1, tab2, tab3, tab4, tab5 = st.tabs([
100
+ " Single Prediction",
101
+ " Batch Analysis",
102
+ " Model Analytics",
103
+ " Hyperparameter Tuning",
104
+ "ℹ About"
105
+ ])
106
 
107
+ # ==================== TAB 1: SINGLE PREDICTION ====================
108
+ with tab1:
109
+ st.header(" Analyze Single Exoplanet Candidate")
110
+ st.markdown("Enter the parameters of an exoplanet candidate to predict if it's a **planet** or **false positive**")
111
+
112
+ with st.form("prediction_form"):
113
+ col1, col2, col3 = st.columns(3)
114
+
115
+ with col1:
116
+ st.subheader(" Orbital Properties")
117
+ period = st.number_input("Orbital Period (days)", 0.0, 10000.0, 10.0,
118
+ help="Time for one complete orbit around the star")
119
+ duration = st.number_input("Transit Duration (hours)", 0.0, 48.0, 3.0,
120
+ help="Time the planet takes to cross the star")
121
+ depth = st.number_input("Transit Depth (ppm)", 0.0, 100000.0, 1000.0,
122
+ help="Brightness dip when planet transits")
123
+
124
+ with col2:
125
+ st.subheader(" Planet Properties")
126
+ planet_radius = st.number_input("Planet Radius (Earth radii)", 0.1, 100.0, 1.0,
127
+ help="Size relative to Earth")
128
+ equilibrium_temp = st.number_input("Equilibrium Temperature (K)", 0, 5000, 288,
129
+ help="Expected temperature of the planet")
130
+ insolation = st.number_input("Insolation Flux (Earth units)", 0.0, 10000.0, 1.0,
131
+ help="Energy received from star (Earth=1.0)")
132
+
133
+ with col3:
134
+ st.subheader(" Stellar Properties")
135
+ star_radius = st.number_input("Star Radius (Solar radii)", 0.1, 50.0, 1.0,
136
+ help="Size relative to the Sun")
137
+ star_temp = st.number_input("Star Temperature (K)", 2000, 50000, 5778,
138
+ help="Surface temperature (Sun=5778K)")
139
+ star_logg = st.number_input("Star log(g)", 0.0, 5.0, 4.4,
140
+ help="Surface gravity indicator")
141
+
142
+ mission = st.selectbox("Mission", metadata['missions'], help="Which telescope detected this candidate")
143
+
144
+ submit_button = st.form_submit_button(" Analyze Candidate", type="primary")
145
 
146
+ if submit_button:
147
+ with st.spinner(" Analyzing candidate..."):
148
+ # Create feature dictionary
149
+ features_dict = {}
150
+
151
+ # Basic features
152
+ feature_map = {
153
+ 'period': period,
154
+ 'duration': duration,
155
+ 'depth': depth,
156
+ 'planet_radius': planet_radius,
157
+ 'star_radius': star_radius,
158
+ 'star_temp': star_temp,
159
+ 'star_logg': star_logg,
160
+ 'equilibrium_temp': equilibrium_temp,
161
+ 'insolation_flux': insolation
162
+ }
163
+
164
+ for fname, fval in feature_map.items():
165
+ if fname in feature_names:
166
+ features_dict[fname] = fval
167
+
168
+ # Engineered features
169
+ if 'transit_period_ratio' in feature_names and period > 0:
170
+ features_dict['transit_period_ratio'] = duration / (period * 24)
171
+
172
+ if 'radius_ratio' in feature_names and star_radius > 0:
173
+ features_dict['radius_ratio'] = planet_radius / star_radius
174
+
175
+ if 'period_log' in feature_names and period > 0:
176
+ features_dict['period_log'] = np.log10(period)
177
+
178
+ if 'insolation_flux_log' in feature_names and insolation > 0:
179
+ features_dict['insolation_flux_log'] = np.log10(insolation)
180
+
181
+ if 'habitable_zone_dist' in feature_names:
182
+ features_dict['habitable_zone_dist'] = abs(equilibrium_temp - 288) / 288
183
+
184
+ # Stellar classification
185
+ if 'star_class' in feature_names:
186
+ if star_temp >= 7500: star_class = 5
187
+ elif star_temp >= 6000: star_class = 4
188
+ elif star_temp >= 5200: star_class = 3
189
+ elif star_temp >= 3700: star_class = 2
190
+ else: star_class = 1
191
+ features_dict['star_class'] = star_class
192
+
193
+ if 'luminosity_class' in feature_names:
194
+ if star_logg < 3.5: lum_class = 3
195
+ elif star_logg < 4.0: lum_class = 2
196
+ else: lum_class = 1
197
+ features_dict['luminosity_class'] = lum_class
198
+
199
+ # Mission encoding
200
+ for m in metadata['missions']:
201
+ col_name = f'mission_{m}'
202
+ if col_name in feature_names:
203
+ features_dict[col_name] = 1 if m == mission else 0
204
+
205
+ # Create feature vector
206
+ feature_vector = [features_dict.get(f, 0) for f in feature_names]
207
+ X_input = np.array(feature_vector).reshape(1, -1)
208
+
209
+ # Scale and predict
210
+ X_scaled = scaler.transform(X_input)
211
+ prediction = model.predict(X_scaled)[0]
212
+ probabilities = model.predict_proba(X_scaled)[0]
213
+
214
+ # Display results
215
+ st.markdown("---")
216
+ st.markdown("### Prediction Results")
217
+
218
+ result_col1, result_col2, result_col3 = st.columns([2, 2, 3])
219
+
220
+ with result_col1:
221
+ if prediction == 1:
222
+ st.success("### PLANET DETECTED!")
223
+ confidence = probabilities[1]
224
+ else:
225
+ st.error("### FALSE POSITIVE")
226
+ confidence = probabilities[0]
227
+
228
+ with result_col2:
229
+ st.metric("Confidence Score", f"{confidence*100:.1f}%",
230
+ delta=f"{(confidence-0.5)*100:.1f}% from neutral")
231
+
232
+ if confidence > 0.9:
233
+ st.info(" Very High Confidence")
234
+ elif confidence > 0.75:
235
+ st.info(" High Confidence")
236
+ elif confidence > 0.6:
237
+ st.info(" Moderate Confidence")
238
+ else:
239
+ st.info(" Low Confidence")
240
+
241
+ with result_col3:
242
+ # Probability gauge
243
+ fig = go.Figure(go.Indicator(
244
+ mode="gauge+number+delta",
245
+ value=probabilities[1] * 100,
246
+ title={'text': "Planet Probability (%)"},
247
+ delta={'reference': 50, 'increasing': {'color': "green"}},
248
+ gauge={
249
+ 'axis': {'range': [0, 100], 'tickwidth': 1},
250
+ 'bar': {'color': "darkblue"},
251
+ 'steps': [
252
+ {'range': [0, 25], 'color': "lightgray"},
253
+ {'range': [25, 50], 'color': "gray"},
254
+ {'range': [50, 75], 'color': "lightblue"},
255
+ {'range': [75, 100], 'color': "lightgreen"}
256
+ ],
257
+ 'threshold': {
258
+ 'line': {'color': "red", 'width': 4},
259
+ 'thickness': 0.75,
260
+ 'value': 50
261
+ }
262
+ }
263
+ ))
264
+ fig.update_layout(height=280, margin=dict(l=20, r=20, t=80, b=20))
265
+ st.plotly_chart(fig, use_container_width=True)
266
+
267
+ # Detailed probabilities
268
+ st.markdown("---")
269
+ st.subheader(" Detailed Probabilities")
270
+
271
+ prob_col1, prob_col2 = st.columns(2)
272
 
273
+ with prob_col1:
274
+ st.metric("False Positive Probability", f"{probabilities[0]*100:.2f}%")
275
+ with prob_col2:
276
+ st.metric("Planet Probability", f"{probabilities[1]*100:.2f}%")
277
 
278
+ # ==================== TAB 2: BATCH ANALYSIS ====================
279
+ with tab2:
280
+ st.header(" Batch Analysis")
281
+ st.markdown("Upload a CSV file with multiple exoplanet candidates for batch predictions")
282
 
283
+ st.info(" **Tip:** Your CSV should contain columns matching the feature names used by the model")
 
 
 
284
 
285
+ uploaded_file = st.file_uploader("Choose CSV file", type=['csv'])
286
+
287
+ if uploaded_file:
288
+ df_upload = pd.read_csv(uploaded_file)
289
+
290
+ st.subheader(" Uploaded Data Preview")
291
+ st.dataframe(df_upload.head(10), use_container_width=True)
292
+
293
+ st.metric("Total Candidates", len(df_upload))
294
+
295
+ if st.button(" Analyze All Candidates", type="primary"):
296
+ with st.spinner("Analyzing all candidates..."):
297
+ st.success(f" Would analyze {len(df_upload)} candidates!")
298
+ st.info(" Feature coming soon: Batch prediction implementation")
299
+ st.balloons()
300
 
301
+ # ==================== TAB 3: MODEL ANALYTICS ====================
302
+ with tab3:
303
+ st.header(" Model Performance Analytics")
304
+
305
+ # Metrics Overview
306
+ st.subheader(" Test Set Performance")
307
+ metric_col1, metric_col2, metric_col3, metric_col4, metric_col5 = st.columns(5)
308
+
309
+ with metric_col1:
310
+ st.metric("Accuracy", f"{metadata['test_accuracy']*100:.2f}%")
311
+ with metric_col2:
312
+ st.metric("Precision", f"{metadata['test_precision']:.3f}")
313
+ with metric_col3:
314
+ st.metric("Recall", f"{metadata['test_recall']:.3f}")
315
+ with metric_col4:
316
+ st.metric("F1 Score", f"{metadata['test_f1_score']:.3f}")
317
+ with metric_col5:
318
+ st.metric("ROC-AUC", f"{metadata['test_roc_auc']:.3f}")
319
+
320
+ st.markdown("---")
321
+
322
+ # Dataset Information
323
+ st.subheader(" Dataset Information")
324
+ data_col1, data_col2, data_col3, data_col4 = st.columns(4)
325
+
326
+ with data_col1:
327
+ st.metric("Total Samples", f"{metadata['total_samples']:,}")
328
+ with data_col2:
329
+ st.metric("Planets", f"{metadata['planets_total']:,}")
330
+ with data_col3:
331
+ st.metric("False Positives", f"{metadata['false_positives_total']:,}")
332
+ with data_col4:
333
+ st.metric("Planet %", f"{metadata['planet_percentage']:.1f}%")
334
+
335
+ st.markdown("---")
336
+
337
+ # Model Comparison
338
+ st.subheader(" Individual Model Performance (Validation Set)")
339
+
340
+ if 'validation_scores' in metadata:
341
+ val_scores_df = pd.DataFrame([
342
+ {"Model": k, "ROC-AUC": v}
343
+ for k, v in metadata['validation_scores'].items()
344
+ ]).sort_values('ROC-AUC', ascending=False)
345
+
346
+ fig = px.bar(val_scores_df, x='ROC-AUC', y='Model', orientation='h',
347
+ title='Model Comparison (Validation ROC-AUC)',
348
+ color='ROC-AUC', color_continuous_scale='viridis')
349
+ fig.update_layout(height=400, yaxis={'categoryorder':'total ascending'})
350
+ st.plotly_chart(fig, use_container_width=True)
351
+
352
+ st.markdown("---")
353
+
354
+ # Cross-Validation
355
+ st.subheader(" Cross-Validation Results")
356
+ cv_col1, cv_col2, cv_col3 = st.columns(3)
357
+
358
+ with cv_col1:
359
+ st.metric("CV Mean ROC-AUC", f"{metadata['cv_mean_roc_auc']:.4f}")
360
+ with cv_col2:
361
+ st.metric("CV Std Dev", f"±{metadata['cv_std_roc_auc']:.4f}")
362
+ with cv_col3:
363
+ overfitting_status = metadata.get('overfitting_check', 'Unknown')
364
+ st.metric("Overfitting Check", overfitting_status)
365
 
366
+ # ==================== TAB 4: HYPERPARAMETER TUNING ====================
367
+ with tab4:
368
+ st.header(" Hyperparameter Tuning")
369
+ st.markdown("Customize model hyperparameters and train new models")
370
+
371
+ # ==================== PRESET CONFIGURATIONS ====================
372
+ st.subheader(" Quick Presets")
373
+
374
+ preset_col1, preset_col2, preset_col3, preset_col4 = st.columns(4)
375
+
376
+ with preset_col1:
377
+ if st.button(" Best Performance", help="Optimized for maximum accuracy"):
378
+ st.session_state.preset = "best"
379
+
380
+ with preset_col2:
381
+ if st.button(" Fast Training", help="Quick training, good accuracy"):
382
+ st.session_state.preset = "fast"
383
+
384
+ with preset_col3:
385
+ if st.button(" Anti-Overfit", help="Maximum generalization"):
386
+ st.session_state.preset = "safe"
387
+
388
+ with preset_col4:
389
+ if st.button(" Research Grade", help="Publication-quality"):
390
+ st.session_state.preset = "research"
391
+
392
+ # Initialize session state
393
+ if 'preset' not in st.session_state:
394
+ st.session_state.preset = "best"
395
+
396
+ # Define presets
397
+ presets = {
398
+ "best": {
399
+ "rf_n_estimators": 300, "rf_max_depth": 15, "rf_min_samples_split": 8,
400
+ "rf_min_samples_leaf": 4, "rf_max_features": "sqrt",
401
+ "gb_n_estimators": 150, "gb_learning_rate": 0.05, "gb_max_depth": 5,
402
+ "gb_min_samples_split": 10, "gb_subsample": 0.8,
403
+ "xgb_n_estimators": 200, "xgb_learning_rate": 0.05, "xgb_max_depth": 6,
404
+ "xgb_min_child_weight": 5, "xgb_subsample": 0.8, "xgb_colsample": 0.8,
405
+ "lgb_n_estimators": 200, "lgb_learning_rate": 0.05, "lgb_max_depth": 7,
406
+ "lgb_num_leaves": 25, "lgb_min_child_samples": 20, "lgb_subsample": 0.8
407
+ },
408
+ "fast": {
409
+ "rf_n_estimators": 100, "rf_max_depth": 10, "rf_min_samples_split": 10,
410
+ "rf_min_samples_leaf": 5, "rf_max_features": "sqrt",
411
+ "gb_n_estimators": 75, "gb_learning_rate": 0.1, "gb_max_depth": 4,
412
+ "gb_min_samples_split": 10, "gb_subsample": 0.8,
413
+ "xgb_n_estimators": 100, "xgb_learning_rate": 0.1, "xgb_max_depth": 5,
414
+ "xgb_min_child_weight": 3, "xgb_subsample": 0.8, "xgb_colsample": 0.8,
415
+ "lgb_n_estimators": 100, "lgb_learning_rate": 0.1, "lgb_max_depth": 6,
416
+ "lgb_num_leaves": 20, "lgb_min_child_samples": 15, "lgb_subsample": 0.8
417
+ },
418
+ "safe": {
419
+ "rf_n_estimators": 200, "rf_max_depth": 10, "rf_min_samples_split": 15,
420
+ "rf_min_samples_leaf": 8, "rf_max_features": "sqrt",
421
+ "gb_n_estimators": 100, "gb_learning_rate": 0.03, "gb_max_depth": 3,
422
+ "gb_min_samples_split": 20, "gb_subsample": 0.7,
423
+ "xgb_n_estimators": 150, "xgb_learning_rate": 0.03, "xgb_max_depth": 4,
424
+ "xgb_min_child_weight": 8, "xgb_subsample": 0.7, "xgb_colsample": 0.7,
425
+ "lgb_n_estimators": 150, "lgb_learning_rate": 0.03, "lgb_max_depth": 5,
426
+ "lgb_num_leaves": 15, "lgb_min_child_samples": 30, "lgb_subsample": 0.7
427
+ },
428
+ "research": {
429
+ "rf_n_estimators": 400, "rf_max_depth": 18, "rf_min_samples_split": 6,
430
+ "rf_min_samples_leaf": 3, "rf_max_features": "sqrt",
431
+ "gb_n_estimators": 200, "gb_learning_rate": 0.03, "gb_max_depth": 6,
432
+ "gb_min_samples_split": 8, "gb_subsample": 0.85,
433
+ "xgb_n_estimators": 250, "xgb_learning_rate": 0.03, "xgb_max_depth": 7,
434
+ "xgb_min_child_weight": 4, "xgb_subsample": 0.85, "xgb_colsample": 0.85,
435
+ "lgb_n_estimators": 250, "lgb_learning_rate": 0.03, "lgb_max_depth": 8,
436
+ "lgb_num_leaves": 30, "lgb_min_child_samples": 15, "lgb_subsample": 0.85
437
+ }
438
+ }
439
+
440
+ selected_preset = presets[st.session_state.preset]
441
+ st.success(f" Using '{st.session_state.preset.upper()}' preset configuration!")
442
+
443
+ st.markdown("---")
444
+
445
+ # Create two columns for different models
446
+ col_left, col_right = st.columns(2)
447
+
448
+ with col_left:
449
+ st.subheader(" Random Forest")
450
+ rf_n_estimators = st.slider("RF: n_estimators", 50, 500, selected_preset["rf_n_estimators"], 10)
451
+ rf_max_depth = st.slider("RF: max_depth", 5, 30, selected_preset["rf_max_depth"], 1)
452
+ rf_min_samples_split = st.slider("RF: min_samples_split", 2, 20, selected_preset["rf_min_samples_split"], 1)
453
+ rf_min_samples_leaf = st.slider("RF: min_samples_leaf", 1, 10, selected_preset["rf_min_samples_leaf"], 1)
454
+ rf_max_features = st.selectbox("RF: max_features", ['sqrt', 'log2', None], index=0)
455
+
456
+ with col_right:
457
+ st.subheader(" Gradient Boosting")
458
+ gb_n_estimators = st.slider("GB: n_estimators", 50, 300, selected_preset["gb_n_estimators"], 10)
459
+ gb_learning_rate = st.slider("GB: learning_rate", 0.01, 0.3, selected_preset["gb_learning_rate"], 0.01)
460
+ gb_max_depth = st.slider("GB: max_depth", 3, 10, selected_preset["gb_max_depth"], 1)
461
+ gb_min_samples_split = st.slider("GB: min_samples_split", 2, 20, selected_preset["gb_min_samples_split"], 1)
462
+ gb_subsample = st.slider("GB: subsample", 0.5, 1.0, selected_preset["gb_subsample"], 0.05)
463
+
464
+ with st.expander(" XGBoost Parameters"):
465
+ col1, col2 = st.columns(2)
466
+ with col1:
467
+ xgb_n_estimators = st.slider("XGB: n_estimators", 50, 300, selected_preset["xgb_n_estimators"], 10, key="xgb_n")
468
+ xgb_learning_rate = st.slider("XGB: learning_rate", 0.01, 0.3, selected_preset["xgb_learning_rate"], 0.01, key="xgb_lr")
469
+ xgb_max_depth = st.slider("XGB: max_depth", 3, 10, selected_preset["xgb_max_depth"], 1, key="xgb_depth")
470
+ with col2:
471
+ xgb_min_child_weight = st.slider("XGB: min_child_weight", 1, 10, selected_preset["xgb_min_child_weight"], 1)
472
+ xgb_subsample = st.slider("XGB: subsample", 0.5, 1.0, selected_preset["xgb_subsample"], 0.05, key="xgb_sub")
473
+ xgb_colsample = st.slider("XGB: colsample_bytree", 0.5, 1.0, selected_preset["xgb_colsample"], 0.05)
474
+
475
+ with st.expander(" LightGBM Parameters"):
476
+ col1, col2 = st.columns(2)
477
+ with col1:
478
+ lgb_n_estimators = st.slider("LGB: n_estimators", 50, 300, selected_preset["lgb_n_estimators"], 10, key="lgb_n")
479
+ lgb_learning_rate = st.slider("LGB: learning_rate", 0.01, 0.3, selected_preset["lgb_learning_rate"], 0.01, key="lgb_lr")
480
+ lgb_max_depth = st.slider("LGB: max_depth", 3, 15, selected_preset["lgb_max_depth"], 1, key="lgb_depth")
481
+ with col2:
482
+ lgb_num_leaves = st.slider("LGB: num_leaves", 10, 100, selected_preset["lgb_num_leaves"], 5)
483
+ lgb_min_child_samples = st.slider("LGB: min_child_samples", 5, 50, selected_preset["lgb_min_child_samples"], 5)
484
+ lgb_subsample = st.slider("LGB: subsample", 0.5, 1.0, selected_preset["lgb_subsample"], 0.05, key="lgb_sub")
485
+
486
+ st.markdown("---")
487
+
488
+ # Generate code button
489
+ st.subheader(" Generated Training Code")
490
+
491
+ if st.button(" Generate Retraining Code"):
492
+ generated_code = f"""# Generated on {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
493
 
494
+ # Random Forest Parameters
495
+ rf_params = {{
496
+ 'n_estimators': {rf_n_estimators},
497
+ 'max_depth': {rf_max_depth},
498
+ 'min_samples_split': {rf_min_samples_split},
499
+ 'min_samples_leaf': {rf_min_samples_leaf},
500
+ 'max_features': {repr(rf_max_features)},
501
+ 'random_state': 42, 'n_jobs': -1, 'class_weight': 'balanced'
502
+ }}
503
 
504
+ # Gradient Boosting Parameters
505
+ gb_params = {{
506
+ 'n_estimators': {gb_n_estimators},
507
+ 'learning_rate': {gb_learning_rate},
508
+ 'max_depth': {gb_max_depth},
509
+ 'min_samples_split': {gb_min_samples_split},
510
+ 'subsample': {gb_subsample},
511
+ 'random_state': 42
512
+ }}
513
 
514
+ # XGBoost Parameters
515
+ xgb_params = {{
516
+ 'n_estimators': {xgb_n_estimators},
517
+ 'learning_rate': {xgb_learning_rate},
518
+ 'max_depth': {xgb_max_depth},
519
+ 'min_child_weight': {xgb_min_child_weight},
520
+ 'subsample': {xgb_subsample},
521
+ 'colsample_bytree': {xgb_colsample},
522
+ 'random_state': 42, 'n_jobs': -1
523
+ }}
524
 
525
+ # LightGBM Parameters
526
+ lgb_params = {{
527
+ 'n_estimators': {lgb_n_estimators},
528
+ 'learning_rate': {lgb_learning_rate},
529
+ 'max_depth': {lgb_max_depth},
530
+ 'num_leaves': {lgb_num_leaves},
531
+ 'min_child_samples': {lgb_min_child_samples},
532
+ 'subsample': {lgb_subsample},
533
+ 'random_state': 42, 'n_jobs': -1, 'verbose': -1
534
+ }}
535
 
536
+ # Train models
537
+ trained_models, final_model = train_all_models_anti_overfit(
538
+ X_train_scaled, y_train, X_val_scaled, y_val
539
+ )
540
+ """
541
+ st.code(generated_code, language="python")
542
+ st.success(" Code generated! Copy and paste into Jupyter notebook.")
543
+
544
+ st.markdown("---")
545
+
546
+ # ==================== TRAIN MODEL IN STREAMLIT ====================
547
+ st.subheader(" Train Model with Custom Parameters")
548
+
549
+ train_col1, train_col2 = st.columns([3, 1])
550
+
551
+ with train_col1:
552
+ st.info("""
553
+ **How it works:**
554
+ 1. Adjust hyperparameters above
555
+ 2. Click "Train New Model"
556
+ 3. Wait 5-15 minutes for training
557
+ 4. Download trained model
558
+ 5. Replace old model and restart app
559
+ """)
560
+
561
+ with train_col2:
562
+ train_button = st.button(" Train New Model", type="primary", use_container_width=True)
563
+
564
+ if train_button:
565
+ st.markdown("---")
566
+ st.header(" Training in Progress...")
567
+
568
+ progress_bar = st.progress(0)
569
+ status_text = st.empty()
570
+
571
+ try:
572
+ # Step 1: Load Data
573
+ status_text.text("Step 1/5: Loading datasets...")
574
+ progress_bar.progress(10)
575
+
576
+ @st.cache_data
577
+ def load_training_data():
578
+ import requests
579
+ from io import StringIO
580
+ datasets = {}
581
+ try:
582
+ url = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync?query=select+*+from+koi&format=csv"
583
+ response = requests.get(url, timeout=30)
584
+ if response.status_code == 200:
585
+ datasets['kepler'] = pd.read_csv(StringIO(response.text))
586
+ except: pass
587
+ try:
588
+ url = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync?query=select+*+from+toi&format=csv"
589
+ response = requests.get(url, timeout=30)
590
+ if response.status_code == 200:
591
+ datasets['tess'] = pd.read_csv(StringIO(response.text))
592
+ except: pass
593
+ try:
594
+ url = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync?query=select+*+from+k2pandc&format=csv"
595
+ response = requests.get(url, timeout=30)
596
+ if response.status_code == 200:
597
+ datasets['k2'] = pd.read_csv(StringIO(response.text))
598
+ except: pass
599
+ return datasets
600
+
601
+ datasets = load_training_data()
602
+ if len(datasets) == 0:
603
+ st.error(" Unable to load datasets")
604
+ st.stop()
605
+
606
+ st.success(f" Loaded {len(datasets)} dataset(s)")
607
+ progress_bar.progress(20)
608
+
609
+ # Step 2: Preprocess
610
+ status_text.text("Step 2/5: Preprocessing...")
611
+
612
+ from sklearn.model_selection import train_test_split
613
+ from sklearn.preprocessing import RobustScaler
614
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier
615
+
616
+ def quick_preprocess(datasets):
617
+ dfs = []
618
+ for mission, df in datasets.items():
619
+ df_copy = df.copy()
620
+ numeric_cols = df_copy.select_dtypes(include=[np.number]).columns.tolist()
621
+ target_cols = ['koi_disposition', 'tfopwg_disp', 'disposition']
622
+ target_col = None
623
+ for tc in target_cols:
624
+ if tc in df_copy.columns:
625
+ target_col = tc
626
+ break
627
+ if target_col is None:
628
+ continue
629
+
630
+ # Create binary target
631
+ if mission == 'kepler':
632
+ df_copy['target'] = df_copy[target_col].apply(
633
+ lambda x: 1 if str(x).upper() in ['CONFIRMED', 'CANDIDATE'] else 0
634
+ )
635
+ elif mission == 'tess':
636
+ df_copy['target'] = df_copy[target_col].apply(
637
+ lambda x: 1 if str(x).upper() in ['PC', 'CP', 'KP'] else 0
638
+ )
639
+ else:
640
+ df_copy['target'] = df_copy[target_col].apply(
641
+ lambda x: 1 if str(x).upper() in ['CONFIRMED', 'CANDIDATE'] else 0
642
+ )
643
+
644
+ keep_cols = [col for col in numeric_cols if col != target_col] + ['target']
645
+ df_subset = df_copy[keep_cols].copy()
646
+ dfs.append(df_subset)
647
+
648
+ # Combine all datasets
649
+ combined = pd.concat(dfs, ignore_index=True)
650
+
651
+ # CRITICAL: Remove columns with too many missing values FIRST
652
+ missing_pct = combined.isnull().sum() / len(combined)
653
+ cols_to_keep = missing_pct[missing_pct < 0.7].index.tolist() # Keep columns with <70% missing
654
+ combined = combined[cols_to_keep]
655
+
656
+ # Fill remaining NaN values with median
657
+ for col in combined.columns:
658
+ if col != 'target':
659
+ if combined[col].isnull().any():
660
+ median_val = combined[col].median()
661
+ # If median is also NaN (all values are NaN), use 0
662
+ if pd.isna(median_val):
663
+ combined[col].fillna(0, inplace=True)
664
+ else:
665
+ combined[col].fillna(median_val, inplace=True)
666
+
667
+ # Replace infinite values
668
+ combined = combined.replace([np.inf, -np.inf], 0)
669
+
670
+ # Remove rows with ANY remaining missing values in features
671
+ combined = combined.dropna(subset=[col for col in combined.columns if col != 'target'])
672
+
673
+ # Final safety check: ensure NO NaN values remain
674
+ assert combined.isnull().sum().sum() == 0, "NaN values still present after preprocessing!"
675
+
676
+ return combined
677
+
678
+ processed_data = quick_preprocess(datasets)
679
+ X = processed_data.drop('target', axis=1)
680
+ y = processed_data['target']
681
+
682
+ st.success(f" Preprocessed {len(X)} samples")
683
+ progress_bar.progress(35)
684
+
685
+ # Step 3: Split and Scale
686
+ status_text.text("Step 3/5: Splitting and scaling...")
687
+
688
+ X_train, X_test, y_train, y_test = train_test_split(
689
+ X, y, test_size=0.2, random_state=42, stratify=y
690
+ )
691
+
692
+ scaler_new = RobustScaler()
693
+ X_train_scaled = scaler_new.fit_transform(X_train)
694
+ X_test_scaled = scaler_new.transform(X_test)
695
+
696
+ progress_bar.progress(45)
697
+
698
+ # Step 4: Train Models
699
+ status_text.text("Step 4/5: Training models...")
700
+
701
+ models_trained = {}
702
+
703
+ st.write(" Training Random Forest...")
704
+ rf_new = RandomForestClassifier(
705
+ n_estimators=rf_n_estimators, max_depth=rf_max_depth,
706
+ min_samples_split=rf_min_samples_split, min_samples_leaf=rf_min_samples_leaf,
707
+ max_features=rf_max_features, class_weight='balanced',
708
+ random_state=42, n_jobs=-1
709
+ )
710
+ rf_new.fit(X_train_scaled, y_train)
711
+ models_trained['RandomForest'] = rf_new
712
+ progress_bar.progress(55)
713
+
714
+ st.write(" Training Gradient Boosting...")
715
+ gb_new = GradientBoostingClassifier(
716
+ n_estimators=gb_n_estimators, learning_rate=gb_learning_rate,
717
+ max_depth=gb_max_depth, min_samples_split=gb_min_samples_split,
718
+ subsample=gb_subsample, random_state=42
719
+ )
720
+ gb_new.fit(X_train_scaled, y_train)
721
+ models_trained['GradientBoosting'] = gb_new
722
+ progress_bar.progress(65)
723
+
724
+ try:
725
+ import xgboost as xgb
726
+ st.write(" Training XGBoost...")
727
+ xgb_new = xgb.XGBClassifier(
728
+ n_estimators=xgb_n_estimators, learning_rate=xgb_learning_rate,
729
+ max_depth=xgb_max_depth, min_child_weight=xgb_min_child_weight,
730
+ subsample=xgb_subsample, colsample_bytree=xgb_colsample,
731
+ random_state=42, n_jobs=-1
732
+ )
733
+ xgb_new.fit(X_train_scaled, y_train)
734
+ models_trained['XGBoost'] = xgb_new
735
+ except:
736
+ st.warning(" XGBoost not available")
737
+ progress_bar.progress(75)
738
+
739
+ try:
740
+ import lightgbm as lgb
741
+ st.write(" Training LightGBM...")
742
+ lgb_new = lgb.LGBMClassifier(
743
+ n_estimators=lgb_n_estimators, learning_rate=lgb_learning_rate,
744
+ max_depth=lgb_max_depth, num_leaves=lgb_num_leaves,
745
+ min_child_samples=lgb_min_child_samples, subsample=lgb_subsample,
746
+ random_state=42, n_jobs=-1, verbose=-1
747
+ )
748
+ lgb_new.fit(X_train_scaled, y_train)
749
+ models_trained['LightGBM'] = lgb_new
750
+ except:
751
+ st.warning(" LightGBM not available")
752
+ progress_bar.progress(85)
753
+
754
+ st.write(" Creating Ensemble...")
755
+ estimators = [(name, model) for name, model in models_trained.items()]
756
+ ensemble_new = VotingClassifier(estimators=estimators, voting='soft', n_jobs=-1)
757
+ ensemble_new.fit(X_train_scaled, y_train)
758
+ progress_bar.progress(90)
759
+
760
+ # Step 5: Evaluate
761
+ status_text.text("Step 5/5: Evaluating...")
762
+
763
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
764
+
765
+ y_pred = ensemble_new.predict(X_test_scaled)
766
+ y_pred_proba = ensemble_new.predict_proba(X_test_scaled)[:, 1]
767
+
768
+ new_metrics = {
769
+ 'accuracy': accuracy_score(y_test, y_pred),
770
+ 'precision': precision_score(y_test, y_pred, zero_division=0),
771
+ 'recall': recall_score(y_test, y_pred, zero_division=0),
772
+ 'f1_score': f1_score(y_test, y_pred, zero_division=0),
773
+ 'roc_auc': roc_auc_score(y_test, y_pred_proba)
774
+ }
775
+
776
+ progress_bar.progress(100)
777
+ status_text.text(" Training complete!")
778
+
779
+ st.success(" Model training complete!")
780
+
781
+ st.markdown("---")
782
+ st.subheader(" New Model Performance")
783
+
784
+ metric_col1, metric_col2, metric_col3, metric_col4, metric_col5 = st.columns(5)
785
+ with metric_col1:
786
+ st.metric("Accuracy", f"{new_metrics['accuracy']:.3f}")
787
+ with metric_col2:
788
+ st.metric("Precision", f"{new_metrics['precision']:.3f}")
789
+ with metric_col3:
790
+ st.metric("Recall", f"{new_metrics['recall']:.3f}")
791
+ with metric_col4:
792
+ st.metric("F1 Score", f"{new_metrics['f1_score']:.3f}")
793
+ with metric_col5:
794
+ st.metric("ROC-AUC", f"{new_metrics['roc_auc']:.3f}")
795
+
796
+ # Save model
797
+ st.markdown("---")
798
+ st.subheader(" Download New Model")
799
+
800
+ new_model_package = {
801
+ 'ensemble_model': ensemble_new,
802
+ 'individual_models': models_trained,
803
+ 'scaler': scaler_new,
804
+ 'feature_names': X.columns.tolist(),
805
+ 'metadata': {
806
+ 'version': '2.0',
807
+ 'created_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
808
+ 'created_date': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
809
+ 'missions': list(datasets.keys()),
810
+ 'total_samples': len(X),
811
+ 'train_samples': len(X_train),
812
+ 'test_samples': len(X_test),
813
+ 'n_features': len(X.columns),
814
+ 'test_accuracy': float(new_metrics['accuracy']),
815
+ 'test_precision': float(new_metrics['precision']),
816
+ 'test_recall': float(new_metrics['recall']),
817
+ 'test_f1_score': float(new_metrics['f1_score']),
818
+ 'test_roc_auc': float(new_metrics['roc_auc']),
819
+ 'n_models_in_ensemble': len(models_trained),
820
+ 'ensemble_model_names': list(models_trained.keys()),
821
+ 'planets_total': int(y.sum()),
822
+ 'false_positives_total': int((y==0).sum()),
823
+ 'planet_percentage': float(y.mean() * 100),
824
+ 'cv_mean_roc_auc': 0.0,
825
+ 'cv_std_roc_auc': 0.0,
826
+ 'overfitting_check': 'Not tested'
827
+ }
828
+ }
829
+
830
+ buffer = io.BytesIO()
831
+ joblib.dump(new_model_package, buffer)
832
+ buffer.seek(0)
833
+
834
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
835
+ filename = f"exoplanet_final_model.joblib"
836
+
837
+ st.download_button(
838
+ label="⬇ Download New Model",
839
+ data=buffer,
840
+ file_name=filename,
841
+ mime="application/octet-stream",
842
+ type="primary"
843
+ )
844
+
845
+ st.success(f" Model ready! Update line 47 with: `{filename}`")
846
+
847
+ except Exception as e:
848
+ st.error(f" Error: {str(e)}")
849
 
850
+ # ==================== TAB 5: ABOUT ====================
851
+ with tab5:
852
+ st.header("ℹ About This System")
853
+
854
+ st.markdown("""
855
+ ### Project Overview
856
+ AI-powered exoplanet detection using NASA telescope data.
857
+
858
+ ### Data Sources
859
+ - **Kepler Mission**: Stellar transit observations
860
+ - **TESS Mission**: Transiting Exoplanet Survey Satellite
861
+ - **K2 Mission**: Extended Kepler observations
862
+
863
+ ### ML Approach
864
+ Multi-model ensemble with advanced feature engineering
865
+
866
+ ### NASA Space Apps Challenge 2025
867
+ Built for "A World Away: Hunting for Exoplanets with AI"
868
+
869
+ ### Resources
870
+ - [NASA Exoplanet Archive](https://exoplanetarchive.ipac.caltech.edu/)
871
+ - [Space Apps Challenge](https://www.spaceappschallenge.org/)
872
+ """)
873
 
874
+ st.markdown("---")
875
+ st.markdown("""
876
+ <div style='text-align: center; color: #666;'>
877
+ <p><strong>NASA Space Apps Challenge 2025</strong></p>
878
+ <p>Built with ❤️ using Streamlit & Machine Learning</p>
879
+ <p>🌟 Detecting exoplanets • One transit at a time 🪐</p>
880
+ </div>
881
+ """, unsafe_allow_html=True)
882