Astro-Hunters / app.py
FatimahEmadEldin's picture
Update app.py
0ce524d verified
#!/usr/bin/env python3
"""
NASA Space Apps 2025 - Comprehensive Exoplanet Hunter Interface (HDF5 Version)
- Tab 1: A game to find transits in the large, pre-processed HDF5 dataset.
- Tab 2: An AI-powered tool to predict transits in user-uploaded FITS files.
- Uses the robust v4 model.
"""
import gradio as gr
import numpy as np
import pickle
import matplotlib.pyplot as plt
import traceback
import lightkurve as lk
from scipy.ndimage import median_filter
from scipy.stats import skew
import h5py # <-- CHANGE: Import h5py for HDF5 file handling
# ============================================================================
# LOAD ALL NECESSARY DATA AND MODELS
# ============================================================================
# --- Path Configuration ---
# <-- CHANGE: Point to your new HDF5 data file for the game
GAME_DATA_PATH = "preprocessed_transit_data_v5.h5"
# <-- CHANGE: Point to your latest production model
MODEL_PATH = "exoplanet_detector_v4.pkl"
# --- Load the AI Model ---
try:
with open(MODEL_PATH, 'rb') as f:
production_model = pickle.load(f)
xgb_model = production_model['xgboost_classifier']
model_config = production_model['config']
print(f"✅ Successfully loaded AI model '{MODEL_PATH}'. Ready for predictions.")
MODEL_LOADED = True
except Exception as e:
print(f"❌ ERROR: Could not load AI model from '{MODEL_PATH}'. Prediction tab will fail. Error: {e}")
MODEL_LOADED = False
# --- Load the Game Data (from HDF5) ---
# <-- CHANGE: New loading logic for the large HDF5 file
try:
# Open the HDF5 file and keep a handle to it. We don't load it all into RAM.
hf_game_data = h5py.File(GAME_DATA_PATH, 'r')
# Read all star IDs into memory to populate the dropdown
all_star_ids = np.array([s.decode('utf-8') for s in hf_game_data['star_ids'][:]])
stars_available = list(np.unique(all_star_ids))
print(f"✅ Successfully opened HDF5 game data. Found {len(stars_available)} unique stars.")
GAME_DATA_LOADED = True
except Exception as e:
print(f"ℹ️ NOTE: Could not load game data from '{GAME_DATA_PATH}'. The 'Play the Game' tab will be disabled. Error: {e}")
GAME_DATA_LOADED = False
stars_available = []
# ============================================================================
# GAME STATE (Only for the "Play the Game" tab)
# ============================================================================
game_state = {
'score': 0, 'attempts': 0, 'correct': 0, 'current_star': None, 'user_selections': [],
}
# ============================================================================
# BACKEND FUNCTIONS
# ============================================================================
### --- Functions for the "AI Prediction" Tab (Unchanged) --- ###
def extract_features_for_prediction(time, flux, config):
"""Extracts features from clean data. Must match the training script."""
WINDOW = config['window_size']
features_list = []
detrended = flux / median_filter(flux, size=1001, mode='reflect')
for i in range(WINDOW, len(detrended) - WINDOW):
window = detrended[i-WINDOW:i+WINDOW]
features = [
detrended[i], np.mean(window), detrended[i] - np.mean(window),
np.min(window), detrended[i] / np.mean(window), np.std(window), skew(window),
]
features_list.append(features)
X_new = np.array(features_list)
valid_time = time[WINDOW:len(detrended) - WINDOW]
valid_flux = detrended[WINDOW:len(detrended) - WINDOW]
return X_new, valid_time, valid_flux
def predict_on_upload(uploaded_file, probability_threshold):
"""Full pipeline: loads, preprocesses, and predicts on an uploaded FITS file."""
if not MODEL_LOADED: return None, "❌ AI Model is not loaded.", "Error"
if uploaded_file is None: return None, "Please upload a FITS file first.", "Waiting for file"
try:
lc_raw = lk.read(uploaded_file.name)
lc_unitless = lk.LightCurve(time=lc_raw.time.value, flux=lc_raw.flux.value)
lc_clean = lc_unitless.remove_nans().remove_outliers(sigma=5)
lc_flat = lc_clean.flatten(window_length=1001)
X_new, time_processed, flux_processed = extract_features_for_prediction(lc_flat.time.value, lc_flat.flux.value, model_config)
probabilities = xgb_model.predict_proba(X_new)[:, 1]
transit_indices = np.where(probabilities > probability_threshold)[0]
# Plotting
fig, ax1 = plt.subplots(figsize=(16, 7))
ax1.plot(time_processed, flux_processed, 'k.', markersize=2, alpha=0.3, label='Processed Flux')
if len(transit_indices) > 0:
ax1.scatter(time_processed[transit_indices], flux_processed[transit_indices], c='red', s=25, alpha=0.9, label='AI Detected Transit', zorder=10)
ax1.set_xlabel('Time (BTJD)'); ax1.set_ylabel('Normalized Brightness'); ax1.legend(loc='upper left'); ax1.grid(True, alpha=0.3)
ax2 = ax1.twinx()
ax2.plot(time_processed, probabilities, 'b-', alpha=0.6, lw=1.5, label='AI Confidence Score')
ax2.set_ylabel('Model Probability (Confidence)', color='blue'); ax2.tick_params(axis='y', labelcolor='blue'); ax2.set_ylim(0, 1)
ax2.axhline(y=probability_threshold, color='r', linestyle='--', alpha=0.7, label='Current Threshold'); ax2.legend(loc='upper right')
fig.suptitle(f'AI Prediction for {lc_raw.label}', fontsize=16, fontweight='bold'); plt.tight_layout(rect=[0, 0.03, 1, 0.95])
info = f"## AI Prediction Results\n**Star:** {lc_raw.label}\n**AI Detected Transit Points:** {len(transit_indices)}"
return fig, info, f"✅ Prediction successful: {len(transit_indices)} points found."
except Exception as e: return None, f"ERROR: {str(e)}\n{traceback.format_exc()}", "Error during prediction"
### --- Functions for the "Play the Game" Tab (Modified for HDF5) --- ###
def load_star_for_game(star_name):
"""Loads a pre-processed star for the game directly from the HDF5 file."""
if not GAME_DATA_LOADED: return None, "ERROR: Game data not loaded.", ""
try:
# <-- CHANGE: Read data for one star from the HDF5 file
# Find the indices for the selected star
indices = np.where(all_star_ids == star_name)[0]
if len(indices) == 0:
return None, f"Could not find data for {star_name}", "Error"
# Slice the HDF5 datasets to get data for this star only
features = hf_game_data['X'][indices]
labels = hf_game_data['y'][indices]
# For the plot, we use the first feature (detrended flux) and a simple index for the x-axis
flux = features[:, 0]
time_index = np.arange(len(flux)) # The x-axis is now a data point index
true_transits = (labels == 1)
game_state.update({'current_star': star_name, 'time': time_index, 'flux': flux, 'true_transits': true_transits, 'user_selections': []})
fig, ax = plt.subplots(figsize=(16, 6))
ax.plot(time_index, flux, 'k.', markersize=2, alpha=0.4)
ax.set_title(f'{star_name} - Find the transit dips!', fontsize=15, fontweight='bold')
ax.set_xlabel('Data Point Index'); ax.set_ylabel('Normalized Brightness'); ax.grid(True, alpha=0.3); plt.tight_layout()
# <-- CHANGE: Updated instructions
instructions = f"## How to Play\n**Star:** {star_name}\n\n1. Look for U-shaped dips in the data points.\n2. Enter the START and END index of a dip.\n3. Click 'Add Selection'.\n4. Click 'Check Answer' when you're ready."
return fig, instructions, f"✅ Loaded {star_name}"
except Exception as e: return None, f"ERROR: {str(e)}\n{traceback.format_exc()}", "Error"
def add_selection(start, end):
"""Adds a user's index selection for the game."""
if game_state['current_star'] is None: return "⚠️ Load a star first"
if start is None or end is None or start >= end: return "⚠️ Invalid index range"
game_state['user_selections'].append((float(start), float(end)))
return f"✅ Added selection #{len(game_state['user_selections'])}: {start:.0f} - {end:.0f}"
def check_answer():
"""Checks the user's answer in the game."""
if not game_state.get('user_selections'): return None, "⚠️ Add at least one selection", ""
time, flux, true_transits = game_state['time'], game_state['flux'], game_state['true_transits']
# Check if any user-selected range contains more than 5 true transit points
found = any(np.sum(true_transits[(time >= t_start) & (time <= t_end)]) > 5 for t_start, t_end in game_state['user_selections'])
game_state['attempts'] += 1
if found:
game_state['correct'] += 1; game_state['score'] += 100; result, color = "🎉 CORRECT!", 'green'
else:
game_state['score'] = max(0, game_state['score'] - 20); result, color = "❌ Not quite", 'orange'
fig, ax = plt.subplots(figsize=(16, 6))
ax.plot(time, flux, 'k.', markersize=2, alpha=0.3, label='Data')
ax.scatter(time[true_transits], flux[true_transits], c='lime', s=15, alpha=0.7, label='Actual Transits')
for i, (t_start, t_end) in enumerate(game_state['user_selections']):
ax.axvspan(t_start, t_end, alpha=0.3, color=color, label='Your Selection' if i==0 else '')
ax.legend(); ax.grid(True, alpha=0.3); ax.set_title(result); plt.tight_layout()
stats = f"## {result}\n**Score:** {game_state['score']}\n**Accuracy:** {game_state['correct']}/{game_state['attempts']} ({100*game_state['correct']/max(1,game_state['attempts']):.0f}%)"
return fig, stats, f"Score: {game_state['score']}"
# ============================================================================
# GRADIO INTERFACE (with Tabs)
# ============================================================================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🔭 Exoplanet Transit Hunter\n### Find planets with a game or let our AI do it for you!")
with gr.Tabs():
# --- Tab 1: AI Prediction ---
with gr.TabItem("AI Prediction"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1️⃣ Upload FITS File")
file_upload = gr.File(label="Upload a TESS FITS light curve file", file_types=['.fits'])
gr.Markdown("### 2️⃣ Adjust Sensitivity")
threshold_slider = gr.Slider(minimum=0.05, maximum=0.95, value=0.50, step=0.05, label="Probability Threshold")
predict_btn = gr.Button("Run AI Prediction", variant="primary")
status_predict = gr.Textbox(label="Status", lines=3)
with gr.Column(scale=2):
plot_predict = gr.Plot()
info_predict = gr.Markdown()
# --- Tab 2: Play the Game ---
with gr.TabItem("Play the Game"):
if not GAME_DATA_LOADED:
# <-- CHANGE: Updated message for HDF5
gr.Markdown("## Game data (`preprocessed_transit_data_v5.h5`) not found.\nThis tab is disabled. Please place the file in the same directory to enable the game.")
else:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1️⃣ Choose Star")
star_drop = gr.Dropdown(choices=stars_available, label="Star", value=stars_available[0])
load_btn = gr.Button("Load Star", variant="primary")
gr.Markdown("### 2️⃣ Mark Dips")
# <-- CHANGE: Updated labels for index instead of time
start_box = gr.Number(label="Start Index")
end_box = gr.Number(label="End Index")
add_btn = gr.Button("Add Selection")
status_game = gr.Textbox(label="Status", lines=2)
gr.Markdown("### 3️⃣ Check")
check_btn = gr.Button("Check Answer", variant="primary")
with gr.Column(scale=2):
plot_game = gr.Plot()
info_game = gr.Markdown()
# --- Wire up the components for both tabs ---
# AI Prediction Tab
predict_btn.click(predict_on_upload, inputs=[file_upload, threshold_slider], outputs=[plot_predict, info_predict, status_predict])
# Game Tab (only if the data was loaded)
if GAME_DATA_LOADED:
load_btn.click(load_star_for_game, star_drop, [plot_game, info_game, status_game])
add_btn.click(add_selection, [start_box, end_box], status_game)
check_btn.click(check_answer, outputs=[plot_game, info_game, status_game])
demo.launch(share=True, debug=True)