exoplanet_model / app.py
mibrahimzia's picture
Upload folder using huggingface_hub
5794601 verified
# app.py
import gradio as gr
import pandas as pd
import joblib
# Load trained model
model = joblib.load("exoplanet_model.pkl")
# Features (must match training order)
FEATURES = ["orbital_period", "transit_duration", "planet_radius", "star_temp", "star_radius"]
# Mapping labels
LABELS = {1: "βœ… Confirmed Exoplanet", 0: "🟑 Candidate", -1: "❌ False Positive"}
# ---- Single Prediction ----
def predict_single(orbital_period, transit_duration, planet_radius, star_temp, star_radius):
df = pd.DataFrame([[orbital_period, transit_duration, planet_radius, star_temp, star_radius]],
columns=FEATURES)
pred = model.predict(df)[0]
return LABELS.get(pred, "Unknown")
# ---- Bulk Prediction ----
def predict_bulk(file):
df = pd.read_csv(file.name)
# Ensure required columns exist
missing_cols = [col for col in FEATURES if col not in df.columns]
if missing_cols:
return f"Missing columns in uploaded file: {missing_cols}", None
preds = model.predict(df[FEATURES])
df["prediction"] = [LABELS.get(p, "Unknown") for p in preds]
# Save results
output_file = "bulk_predictions.csv"
df.to_csv(output_file, index=False)
return df.head(10), output_file # Show preview + downloadable file
# ---- Gradio UI ----
with gr.Blocks() as demo:
gr.Markdown("## πŸ”­ Exoplanet Classifier – NASA Space Apps Challenge 2025")
with gr.Tab("Single Prediction"):
inputs = [
gr.Number(label="Orbital Period (days)"),
gr.Number(label="Transit Duration (hrs)"),
gr.Number(label="Planet Radius (Earth radii)"),
gr.Number(label="Star Temperature (K)"),
gr.Number(label="Star Radius (Solar radii)")
]
output = gr.Label(label="Prediction")
gr.Interface(fn=predict_single, inputs=inputs, outputs=output).render()
with gr.Tab("Bulk CSV Prediction"):
file_input = gr.File(label="Upload CSV", file_types=[".csv"])
preview = gr.Dataframe(label="Preview (first 10 rows)")
download = gr.File(label="Download Predictions")
btn = gr.Button("Run Bulk Classification")
btn.click(fn=predict_bulk, inputs=file_input, outputs=[preview, download])
if __name__ == "__main__":
demo.launch()