mozzic's picture
ok
15418a7 verified
import gradio as gr
import pandas as pd
import numpy as np
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
import base64
import traceback
# Load trained model and preprocessing tools
model, imp, scaler = joblib.load("optimized_exoplanet_model.pkl")
def predict_exoplanets(file):
try:
print("πŸ”Ή Loading CSV...")
df = pd.read_csv(file.name)
print("βœ… CSV loaded. Columns:", df.columns.tolist())
features = ["koi_period","koi_duration","koi_depth","koi_prad","koi_impact"]
# Check for missing columns and add them if missing
missing_cols = [col for col in features if col not in df.columns]
for col in missing_cols:
df[col] = np.nan # <-- use np.nan instead of pd.NA
print("βœ… Missing columns handled:", missing_cols)
warning_msg = ""
if missing_cols:
warning_msg = f"⚠️ Missing columns detected and filled with NaN: {', '.join(missing_cols)}"
# Prepare features for model
X = df[features]
print("πŸ”Ή Features extracted:\n", X.head())
X_imp = imp.transform(X)
print("πŸ”Ή Imputation done.")
X_scaled = scaler.transform(X_imp)
print("πŸ”Ή Scaling done.")
preds = model.predict(X_scaled)
print("πŸ”Ή Predictions done:", preds[:10])
label_map = {0: "FALSE POSITIVE", 1: "CANDIDATE", 2: "CONFIRMED"}
df["Prediction"] = [label_map[p] for p in preds]
# --- Visualization: Prediction counts ---
plt.figure(figsize=(6,4))
sns.countplot(x="Prediction", data=df, palette="viridis")
plt.title("Prediction Distribution")
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
img_b64 = base64.b64encode(buf.read()).decode()
img_html = f"<img src='data:image/png;base64,{img_b64}' width='400'>"
# Add warning to visualization if any
if warning_msg:
img_html = f"<p style='color:red'>{warning_msg}</p>" + img_html
return df, img_html
except Exception as e:
error_msg = f"Error processing file: {e}\n\nTraceback:\n{traceback.format_exc()}"
print(error_msg)
return error_msg, None
# Gradio interface
demo = gr.Interface(
fn=predict_exoplanets,
inputs=gr.File(label="Upload NASA Exoplanet CSV"),
outputs=[gr.Dataframe(label="Predictions"), gr.HTML(label="Prediction Chart")],
title="πŸš€ NASA Exoplanet Classifier (Optimized + Visualization)",
description="Upload a CSV containing exoplanet data β€” the AI will classify each as FALSE POSITIVE, CANDIDATE, or CONFIRMED. A chart shows the class distribution. Missing columns are automatically handled."
)
demo.launch(share=True)