Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| NASA Space Apps 2025 - Comprehensive Exoplanet Hunter Interface (HDF5 Version) | |
| - Tab 1: A game to find transits in the large, pre-processed HDF5 dataset. | |
| - Tab 2: An AI-powered tool to predict transits in user-uploaded FITS files. | |
| - Uses the robust v4 model. | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import pickle | |
| import matplotlib.pyplot as plt | |
| import traceback | |
| import lightkurve as lk | |
| from scipy.ndimage import median_filter | |
| from scipy.stats import skew | |
| import h5py # <-- CHANGE: Import h5py for HDF5 file handling | |
| # ============================================================================ | |
| # LOAD ALL NECESSARY DATA AND MODELS | |
| # ============================================================================ | |
| # --- Path Configuration --- | |
| # <-- CHANGE: Point to your new HDF5 data file for the game | |
| GAME_DATA_PATH = "preprocessed_transit_data_v5.h5" | |
| # <-- CHANGE: Point to your latest production model | |
| MODEL_PATH = "exoplanet_detector_v4.pkl" | |
| # --- Load the AI Model --- | |
| try: | |
| with open(MODEL_PATH, 'rb') as f: | |
| production_model = pickle.load(f) | |
| xgb_model = production_model['xgboost_classifier'] | |
| model_config = production_model['config'] | |
| print(f"✅ Successfully loaded AI model '{MODEL_PATH}'. Ready for predictions.") | |
| MODEL_LOADED = True | |
| except Exception as e: | |
| print(f"❌ ERROR: Could not load AI model from '{MODEL_PATH}'. Prediction tab will fail. Error: {e}") | |
| MODEL_LOADED = False | |
| # --- Load the Game Data (from HDF5) --- | |
| # <-- CHANGE: New loading logic for the large HDF5 file | |
| try: | |
| # Open the HDF5 file and keep a handle to it. We don't load it all into RAM. | |
| hf_game_data = h5py.File(GAME_DATA_PATH, 'r') | |
| # Read all star IDs into memory to populate the dropdown | |
| all_star_ids = np.array([s.decode('utf-8') for s in hf_game_data['star_ids'][:]]) | |
| stars_available = list(np.unique(all_star_ids)) | |
| print(f"✅ Successfully opened HDF5 game data. Found {len(stars_available)} unique stars.") | |
| GAME_DATA_LOADED = True | |
| except Exception as e: | |
| print(f"ℹ️ NOTE: Could not load game data from '{GAME_DATA_PATH}'. The 'Play the Game' tab will be disabled. Error: {e}") | |
| GAME_DATA_LOADED = False | |
| stars_available = [] | |
| # ============================================================================ | |
| # GAME STATE (Only for the "Play the Game" tab) | |
| # ============================================================================ | |
| game_state = { | |
| 'score': 0, 'attempts': 0, 'correct': 0, 'current_star': None, 'user_selections': [], | |
| } | |
| # ============================================================================ | |
| # BACKEND FUNCTIONS | |
| # ============================================================================ | |
| ### --- Functions for the "AI Prediction" Tab (Unchanged) --- ### | |
| def extract_features_for_prediction(time, flux, config): | |
| """Extracts features from clean data. Must match the training script.""" | |
| WINDOW = config['window_size'] | |
| features_list = [] | |
| detrended = flux / median_filter(flux, size=1001, mode='reflect') | |
| for i in range(WINDOW, len(detrended) - WINDOW): | |
| window = detrended[i-WINDOW:i+WINDOW] | |
| features = [ | |
| detrended[i], np.mean(window), detrended[i] - np.mean(window), | |
| np.min(window), detrended[i] / np.mean(window), np.std(window), skew(window), | |
| ] | |
| features_list.append(features) | |
| X_new = np.array(features_list) | |
| valid_time = time[WINDOW:len(detrended) - WINDOW] | |
| valid_flux = detrended[WINDOW:len(detrended) - WINDOW] | |
| return X_new, valid_time, valid_flux | |
| def predict_on_upload(uploaded_file, probability_threshold): | |
| """Full pipeline: loads, preprocesses, and predicts on an uploaded FITS file.""" | |
| if not MODEL_LOADED: return None, "❌ AI Model is not loaded.", "Error" | |
| if uploaded_file is None: return None, "Please upload a FITS file first.", "Waiting for file" | |
| try: | |
| lc_raw = lk.read(uploaded_file.name) | |
| lc_unitless = lk.LightCurve(time=lc_raw.time.value, flux=lc_raw.flux.value) | |
| lc_clean = lc_unitless.remove_nans().remove_outliers(sigma=5) | |
| lc_flat = lc_clean.flatten(window_length=1001) | |
| X_new, time_processed, flux_processed = extract_features_for_prediction(lc_flat.time.value, lc_flat.flux.value, model_config) | |
| probabilities = xgb_model.predict_proba(X_new)[:, 1] | |
| transit_indices = np.where(probabilities > probability_threshold)[0] | |
| # Plotting | |
| fig, ax1 = plt.subplots(figsize=(16, 7)) | |
| ax1.plot(time_processed, flux_processed, 'k.', markersize=2, alpha=0.3, label='Processed Flux') | |
| if len(transit_indices) > 0: | |
| ax1.scatter(time_processed[transit_indices], flux_processed[transit_indices], c='red', s=25, alpha=0.9, label='AI Detected Transit', zorder=10) | |
| ax1.set_xlabel('Time (BTJD)'); ax1.set_ylabel('Normalized Brightness'); ax1.legend(loc='upper left'); ax1.grid(True, alpha=0.3) | |
| ax2 = ax1.twinx() | |
| ax2.plot(time_processed, probabilities, 'b-', alpha=0.6, lw=1.5, label='AI Confidence Score') | |
| ax2.set_ylabel('Model Probability (Confidence)', color='blue'); ax2.tick_params(axis='y', labelcolor='blue'); ax2.set_ylim(0, 1) | |
| ax2.axhline(y=probability_threshold, color='r', linestyle='--', alpha=0.7, label='Current Threshold'); ax2.legend(loc='upper right') | |
| fig.suptitle(f'AI Prediction for {lc_raw.label}', fontsize=16, fontweight='bold'); plt.tight_layout(rect=[0, 0.03, 1, 0.95]) | |
| info = f"## AI Prediction Results\n**Star:** {lc_raw.label}\n**AI Detected Transit Points:** {len(transit_indices)}" | |
| return fig, info, f"✅ Prediction successful: {len(transit_indices)} points found." | |
| except Exception as e: return None, f"ERROR: {str(e)}\n{traceback.format_exc()}", "Error during prediction" | |
| ### --- Functions for the "Play the Game" Tab (Modified for HDF5) --- ### | |
| def load_star_for_game(star_name): | |
| """Loads a pre-processed star for the game directly from the HDF5 file.""" | |
| if not GAME_DATA_LOADED: return None, "ERROR: Game data not loaded.", "" | |
| try: | |
| # <-- CHANGE: Read data for one star from the HDF5 file | |
| # Find the indices for the selected star | |
| indices = np.where(all_star_ids == star_name)[0] | |
| if len(indices) == 0: | |
| return None, f"Could not find data for {star_name}", "Error" | |
| # Slice the HDF5 datasets to get data for this star only | |
| features = hf_game_data['X'][indices] | |
| labels = hf_game_data['y'][indices] | |
| # For the plot, we use the first feature (detrended flux) and a simple index for the x-axis | |
| flux = features[:, 0] | |
| time_index = np.arange(len(flux)) # The x-axis is now a data point index | |
| true_transits = (labels == 1) | |
| game_state.update({'current_star': star_name, 'time': time_index, 'flux': flux, 'true_transits': true_transits, 'user_selections': []}) | |
| fig, ax = plt.subplots(figsize=(16, 6)) | |
| ax.plot(time_index, flux, 'k.', markersize=2, alpha=0.4) | |
| ax.set_title(f'{star_name} - Find the transit dips!', fontsize=15, fontweight='bold') | |
| ax.set_xlabel('Data Point Index'); ax.set_ylabel('Normalized Brightness'); ax.grid(True, alpha=0.3); plt.tight_layout() | |
| # <-- CHANGE: Updated instructions | |
| instructions = f"## How to Play\n**Star:** {star_name}\n\n1. Look for U-shaped dips in the data points.\n2. Enter the START and END index of a dip.\n3. Click 'Add Selection'.\n4. Click 'Check Answer' when you're ready." | |
| return fig, instructions, f"✅ Loaded {star_name}" | |
| except Exception as e: return None, f"ERROR: {str(e)}\n{traceback.format_exc()}", "Error" | |
| def add_selection(start, end): | |
| """Adds a user's index selection for the game.""" | |
| if game_state['current_star'] is None: return "⚠️ Load a star first" | |
| if start is None or end is None or start >= end: return "⚠️ Invalid index range" | |
| game_state['user_selections'].append((float(start), float(end))) | |
| return f"✅ Added selection #{len(game_state['user_selections'])}: {start:.0f} - {end:.0f}" | |
| def check_answer(): | |
| """Checks the user's answer in the game.""" | |
| if not game_state.get('user_selections'): return None, "⚠️ Add at least one selection", "" | |
| time, flux, true_transits = game_state['time'], game_state['flux'], game_state['true_transits'] | |
| # Check if any user-selected range contains more than 5 true transit points | |
| found = any(np.sum(true_transits[(time >= t_start) & (time <= t_end)]) > 5 for t_start, t_end in game_state['user_selections']) | |
| game_state['attempts'] += 1 | |
| if found: | |
| game_state['correct'] += 1; game_state['score'] += 100; result, color = "🎉 CORRECT!", 'green' | |
| else: | |
| game_state['score'] = max(0, game_state['score'] - 20); result, color = "❌ Not quite", 'orange' | |
| fig, ax = plt.subplots(figsize=(16, 6)) | |
| ax.plot(time, flux, 'k.', markersize=2, alpha=0.3, label='Data') | |
| ax.scatter(time[true_transits], flux[true_transits], c='lime', s=15, alpha=0.7, label='Actual Transits') | |
| for i, (t_start, t_end) in enumerate(game_state['user_selections']): | |
| ax.axvspan(t_start, t_end, alpha=0.3, color=color, label='Your Selection' if i==0 else '') | |
| ax.legend(); ax.grid(True, alpha=0.3); ax.set_title(result); plt.tight_layout() | |
| stats = f"## {result}\n**Score:** {game_state['score']}\n**Accuracy:** {game_state['correct']}/{game_state['attempts']} ({100*game_state['correct']/max(1,game_state['attempts']):.0f}%)" | |
| return fig, stats, f"Score: {game_state['score']}" | |
| # ============================================================================ | |
| # GRADIO INTERFACE (with Tabs) | |
| # ============================================================================ | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🔭 Exoplanet Transit Hunter\n### Find planets with a game or let our AI do it for you!") | |
| with gr.Tabs(): | |
| # --- Tab 1: AI Prediction --- | |
| with gr.TabItem("AI Prediction"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1️⃣ Upload FITS File") | |
| file_upload = gr.File(label="Upload a TESS FITS light curve file", file_types=['.fits']) | |
| gr.Markdown("### 2️⃣ Adjust Sensitivity") | |
| threshold_slider = gr.Slider(minimum=0.05, maximum=0.95, value=0.50, step=0.05, label="Probability Threshold") | |
| predict_btn = gr.Button("Run AI Prediction", variant="primary") | |
| status_predict = gr.Textbox(label="Status", lines=3) | |
| with gr.Column(scale=2): | |
| plot_predict = gr.Plot() | |
| info_predict = gr.Markdown() | |
| # --- Tab 2: Play the Game --- | |
| with gr.TabItem("Play the Game"): | |
| if not GAME_DATA_LOADED: | |
| # <-- CHANGE: Updated message for HDF5 | |
| gr.Markdown("## Game data (`preprocessed_transit_data_v5.h5`) not found.\nThis tab is disabled. Please place the file in the same directory to enable the game.") | |
| else: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1️⃣ Choose Star") | |
| star_drop = gr.Dropdown(choices=stars_available, label="Star", value=stars_available[0]) | |
| load_btn = gr.Button("Load Star", variant="primary") | |
| gr.Markdown("### 2️⃣ Mark Dips") | |
| # <-- CHANGE: Updated labels for index instead of time | |
| start_box = gr.Number(label="Start Index") | |
| end_box = gr.Number(label="End Index") | |
| add_btn = gr.Button("Add Selection") | |
| status_game = gr.Textbox(label="Status", lines=2) | |
| gr.Markdown("### 3️⃣ Check") | |
| check_btn = gr.Button("Check Answer", variant="primary") | |
| with gr.Column(scale=2): | |
| plot_game = gr.Plot() | |
| info_game = gr.Markdown() | |
| # --- Wire up the components for both tabs --- | |
| # AI Prediction Tab | |
| predict_btn.click(predict_on_upload, inputs=[file_upload, threshold_slider], outputs=[plot_predict, info_predict, status_predict]) | |
| # Game Tab (only if the data was loaded) | |
| if GAME_DATA_LOADED: | |
| load_btn.click(load_star_for_game, star_drop, [plot_game, info_game, status_game]) | |
| add_btn.click(add_selection, [start_box, end_box], status_game) | |
| check_btn.click(check_answer, outputs=[plot_game, info_game, status_game]) | |
| demo.launch(share=True, debug=True) |