nadahesham's picture
Update app.py
e0e3f52 verified
import gradio as gr
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import os, random, math
# --------------------------------------------------
# Paths
# --------------------------------------------------
PATH_FINAL = "final_combined_full_table.csv"
PATH_XTEST = "X_test_target.csv"
PATH_OUTPUT = "output.csv"
GLOBE_R = 6371 # Earth radius
# --------------------------------------------------
# Utility: seeded random vector
# --------------------------------------------------
def seeded_vec(seed):
r = random.Random(seed)
phi = r.random() * 2 * math.pi
cos_t = 2 * r.random() - 1
sin_t = math.sqrt(1 - cos_t * cos_t)
return np.array([sin_t * math.cos(phi), sin_t * math.sin(phi), cos_t])
# --------------------------------------------------
# Load data
# --------------------------------------------------
df_final = pd.read_csv(PATH_FINAL)
df_xtest = pd.read_csv(PATH_XTEST)
df_output = pd.read_csv(PATH_OUTPUT)
# --------------------------------------------------
# Synthesize positions if missing
# --------------------------------------------------
if not {"relative_position_r", "relative_position_t", "relative_position_n"}.issubset(df_final.columns):
mags = df_final["rel_pos_mag"].fillna(1).values
xs, ys, zs = [], [], []
for cid, mag in zip(df_final["conjunction_id"], mags):
v = seeded_vec(int(cid))
v *= float(mag)
xs.append(v[0])
ys.append(v[1])
zs.append(v[2])
df_final["relative_position_r"] = xs
df_final["relative_position_t"] = ys
df_final["relative_position_n"] = zs
# --------------------------------------------------
# Color mapping
# --------------------------------------------------
def alert_color(level):
L = str(level).upper()
if "HIGH" in L:
return "#ff1744"
if "MEDIUM" in L:
return "#ff9100"
if "LOW" in L:
return "#00e676"
return "#9e9e9e"
df_final["_color"] = df_final["dl_score_level"].apply(alert_color)
# --------------------------------------------------
# 3D Earth
# --------------------------------------------------
def make_earth():
u = np.linspace(0, 2 * np.pi, 72)
v = np.linspace(0, np.pi, 36)
u, v = np.meshgrid(u, v)
x = GLOBE_R * np.cos(u) * np.sin(v)
y = GLOBE_R * np.sin(u) * np.sin(v)
z = GLOBE_R * np.cos(v)
return go.Surface(
x=x,
y=y,
z=z,
showscale=False,
opacity=0.90,
colorscale=[[0, "black"], [1, "#1e3a8a"]],
)
# --------------------------------------------------
# Build 3D Visualization
# --------------------------------------------------
def build_orbit_plot(orbit, frame, highlight):
sub = df_final.copy()
if orbit != "ALL":
sub = sub[sub["orbit_regime"] == orbit]
sub = sub[sub["conjunction_id"] % 200 == frame]
SCALE = 4.5
fig = go.Figure()
fig.add_trace(make_earth())
xs = GLOBE_R + sub["relative_position_r"].astype(float) * SCALE
ys = GLOBE_R + sub["relative_position_t"].astype(float) * SCALE
zs = GLOBE_R + sub["relative_position_n"].astype(float) * SCALE
fig.add_trace(
go.Scatter3d(
x=xs,
y=ys,
z=zs,
mode="markers",
marker=dict(size=3, color=sub["_color"], opacity=0.95),
text=sub["conjunction_id"].astype(str),
hoverinfo="text",
name="Conjunction Events",
)
)
if highlight:
try:
cid = int(highlight)
t = df_final[df_final["conjunction_id"] == cid]
if not t.empty:
xs2 = GLOBE_R + t["relative_position_r"].astype(float) * SCALE
ys2 = GLOBE_R + t["relative_position_t"].astype(float) * SCALE
zs2 = GLOBE_R + t["relative_position_n"].astype(float) * SCALE
fig.add_trace(
go.Scatter3d(
x=xs2,
y=ys2,
z=zs2,
mode="lines+markers",
line=dict(width=6, color="yellow"),
marker=dict(size=6, color="yellow"),
name=f"Track {cid}",
)
)
except:
pass
fig.update_layout(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
),
template="plotly_dark",
height=650,
margin=dict(l=0, r=0, t=40, b=0),
)
return fig
# --------------------------------------------------
# Top alerts table
# --------------------------------------------------
def top_alerts(n):
cols = [
"conjunction_id",
"orbit_regime",
"dl_score_fixed",
"dl_score_level",
"ppo_action_name",
"final_mode_fixed",
]
if n <= 0:
return pd.DataFrame()
return df_final.sort_values("dl_score_fixed", ascending=False)[cols].head(n)
# --------------------------------------------------
# Inspector
# --------------------------------------------------
def inspect_event(cid):
try:
cid = int(cid)
except:
return pd.DataFrame(), "Invalid ID"
row = df_final[df_final["conjunction_id"] == cid]
if row.empty:
return pd.DataFrame(), "Not Found"
r = row.iloc[0]
txt = f"""
Orbit Regime: {r.get('orbit_regime')}
Miss Distance: {r.get('miss_distance')}
DL Score: {r.get('dl_score_fixed')} ({r.get('dl_score_level')})
PPO Action: {r.get('ppo_action_name')}
Final Mode: {r.get('final_mode_fixed')}
"""
return row, txt
# --------------------------------------------------
# Custom Dark Mode CSS
# --------------------------------------------------
dark_css = """
body { background-color: #111 !important; color: white !important; }
.gradio-container { background-color: #111 !important; }
label, input, textarea { color: white !important; }
"""
# --------------------------------------------------
# UI
# --------------------------------------------------
orbit_list = ["ALL"] + sorted(df_final["orbit_regime"].dropna().unique().tolist())
with gr.Blocks(title="Space Collision Dashboard") as demo:
gr.HTML(f"<style>{dark_css}</style>")
gr.Markdown(
"<h1 style='text-align:center;color:white;'>🚀 Space Collision Dashboard — Premium Dark Mode</h1>"
)
with gr.Row():
with gr.Column(scale=2):
orbit = gr.Dropdown(orbit_list, label="Orbit Regime")
frame = gr.Slider(0, 199, value=0, step=1, label="Frame Index")
highlight = gr.Textbox(label="Highlight Conjunction ID")
plot = gr.Plot(label="3D Orbit Visualizer")
btn_plot = gr.Button("Render 3D View")
btn_plot.click(
fn=build_orbit_plot,
inputs=[orbit, frame, highlight],
outputs=plot,
)
with gr.Column(scale=1):
gr.Markdown("### Alert Statistics")
high = int((df_final["dl_score_level"] == "HIGH").sum())
med = int((df_final["dl_score_level"] == "MEDIUM").sum())
low = int((df_final["dl_score_level"] == "LOW").sum())
gr.Markdown(
f"""
- **High Alerts:** {high}
- **Medium Alerts:** {med}
- **Low Alerts:** {low}
"""
)
gr.Markdown("### Top Alerts Table")
top_n = gr.Slider(5, 50, value=10, step=5, label="Top N Alerts")
top_table = gr.Dataframe()
btn_top = gr.Button("Load Top Alerts")
btn_top.click(fn=top_alerts, inputs=top_n, outputs=top_table)
gr.Markdown("### Event Inspector")
cid_box = gr.Textbox(label="Enter Conjunction ID")
inspect_table = gr.Dataframe()
inspect_text = gr.Textbox(label="Details")
btn_insp = gr.Button("Inspect")
btn_insp.click(
fn=inspect_event, inputs=cid_box, outputs=[inspect_table, inspect_text]
)
demo.launch()