Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from io import BytesIO | |
| import base64 | |
| import traceback | |
| # Load trained model and preprocessing tools | |
| model, imp, scaler = joblib.load("optimized_exoplanet_model.pkl") | |
| def predict_exoplanets(file): | |
| try: | |
| print("πΉ Loading CSV...") | |
| df = pd.read_csv(file.name) | |
| print("β CSV loaded. Columns:", df.columns.tolist()) | |
| features = ["koi_period","koi_duration","koi_depth","koi_prad","koi_impact"] | |
| # Check for missing columns and add them if missing | |
| missing_cols = [col for col in features if col not in df.columns] | |
| for col in missing_cols: | |
| df[col] = np.nan # <-- use np.nan instead of pd.NA | |
| print("β Missing columns handled:", missing_cols) | |
| warning_msg = "" | |
| if missing_cols: | |
| warning_msg = f"β οΈ Missing columns detected and filled with NaN: {', '.join(missing_cols)}" | |
| # Prepare features for model | |
| X = df[features] | |
| print("πΉ Features extracted:\n", X.head()) | |
| X_imp = imp.transform(X) | |
| print("πΉ Imputation done.") | |
| X_scaled = scaler.transform(X_imp) | |
| print("πΉ Scaling done.") | |
| preds = model.predict(X_scaled) | |
| print("πΉ Predictions done:", preds[:10]) | |
| label_map = {0: "FALSE POSITIVE", 1: "CANDIDATE", 2: "CONFIRMED"} | |
| df["Prediction"] = [label_map[p] for p in preds] | |
| # --- Visualization: Prediction counts --- | |
| plt.figure(figsize=(6,4)) | |
| sns.countplot(x="Prediction", data=df, palette="viridis") | |
| plt.title("Prediction Distribution") | |
| plt.tight_layout() | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| img_b64 = base64.b64encode(buf.read()).decode() | |
| img_html = f"<img src='data:image/png;base64,{img_b64}' width='400'>" | |
| # Add warning to visualization if any | |
| if warning_msg: | |
| img_html = f"<p style='color:red'>{warning_msg}</p>" + img_html | |
| return df, img_html | |
| except Exception as e: | |
| error_msg = f"Error processing file: {e}\n\nTraceback:\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return error_msg, None | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=predict_exoplanets, | |
| inputs=gr.File(label="Upload NASA Exoplanet CSV"), | |
| outputs=[gr.Dataframe(label="Predictions"), gr.HTML(label="Prediction Chart")], | |
| title="π NASA Exoplanet Classifier (Optimized + Visualization)", | |
| description="Upload a CSV containing exoplanet data β the AI will classify each as FALSE POSITIVE, CANDIDATE, or CONFIRMED. A chart shows the class distribution. Missing columns are automatically handled." | |
| ) | |
| demo.launch(share=True) | |