Nugget-cloud commited on
Commit
ec77d63
·
verified ·
1 Parent(s): 680083f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -181
app.py CHANGED
@@ -1,196 +1,258 @@
1
  import gradio as gr
2
  import joblib
 
3
  import numpy as np
 
 
4
  from huggingface_hub import hf_hub_download
5
- import json
6
-
7
- # Global variables to store loaded models
8
- ensemble_model = None
9
- feature_scaler = None
10
- feature_imputer = None
11
- variance_selector = None
12
- feature_info = None
13
- model_metrics = None
14
-
15
- def load_models():
16
- """Load all models from the Hugging Face repository"""
17
- global ensemble_model, feature_scaler, feature_imputer, variance_selector, feature_info, model_metrics
18
-
 
19
  try:
20
- # Load models from your repository
21
- repo_id = "Nugget-cloud/nasa-space-apps-exoplanet"
22
-
23
- print("Loading ensemble model...")
24
- ensemble_model = joblib.load(hf_hub_download(repo_id, "exoplanet_ensemble_model.joblib"))
25
-
26
- print("Loading feature scaler...")
27
- feature_scaler = joblib.load(hf_hub_download(repo_id, "feature_scaler.joblib"))
28
-
29
- print("Loading feature imputer...")
30
- feature_imputer = joblib.load(hf_hub_download(repo_id, "feature_imputer.joblib"))
31
-
32
- print("Loading variance selector...")
33
- variance_selector = joblib.load(hf_hub_download(repo_id, "variance_selector.joblib"))
34
-
35
- # Optional files
36
- try:
37
- print("Loading feature info...")
38
- feature_info = joblib.load(hf_hub_download(repo_id, "feature_info.joblib"))
39
- except:
40
- print("Feature info not found, skipping...")
41
- feature_info = None
42
-
43
- try:
44
- print("Loading model metrics...")
45
- model_metrics = joblib.load(hf_hub_download(repo_id, "model_metrics.joblib"))
46
- except:
47
- print("Model metrics not found, skipping...")
48
- model_metrics = None
49
-
50
- print("All models loaded successfully!")
51
- return True
52
-
53
  except Exception as e:
54
- print(f"Error loading models: {str(e)}")
55
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- def predict_exoplanet(features_input):
58
- """Make prediction using the loaded models"""
59
- global ensemble_model, feature_scaler, feature_imputer, variance_selector
60
 
61
- try:
62
- # Load models if not already loaded
63
- if ensemble_model is None:
64
- if not load_models():
65
- return {"error": "Failed to load models"}
66
-
67
- # Parse input features
68
- if isinstance(features_input, str):
69
- # If input is comma-separated string
70
- features = [float(x.strip()) for x in features_input.split(',')]
71
- elif isinstance(features_input, list):
72
- # If input is already a list
73
- features = [float(x) for x in features_input]
74
- else:
75
- return {"error": "Invalid input format. Expected comma-separated string or list of numbers."}
76
-
77
- # Convert to numpy array
78
- features_array = np.array(features).reshape(1, -1)
79
-
80
- print(f"Original features shape: {features_array.shape}")
81
- print(f"Original features: {features_array}")
82
-
83
- # Apply preprocessing pipeline
84
- # 1. Impute missing values
85
- if feature_imputer:
86
- features_array = feature_imputer.transform(features_array)
87
- print(f"After imputation: {features_array.shape}")
88
-
89
- # 2. Scale features
90
- if feature_scaler:
91
- features_array = feature_scaler.transform(features_array)
92
- print(f"After scaling: {features_array.shape}")
93
-
94
- # 3. Select features (variance threshold)
95
- if variance_selector:
96
- features_array = variance_selector.transform(features_array)
97
- print(f"After variance selection: {features_array.shape}")
98
-
99
- # 4. Make prediction
100
- prediction = ensemble_model.predict(features_array)[0]
101
-
102
- # Get prediction probabilities if available
103
- probabilities = None
104
- if hasattr(ensemble_model, 'predict_proba'):
105
- probabilities = ensemble_model.predict_proba(features_array)[0].tolist()
106
-
107
- result = {
108
- "success": True,
109
- "prediction": int(prediction),
110
- "probabilities": probabilities,
111
- "confidence": max(probabilities) if probabilities else None,
112
- "input_features_count": len(features),
113
- "processed_features_count": features_array.shape[1],
114
- "model_info": {
115
- "model_type": str(type(ensemble_model).__name__),
116
- "has_probabilities": hasattr(ensemble_model, 'predict_proba')
117
- }
 
 
 
 
 
 
 
118
  }
119
-
120
- if feature_info:
121
- result["feature_info"] = feature_info
122
-
123
- if model_metrics:
124
- result["model_metrics"] = model_metrics
125
-
126
- return result
127
-
 
 
 
 
 
 
128
  except Exception as e:
129
- return {
130
- "success": False,
131
- "error": str(e)
132
- }
133
 
134
- def predict_api(features_str):
135
- """API endpoint function"""
136
- result = predict_exoplanet(features_str)
137
- return result
138
-
139
- # Create Gradio interface with API support
140
- def create_interface():
141
- with gr.Blocks(title="🪐 Exoplanet Classification API") as iface:
142
- gr.Markdown("# 🪐 Exoplanet Classification Model")
143
- gr.Markdown("Enter comma-separated feature values for exoplanet prediction using NASA Kepler/TESS data.")
144
-
145
- with gr.Row():
146
- with gr.Column():
147
- features_input = gr.Textbox(
148
- label="Features (comma-separated)",
149
- placeholder="1.2,3.4,5.6,7.8,9.1,2.3,4.5,6.7",
150
- info="Enter numerical features separated by commas"
151
- )
152
- predict_btn = gr.Button("Predict", variant="primary")
153
 
154
- with gr.Column():
155
- output = gr.JSON(label="Prediction Result")
156
-
157
- # Connect the button click to the function
158
- predict_btn.click(
159
- fn=predict_api,
160
- inputs=features_input,
161
- outputs=output,
162
- api_name="predict" # This creates an API endpoint
163
- )
164
-
165
- # Example inputs
166
- gr.Markdown("### Example Inputs:")
167
- gr.Markdown("Try these example feature sets:")
168
-
169
- examples = gr.Examples(
170
- examples=[
171
- ["1.2,3.4,5.6,7.8,9.1,2.3,4.5,6.7"],
172
- ["0.5,1.8,2.1,4.2,6.3,1.9,3.7,5.2"],
173
- ["2.1,4.3,6.5,8.7,10.9,3.2,5.4,7.6"]
174
- ],
175
- inputs=features_input,
176
- outputs=output,
177
- fn=predict_api
178
- )
179
-
180
- gr.Markdown("""
181
- ### API Usage
182
- This Space provides an API endpoint at `/api/predict` that accepts:
183
- ```json
184
- {"data": ["1.2,3.4,5.6,7.8,9.1,2.3,4.5,6.7"]}
185
- ```
186
- """)
187
 
188
- return iface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  if __name__ == "__main__":
191
- demo = create_interface()
192
- demo.launch(
193
- server_name="0.0.0.0",
194
- server_port=7860,
195
- share=True
196
- )
 
1
  import gradio as gr
2
  import joblib
3
+ import pandas as pd
4
  import numpy as np
5
+ import plotly.graph_objects as go
6
+ import plotly.express as px
7
  from huggingface_hub import hf_hub_download
8
+ import warnings
9
+
10
+ # ==================== CONFIGURATION ====================
11
+ # ⚠️ UPDATE THIS WITH YOUR HUGGING FACE REPOSITORY ID
12
+ HF_REPO_ID = "YOUR_USERNAME/YOUR_REPO_NAME"
13
+ MODEL_FILENAME = "exoplanet_final_model.joblib"
14
+
15
+ # Suppress specific warnings for a cleaner output
16
+ warnings.filterwarnings("ignore", category=UserWarning, message="Trying to unpickle estimator.*")
17
+ warnings.filterwarnings("ignore", category=FutureWarning)
18
+
19
+ # ==================== LOAD MODEL FROM HUGGING FACE ====================
20
+ @gr.cache(show_api=False)
21
+ def load_model_package(repo_id, filename):
22
+ """Load the complete model package from Hugging Face Hub"""
23
  try:
24
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
25
+ package = joblib.load(model_path)
26
+ return package
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  except Exception as e:
28
+ # Fallback for local development if HF download fails
29
+ print(f"Could not download from Hugging Face: {e}. Trying local file...")
30
+ try:
31
+ package = joblib.load(filename)
32
+ return package
33
+ except FileNotFoundError:
34
+ raise gr.Error(f"Model file not found locally or on Hugging Face at {repo_id}. Please check HF_REPO_ID and ensure the model file is available.")
35
+ except Exception as e_local:
36
+ raise gr.Error(f"Error loading local model: {e_local}")
37
+
38
+ # Load package and extract components
39
+ try:
40
+ print("Loading AI model...")
41
+ package = load_model_package(HF_REPO_ID, MODEL_FILENAME)
42
+ model = package['ensemble_model']
43
+ scaler = package['scaler']
44
+ feature_names = package['feature_names']
45
+ metadata = package.get('metadata', {}) # Use .get for safety
46
+ print("AI model loaded successfully.")
47
+ except Exception as e:
48
+ # If model loading fails, we can't run the app.
49
+ print(str(e))
50
+ # Create a dummy structure to prevent the UI from crashing on startup
51
+ model, scaler, feature_names, metadata = None, None, [], {'missions': ['N/A'], 'version': 'Error'}
52
+
53
+
54
+ # ==================== PREDICTION LOGIC ====================
55
+ def predict_exoplanet(period, duration, depth, planet_radius, equilibrium_temp, insolation, star_radius, star_temp, star_logg, mission):
56
+ """Predicts if a candidate is an exoplanet based on input features."""
57
+ if not model:
58
+ raise gr.Error("Model is not loaded. Cannot perform prediction.")
59
 
60
+ # Create feature dictionary from inputs
61
+ features_dict = {}
 
62
 
63
+ # Basic features
64
+ feature_map = {
65
+ 'period': period, 'duration': duration, 'depth': depth,
66
+ 'planet_radius': planet_radius, 'star_radius': star_radius,
67
+ 'star_temp': star_temp, 'star_logg': star_logg,
68
+ 'equilibrium_temp': equilibrium_temp, 'insolation_flux': insolation
69
+ }
70
+
71
+ for fname, fval in feature_map.items():
72
+ if fname in feature_names:
73
+ features_dict[fname] = fval
74
+
75
+ # Engineered features
76
+ if 'transit_period_ratio' in feature_names and period > 0:
77
+ features_dict['transit_period_ratio'] = duration / (period * 24)
78
+ if 'radius_ratio' in feature_names and star_radius > 0:
79
+ features_dict['radius_ratio'] = planet_radius / star_radius
80
+ if 'period_log' in feature_names and period > 0:
81
+ features_dict['period_log'] = np.log10(period)
82
+ if 'insolation_flux_log' in feature_names and insolation > 0:
83
+ features_dict['insolation_flux_log'] = np.log10(insolation)
84
+ if 'habitable_zone_dist' in feature_names:
85
+ features_dict['habitable_zone_dist'] = abs(equilibrium_temp - 288) / 288
86
+
87
+ if 'star_class' in feature_names:
88
+ if star_temp >= 7500: features_dict['star_class'] = 5
89
+ elif star_temp >= 6000: features_dict['star_class'] = 4
90
+ elif star_temp >= 5200: features_dict['star_class'] = 3
91
+ elif star_temp >= 3700: features_dict['star_class'] = 2
92
+ else: features_dict['star_class'] = 1
93
+
94
+ if 'luminosity_class' in feature_names:
95
+ if star_logg < 3.5: features_dict['luminosity_class'] = 3
96
+ elif star_logg < 4.0: features_dict['luminosity_class'] = 2
97
+ else: features_dict['luminosity_class'] = 1
98
+
99
+ for m in metadata.get('missions', []):
100
+ col_name = f'mission_{m}'
101
+ if col_name in feature_names:
102
+ features_dict[col_name] = 1 if m == mission else 0
103
+
104
+ # Create feature vector in the correct order
105
+ feature_vector = [features_dict.get(f, 0) for f in feature_names]
106
+ X_input = np.array(feature_vector).reshape(1, -1)
107
+
108
+ # Scale and predict
109
+ X_scaled = scaler.transform(X_input)
110
+ prediction = model.predict(X_scaled)[0]
111
+ probabilities = model.predict_proba(X_scaled)[0]
112
+
113
+ # Prepare outputs
114
+ result_label_val = "PLANET DETECTED!" if prediction == 1 else "FALSE POSITIVE"
115
+ confidence = probabilities[1] if prediction == 1 else probabilities[0]
116
+
117
+ # Probability gauge
118
+ gauge_fig = go.Figure(go.Indicator(
119
+ mode="gauge+number",
120
+ value=probabilities[1] * 100,
121
+ title={'text': "Planet Probability (%)"},
122
+ gauge={
123
+ 'axis': {'range': [0, 100]},
124
+ 'bar': {'color': "darkblue"},
125
+ 'steps': [{'range': [0, 50], 'color': "lightgray"}, {'range': [50, 100], 'color': "lightgreen"}],
126
+ 'threshold': {'line': {'color': "red", 'width': 4}, 'thickness': 0.75, 'value': 50}
127
  }
128
+ ))
129
+ gauge_fig.update_layout(height=250, margin=dict(l=20, r=20, t=50, b=20))
130
+
131
+ return {result_label_val: confidence}, gauge_fig
132
+
133
+ # ==================== BATCH ANALYSIS LOGIC ====================
134
+ def batch_analysis(file_obj):
135
+ """Performs batch prediction on an uploaded CSV file."""
136
+ if not model:
137
+ raise gr.Error("Model is not loaded. Cannot perform batch analysis.")
138
+ if file_obj is None:
139
+ return None, "Please upload a file first."
140
+
141
+ try:
142
+ df_upload = pd.read_csv(file_obj.name)
143
  except Exception as e:
144
+ return None, f"Error reading CSV: {e}"
 
 
 
145
 
146
+ # For this simplified batch prediction, we only use columns that directly match model features.
147
+ # A more robust implementation would perform full feature engineering for each row.
148
+ X_batch = pd.DataFrame(columns=feature_names)
149
+
150
+ for col in feature_names:
151
+ if col in df_upload.columns:
152
+ X_batch[col] = df_upload[col]
153
+ else:
154
+ X_batch[col] = 0 # Fill missing feature columns with 0
 
 
 
 
 
 
 
 
 
 
155
 
156
+ X_batch = X_batch.fillna(0)
157
+
158
+ # Scale and predict
159
+ X_scaled = scaler.transform(X_batch)
160
+ predictions = model.predict(X_scaled)
161
+ probabilities = model.predict_proba(X_scaled)[:, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ # Add results to a new dataframe for clarity
164
+ results_df = df_upload.copy()
165
+ results_df['prediction'] = ['Planet' if p == 1 else 'False Positive' for p in predictions]
166
+ results_df['planet_probability'] = probabilities
167
+
168
+ return results_df, f"Analysis complete for {len(results_df)} candidates."
169
+
170
+ # ==================== GRADIO UI ====================
171
+ css = """
172
+ .main-header { font-size: 2.5rem; font-weight: bold; text-align: center; }
173
+ .sub-header { text-align: center; color: #666; font-size: 1.2rem; }
174
+ """
175
+
176
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
177
+ gr.Markdown('<div class="main-header">🪐 NASA Exoplanet AI Detector</div>', elem_classes="main-header")
178
+ gr.Markdown(f"<div class='sub-header'>AI-Powered Exoplanet Detection | Model Version: {metadata.get('version', 'N/A')}</div>", elem_classes="sub-header")
179
+
180
+ with gr.Tabs():
181
+ with gr.TabItem("Single Prediction"):
182
+ gr.Markdown("### Analyze a Single Exoplanet Candidate")
183
+ with gr.Row():
184
+ with gr.Column(scale=2):
185
+ with gr.Accordion("Orbital & Planet Properties", open=True):
186
+ period = gr.Slider(0.0, 10000.0, value=10.0, label="Orbital Period (days)")
187
+ duration = gr.Slider(0.0, 48.0, value=3.0, label="Transit Duration (hours)")
188
+ depth = gr.Slider(0.0, 100000.0, value=1000.0, label="Transit Depth (ppm)")
189
+ planet_radius = gr.Slider(0.1, 100.0, value=1.0, label="Planet Radius (Earth radii)")
190
+ equilibrium_temp = gr.Slider(0, 5000, value=288, label="Equilibrium Temperature (K)")
191
+ insolation = gr.Slider(0.0, 10000.0, value=1.0, label="Insolation Flux (Earth units)")
192
+
193
+ with gr.Accordion("Stellar Properties & Mission", open=True):
194
+ star_radius = gr.Slider(0.1, 50.0, value=1.0, label="Star Radius (Solar radii)")
195
+ star_temp = gr.Slider(2000, 50000, value=5778, label="Star Temperature (K)")
196
+ star_logg = gr.Slider(0.0, 5.0, value=4.4, label="Star log(g)")
197
+ mission = gr.Dropdown(metadata.get('missions', ['N/A']), label="Mission", value=metadata.get('missions', ['N/A'])[0])
198
+
199
+ predict_btn = gr.Button("Analyze Candidate", variant="primary")
200
+
201
+ with gr.Column(scale=1):
202
+ gr.Markdown("### Prediction Results")
203
+ result_label = gr.Label(label="Prediction")
204
+ gauge_plot = gr.Plot(label="Probability Gauge")
205
+
206
+ predict_btn.click(
207
+ fn=predict_exoplanet,
208
+ inputs=[period, duration, depth, planet_radius, equilibrium_temp, insolation, star_radius, star_temp, star_logg, mission],
209
+ outputs=[result_label, gauge_plot],
210
+ api_name="predict"
211
+ )
212
+
213
+ with gr.TabItem("Batch Analysis"):
214
+ gr.Markdown("### Batch Analysis of Exoplanet Candidates")
215
+ gr.Info("Upload a CSV file. The file should contain columns matching the model's features for best results.")
216
+ with gr.Row():
217
+ file_input = gr.File(label="Upload CSV", file_types=[".csv"])
218
+ batch_status = gr.Textbox(label="Status", interactive=False)
219
+ batch_run_btn = gr.Button("⚡ Analyze All Candidates", variant="primary")
220
+ gr.Markdown("### Results")
221
+ batch_output_df = gr.DataFrame(label="Batch Results")
222
+
223
+ batch_run_btn.click(fn=batch_analysis, inputs=[file_input], outputs=[batch_output_df, batch_status], api_name="batch_predict")
224
+
225
+ with gr.TabItem("Model Analytics"):
226
+ gr.Markdown("### Model Performance & Dataset Information")
227
+ with gr.Row():
228
+ gr.Textbox(f"{metadata.get('test_accuracy', 0)*100:.2f}%", label="Test Accuracy")
229
+ gr.Textbox(f"{metadata.get('test_precision', 0):.3f}", label="Precision")
230
+ gr.Textbox(f"{metadata.get('test_recall', 0):.3f}", label="Recall")
231
+ gr.Textbox(f"{metadata.get('test_f1_score', 0):.3f}", label="F1 Score")
232
+ gr.Textbox(f"{metadata.get('test_roc_auc', 0):.3f}", label="ROC-AUC")
233
+
234
+ if 'validation_scores' in metadata:
235
+ gr.Markdown("### Individual Model Performance (Validation Set)")
236
+ val_scores_df = pd.DataFrame([{"Model": k, "ROC-AUC": v} for k, v in metadata['validation_scores'].items()]).sort_values('ROC-AUC', ascending=False)
237
+ fig = px.bar(val_scores_df, x='ROC-AUC', y='Model', orientation='h', title='Model Comparison (Validation ROC-AUC)', color='ROC-AUC', color_continuous_scale='viridis')
238
+ fig.update_layout(height=400, yaxis={'categoryorder':'total ascending'})
239
+ gr.Plot(value=fig)
240
+
241
+ with gr.TabItem("ℹ About"):
242
+ gr.Markdown("""
243
+ ### 🚀 Project Overview
244
+ This application provides an interface for an AI model designed to detect exoplanets from NASA telescope data. It is built for the **NASA Space Apps Challenge 2025**.
245
+ ### 📊 Data Sources
246
+ The model was trained on publicly available data from multiple NASA missions, including Kepler, K2, and TESS.
247
+ ### 🤖 Machine Learning Approach
248
+ The core of this system is a sophisticated **ensemble model**, which combines the predictions of several machine learning algorithms to achieve higher accuracy and robustness.
249
+ ### 🔗 Resources
250
+ - [NASA Exoplanet Archive](https://exoplanetarchive.ipac.caltech.edu/)
251
+ - [NASA Space Apps Challenge](https://www.spaceappschallenge.org/)
252
+ - [Hugging Face (for model hosting)](https://huggingface.co/)
253
+ - [Gradio (for the web UI)](https://www.gradio.app/)
254
+ """)
255
 
256
  if __name__ == "__main__":
257
+ demo.launch()
258
+