Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,12 +7,19 @@ import json
|
|
| 7 |
import pickle
|
| 8 |
import os
|
| 9 |
from typing import Dict, List, Any
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 12 |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
| 13 |
|
| 14 |
print("π Starting Eco Finder API...")
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
# Configuration
|
| 17 |
try:
|
| 18 |
import tensorflow as tf
|
|
@@ -90,11 +97,104 @@ train_medians = feature_stats.get("train_medians", {})
|
|
| 90 |
|
| 91 |
BASE = "https://exoplanetarchive.ipac.caltech.edu/cgi-bin/nstedAPI/nph-nstedAPI"
|
| 92 |
|
| 93 |
-
# ====================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
|
|
|
| 95 |
|
| 96 |
def predict_single(features: Dict) -> Dict:
|
| 97 |
-
"""Function to predict a single object
|
| 98 |
try:
|
| 99 |
if model is None or scaler is None or label_encoder is None:
|
| 100 |
return {"error": "Model not available"}
|
|
@@ -244,7 +344,7 @@ def predict_toi_realtime():
|
|
| 244 |
if "error" not in result:
|
| 245 |
results.append(
|
| 246 |
{
|
| 247 |
-
"TOI": row.get("toi", f"
|
| 248 |
"Disposition": row.get("tfopwg_disp", "Unknown"),
|
| 249 |
"Prediction": result["prediction"],
|
| 250 |
"P(Confirmed)": f"{result['probabilities']['CONFIRMED']:.3f}",
|
|
@@ -311,19 +411,22 @@ def predict_manual(
|
|
| 311 |
|
| 312 |
with gr.Blocks(theme=gr.themes.Soft(), title="Eco Finder API") as demo:
|
| 313 |
gr.Markdown("# π Eco Finder API")
|
| 314 |
-
gr.Markdown("Exoplanet classifier")
|
| 315 |
|
| 316 |
-
with gr.Tab("π― API
|
| 317 |
-
gr.Markdown("### Endpoint for frontend consumption")
|
| 318 |
gr.Markdown("""
|
| 319 |
-
|
| 320 |
|
| 321 |
-
|
| 322 |
-
**
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
-
|
|
|
|
| 325 |
```bash
|
| 326 |
-
curl -X POST "https://
|
| 327 |
-H "Content-Type: application/json" \\
|
| 328 |
-d '{
|
| 329 |
"koi_period": 10.0,
|
|
@@ -340,9 +443,51 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Eco Finder API") as demo:
|
|
| 340 |
"koi_num_transits": 3.0
|
| 341 |
}'
|
| 342 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
""")
|
| 344 |
|
| 345 |
-
|
|
|
|
|
|
|
| 346 |
with gr.Row():
|
| 347 |
with gr.Column():
|
| 348 |
period = gr.Number(label="koi_period", value=10.0)
|
|
@@ -360,10 +505,10 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Eco Finder API") as demo:
|
|
| 360 |
snr = gr.Number(label="koi_model_snr", value=10.0)
|
| 361 |
num_transits = gr.Number(label="koi_num_transits", value=3.0)
|
| 362 |
|
| 363 |
-
|
| 364 |
-
|
| 365 |
|
| 366 |
-
|
| 367 |
fn=predict_from_dict,
|
| 368 |
inputs=[
|
| 369 |
period,
|
|
@@ -379,7 +524,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Eco Finder API") as demo:
|
|
| 379 |
snr,
|
| 380 |
num_transits,
|
| 381 |
],
|
| 382 |
-
outputs=
|
| 383 |
)
|
| 384 |
|
| 385 |
with gr.Tab("π Real-time TOI"):
|
|
@@ -388,32 +533,30 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Eco Finder API") as demo:
|
|
| 388 |
toi_output = gr.Markdown()
|
| 389 |
toi_btn.click(predict_toi_realtime, outputs=toi_output)
|
| 390 |
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
fn=predict_manual,
|
| 397 |
-
inputs=[
|
| 398 |
-
period,
|
| 399 |
-
duration,
|
| 400 |
-
depth,
|
| 401 |
-
prad,
|
| 402 |
-
srad,
|
| 403 |
-
teq,
|
| 404 |
-
steff,
|
| 405 |
-
slogg,
|
| 406 |
-
smet,
|
| 407 |
-
kepmag,
|
| 408 |
-
snr,
|
| 409 |
-
num_transits,
|
| 410 |
-
],
|
| 411 |
-
outputs=manual_output,
|
| 412 |
-
)
|
| 413 |
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
|
| 418 |
if __name__ == "__main__":
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import pickle
|
| 8 |
import os
|
| 9 |
from typing import Dict, List, Any
|
| 10 |
+
from flask import Flask, request, jsonify
|
| 11 |
+
from flask_cors import CORS
|
| 12 |
+
import threading
|
| 13 |
|
| 14 |
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 15 |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
| 16 |
|
| 17 |
print("π Starting Eco Finder API...")
|
| 18 |
|
| 19 |
+
# Initialize Flask app
|
| 20 |
+
flask_app = Flask(__name__)
|
| 21 |
+
CORS(flask_app) # Enable CORS for all routes
|
| 22 |
+
|
| 23 |
# Configuration
|
| 24 |
try:
|
| 25 |
import tensorflow as tf
|
|
|
|
| 97 |
|
| 98 |
BASE = "https://exoplanetarchive.ipac.caltech.edu/cgi-bin/nstedAPI/nph-nstedAPI"
|
| 99 |
|
| 100 |
+
# ==================== FLASK API ENDPOINTS ====================
|
| 101 |
+
|
| 102 |
+
@flask_app.route('/')
|
| 103 |
+
def home():
|
| 104 |
+
return jsonify({
|
| 105 |
+
"message": "Eco Finder API - Exoplanet Classification",
|
| 106 |
+
"version": "1.0.0",
|
| 107 |
+
"endpoints": {
|
| 108 |
+
"health": "/health (GET)",
|
| 109 |
+
"predict": "/predict (POST)",
|
| 110 |
+
"predict_batch": "/predict-batch (POST)",
|
| 111 |
+
"features": "/features (GET)"
|
| 112 |
+
}
|
| 113 |
+
})
|
| 114 |
+
|
| 115 |
+
@flask_app.route('/health', methods=['GET'])
|
| 116 |
+
def health():
|
| 117 |
+
return jsonify({
|
| 118 |
+
"status": "healthy",
|
| 119 |
+
"model_loaded": model is not None,
|
| 120 |
+
"tensorflow_available": TENSORFLOW_AVAILABLE,
|
| 121 |
+
"features_count": len(feature_columns)
|
| 122 |
+
})
|
| 123 |
+
|
| 124 |
+
@flask_app.route('/predict', methods=['POST'])
|
| 125 |
+
def api_predict():
|
| 126 |
+
"""REST API endpoint for single prediction"""
|
| 127 |
+
try:
|
| 128 |
+
data = request.get_json()
|
| 129 |
+
|
| 130 |
+
if not data:
|
| 131 |
+
return jsonify({"error": "No JSON data provided"}), 400
|
| 132 |
+
|
| 133 |
+
# Use default values if parameters are missing
|
| 134 |
+
features = {}
|
| 135 |
+
for feature in feature_columns:
|
| 136 |
+
features[feature] = data.get(feature, train_medians.get(feature, 0))
|
| 137 |
+
|
| 138 |
+
# Make prediction
|
| 139 |
+
result = predict_single(features)
|
| 140 |
+
|
| 141 |
+
if "error" in result:
|
| 142 |
+
return jsonify(result), 500
|
| 143 |
+
|
| 144 |
+
return jsonify(result)
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
return jsonify({"error": str(e)}), 500
|
| 148 |
+
|
| 149 |
+
@flask_app.route('/predict-batch', methods=['POST'])
|
| 150 |
+
def api_predict_batch():
|
| 151 |
+
"""REST API endpoint for batch predictions"""
|
| 152 |
+
try:
|
| 153 |
+
data = request.get_json()
|
| 154 |
+
|
| 155 |
+
if not data or 'objects' not in data:
|
| 156 |
+
return jsonify({"error": "No 'objects' array provided"}), 400
|
| 157 |
+
|
| 158 |
+
predictions = []
|
| 159 |
+
for obj in data['objects']:
|
| 160 |
+
features = {}
|
| 161 |
+
for feature in feature_columns:
|
| 162 |
+
features[feature] = obj.get(feature, train_medians.get(feature, 0))
|
| 163 |
+
|
| 164 |
+
result = predict_single(features)
|
| 165 |
+
predictions.append(result)
|
| 166 |
+
|
| 167 |
+
return jsonify({"predictions": predictions})
|
| 168 |
+
|
| 169 |
+
except Exception as e:
|
| 170 |
+
return jsonify({"error": str(e)}), 500
|
| 171 |
+
|
| 172 |
+
@flask_app.route('/features', methods=['GET'])
|
| 173 |
+
def api_features():
|
| 174 |
+
"""Get information about available features"""
|
| 175 |
+
return jsonify({
|
| 176 |
+
"feature_columns": feature_columns,
|
| 177 |
+
"train_medians": train_medians,
|
| 178 |
+
"feature_descriptions": {
|
| 179 |
+
"koi_period": "Orbital period (days)",
|
| 180 |
+
"koi_duration": "Transit duration (hours)",
|
| 181 |
+
"koi_depth": "Transit depth (ppm)",
|
| 182 |
+
"koi_prad": "Planetary radius (Earth radii)",
|
| 183 |
+
"koi_srad": "Stellar radius (Solar radii)",
|
| 184 |
+
"koi_teq": "Equilibrium temperature (K)",
|
| 185 |
+
"koi_steff": "Stellar effective temperature (K)",
|
| 186 |
+
"koi_slogg": "Stellar surface gravity (log g)",
|
| 187 |
+
"koi_smet": "Stellar metallicity ([Fe/H])",
|
| 188 |
+
"koi_kepmag": "TESS magnitude",
|
| 189 |
+
"koi_model_snr": "Signal-to-noise ratio",
|
| 190 |
+
"koi_num_transits": "Number of transits"
|
| 191 |
+
}
|
| 192 |
+
})
|
| 193 |
|
| 194 |
+
# ==================== PREDICTION FUNCTIONS ====================
|
| 195 |
|
| 196 |
def predict_single(features: Dict) -> Dict:
|
| 197 |
+
"""Function to predict a single object"""
|
| 198 |
try:
|
| 199 |
if model is None or scaler is None or label_encoder is None:
|
| 200 |
return {"error": "Model not available"}
|
|
|
|
| 344 |
if "error" not in result:
|
| 345 |
results.append(
|
| 346 |
{
|
| 347 |
+
"TOI": row.get("toi", f"TOI-{idx}"),
|
| 348 |
"Disposition": row.get("tfopwg_disp", "Unknown"),
|
| 349 |
"Prediction": result["prediction"],
|
| 350 |
"P(Confirmed)": f"{result['probabilities']['CONFIRMED']:.3f}",
|
|
|
|
| 411 |
|
| 412 |
with gr.Blocks(theme=gr.themes.Soft(), title="Eco Finder API") as demo:
|
| 413 |
gr.Markdown("# π Eco Finder API")
|
| 414 |
+
gr.Markdown("Exoplanet classifier with REST API")
|
| 415 |
|
| 416 |
+
with gr.Tab("π― API Documentation"):
|
|
|
|
| 417 |
gr.Markdown("""
|
| 418 |
+
## REST API Endpoints
|
| 419 |
|
| 420 |
+
### Health Check
|
| 421 |
+
**GET** `/health`
|
| 422 |
+
```bash
|
| 423 |
+
curl -X GET "https://your-domain/health"
|
| 424 |
+
```
|
| 425 |
|
| 426 |
+
### Single Prediction
|
| 427 |
+
**POST** `/predict`
|
| 428 |
```bash
|
| 429 |
+
curl -X POST "https://your-domain/predict" \\
|
| 430 |
-H "Content-Type: application/json" \\
|
| 431 |
-d '{
|
| 432 |
"koi_period": 10.0,
|
|
|
|
| 443 |
"koi_num_transits": 3.0
|
| 444 |
}'
|
| 445 |
```
|
| 446 |
+
|
| 447 |
+
### Batch Predictions
|
| 448 |
+
**POST** `/predict-batch`
|
| 449 |
+
```bash
|
| 450 |
+
curl -X POST "https://your-domain/predict-batch" \\
|
| 451 |
+
-H "Content-Type: application/json" \\
|
| 452 |
+
-d '{
|
| 453 |
+
"objects": [
|
| 454 |
+
{"koi_period": 10.0, "koi_duration": 5.0, ...},
|
| 455 |
+
{"koi_period": 15.0, "koi_duration": 6.0, ...}
|
| 456 |
+
]
|
| 457 |
+
}'
|
| 458 |
+
```
|
| 459 |
+
|
| 460 |
+
### Features Information
|
| 461 |
+
**GET** `/features`
|
| 462 |
+
```bash
|
| 463 |
+
curl -X GET "https://your-domain/features"
|
| 464 |
+
```
|
| 465 |
+
|
| 466 |
+
### JavaScript Example
|
| 467 |
+
```javascript
|
| 468 |
+
async function predictExoplanet(features) {
|
| 469 |
+
const response = await fetch('/predict', {
|
| 470 |
+
method: 'POST',
|
| 471 |
+
headers: {
|
| 472 |
+
'Content-Type': 'application/json',
|
| 473 |
+
},
|
| 474 |
+
body: JSON.stringify(features)
|
| 475 |
+
});
|
| 476 |
+
return await response.json();
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
// Usage
|
| 480 |
+
const result = await predictExoplanet({
|
| 481 |
+
koi_period: 10.0,
|
| 482 |
+
koi_duration: 5.0,
|
| 483 |
+
// ... all parameters
|
| 484 |
+
});
|
| 485 |
+
```
|
| 486 |
""")
|
| 487 |
|
| 488 |
+
with gr.Tab("π§ Test Interface"):
|
| 489 |
+
gr.Markdown("Test the prediction model with this interface")
|
| 490 |
+
|
| 491 |
with gr.Row():
|
| 492 |
with gr.Column():
|
| 493 |
period = gr.Number(label="koi_period", value=10.0)
|
|
|
|
| 505 |
snr = gr.Number(label="koi_model_snr", value=10.0)
|
| 506 |
num_transits = gr.Number(label="koi_num_transits", value=3.0)
|
| 507 |
|
| 508 |
+
test_btn = gr.Button("π Test Prediction")
|
| 509 |
+
test_output = gr.JSON()
|
| 510 |
|
| 511 |
+
test_btn.click(
|
| 512 |
fn=predict_from_dict,
|
| 513 |
inputs=[
|
| 514 |
period,
|
|
|
|
| 524 |
snr,
|
| 525 |
num_transits,
|
| 526 |
],
|
| 527 |
+
outputs=test_output,
|
| 528 |
)
|
| 529 |
|
| 530 |
with gr.Tab("π Real-time TOI"):
|
|
|
|
| 533 |
toi_output = gr.Markdown()
|
| 534 |
toi_btn.click(predict_toi_realtime, outputs=toi_output)
|
| 535 |
|
| 536 |
+
# ==================== APPLICATION STARTUP ====================
|
| 537 |
+
|
| 538 |
+
def run_flask():
|
| 539 |
+
"""Run Flask app on port 5000"""
|
| 540 |
+
flask_app.run(host='0.0.0.0', port=5000, debug=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
+
def run_gradio():
|
| 543 |
+
"""Run Gradio app on port 7860"""
|
| 544 |
+
demo.launch(server_name='0.0.0.0', server_port=7860, share=False)
|
| 545 |
|
| 546 |
if __name__ == "__main__":
|
| 547 |
+
print("π Application started successfully!")
|
| 548 |
+
print("π Gradio Interface available at: http://0.0.0.0:7860")
|
| 549 |
+
print("π REST API available at: http://0.0.0.0:5000")
|
| 550 |
+
print("π API Documentation:")
|
| 551 |
+
print(" GET /health")
|
| 552 |
+
print(" POST /predict")
|
| 553 |
+
print(" POST /predict-batch")
|
| 554 |
+
print(" GET /features")
|
| 555 |
+
|
| 556 |
+
# Start both servers
|
| 557 |
+
flask_thread = threading.Thread(target=run_flask)
|
| 558 |
+
flask_thread.daemon = True
|
| 559 |
+
flask_thread.start()
|
| 560 |
+
|
| 561 |
+
# Run Gradio in main thread
|
| 562 |
+
run_gradio()
|