nadahesham's picture
Update app.py
732b44e verified
# app.py
import streamlit as st
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from PIL import Image
import os
os.environ["STREAMLIT_SERVER_PORT"] = "7860"
os.environ["STREAMLIT_SERVER_ADDRESS"] = "0.0.0.0"
st.set_page_config(layout="wide", page_title="Space Collision 3D Dashboard")
st.title("🌍 Space Collision 3D Dashboard")
PATH_FINAL = "final_combined_full_table.csv"
PATH_XTEST = "X_test_target.csv"
PATH_OUTPUT = "output.csv"
@st.cache_data(show_spinner=False)
def load_final(path):
if not os.path.exists(path):
return None
return pd.read_csv(path)
@st.cache_data(show_spinner=False)
def load_x(path):
if not os.path.exists(path):
return None
return pd.read_csv(path)
@st.cache_data(show_spinner=False)
def load_out(path):
if not os.path.exists(path):
return None
return pd.read_csv(path)
df_final = load_final(PATH_FINAL)
df_x = load_x(PATH_XTEST)
df_out = load_out(PATH_OUTPUT)
if df_final is None:
st.error(f"File not found: {PATH_FINAL}.")
st.stop()
st.sidebar.header("Controls")
orbit_opts = ["ALL"] + sorted(df_final["orbit_regime"].dropna().unique().tolist())
orbit_sel = st.sidebar.selectbox("Orbit regime", orbit_opts)
max_points = st.sidebar.slider("Max points (sample)", 500, 20000, 4000, step=500)
frame_idx = st.sidebar.slider("Animation frame", 0, 300, 0)
track_id = st.sidebar.text_input("Track conjunction_id")
show_globe_texture = st.sidebar.checkbox("Show Earth texture", value=False)
earth_texture_path = "earth.jpg"
def alert_color(row):
lvl = row.get("dl_score_level") or row.get("risk_level_fixed") or "UNKNOWN"
if isinstance(lvl, str):
L = lvl.upper()
if "HIGH" in L: return "red"
if "MEDIUM" in L: return "orange"
if "LOW" in L: return "green"
return "gray"
df_final["_alert_color"] = df_final.apply(alert_color, axis=1)
if df_x is not None and {"conjunction_id","relative_position_r","relative_position_t","relative_position_n"}.issubset(df_x.columns):
df_pos = df_x.copy()
pos_source = "X_test_target.csv"
else:
required = {"relative_position_r","relative_position_t","relative_position_n"}
if not required.issubset(df_final.columns):
st.error("Missing 3D position columns.")
st.stop()
df_pos = df_final.copy()
pos_source = "final_combined_full_table.csv"
if "_frame" not in df_pos.columns:
df_pos["_frame"] = df_pos["conjunction_id"].astype(int) % 300
df_sample = df_pos if max_points >= len(df_pos) else df_pos.sample(max_points, random_state=frame_idx)
if orbit_sel != "ALL" and "orbit_regime" in df_pos.columns:
df_sample = df_sample[df_sample["orbit_regime"] == orbit_sel]
if "_frame" in df_sample.columns:
window = 2
df_sample = df_sample[(df_sample["_frame"] >= frame_idx - window) & (df_sample["_frame"] <= frame_idx + window)]
def make_earth_mesh(radius=6371, resolution=50, texture_path=None):
u = np.linspace(0, 2 * np.pi, resolution)
v = np.linspace(0, np.pi, resolution//2)
u, v = np.meshgrid(u, v)
x = radius * np.cos(u) * np.sin(v)
y = radius * np.sin(u) * np.sin(v)
z = radius * np.cos(v)
surface = go.Surface(x=x, y=y, z=z, showscale=False, opacity=0.6)
if texture_path and os.path.exists(texture_path):
try:
img = Image.open(texture_path).convert("RGB")
img = img.resize((resolution, resolution//2))
arr = np.asarray(img).astype(np.float32) / 255.0
lum = arr.mean(axis=2)
surface.surfacecolor = lum
surface.colorscale = "Gray"
surface.opacity = 1.0
except:
pass
else:
surface.surfacecolor = np.zeros_like(x)
surface.colorscale = [[0, "rgb(20,30,100)"], [1, "rgb(60,70,150)"]]
surface.opacity = 0.25
return surface
def build_3d_figure(df_points, show_earth=True, texture=None, highlight_id=None):
fig = go.Figure()
if show_earth:
earth = make_earth_mesh(radius=6371, resolution=72, texture_path=texture)
fig.add_trace(earth)
base_r = 6371.0
x = base_r + (df_points["relative_position_r"].fillna(0).astype(float))
y = base_r + (df_points["relative_position_t"].fillna(0).astype(float))
z = base_r + (df_points["relative_position_n"].fillna(0).astype(float))
colors = df_points.get("_alert_color", ["gray"]*len(df_points))
texts = df_points["conjunction_id"].astype(str) if "conjunction_id" in df_points else None
fig.add_trace(go.Scatter3d(
x=x, y=y, z=z, mode="markers",
marker=dict(size=3, color=colors, opacity=0.9),
text=texts,
hovertemplate="id: %{text}<br>r: %{x:.1f}<br>t: %{y:.1f}<br>n: %{z:.1f}<extra></extra>"
))
if highlight_id:
try:
cid = int(highlight_id)
sub = df_pos[df_pos["conjunction_id"] == cid].sort_values("_frame")
if not sub.empty:
xs = base_r + sub["relative_position_r"].astype(float)
ys = base_r + sub["relative_position_t"].astype(float)
zs = base_r + sub["relative_position_n"].astype(float)
fig.add_trace(go.Scatter3d(
x=xs, y=ys, z=zs,
mode="lines+markers",
line=dict(width=3, color="yellow"),
marker=dict(size=2, color="yellow"),
name=f"track {cid}"
))
except:
pass
fig.update_layout(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
aspectmode='data'
),
margin=dict(r=0, l=0, t=30, b=0),
height=700,
title="3D Orbit View"
)
return fig
st.sidebar.markdown("### Quick stats")
st.sidebar.write("Events:", len(df_final))
if "dl_score_level" in df_final.columns:
st.sidebar.write("DL HIGH:", int((df_final["dl_score_level"] == "HIGH").sum()))
st.sidebar.write("Auto maneuvers:", int((df_final.get("final_mode_fixed","") == "AUTO_MANEUVER").sum()))
col1, col2 = st.columns([2,1])
with col1:
st.subheader(f"3D Orbit map — source: {pos_source}")
fig3d = build_3d_figure(
df_sample,
show_earth=True,
texture=(earth_texture_path if show_globe_texture else None),
highlight_id=track_id
)
st.plotly_chart(fig3d, use_container_width=True)
with col2:
st.subheader("Top alerts")
topn = st.slider("Top N", 5, 50, 10)
if "dl_score_fixed" in df_final.columns:
top = df_final.sort_values("dl_score_fixed", ascending=False).head(topn)
st.dataframe(top[[
"conjunction_id","orbit_regime","dl_score_fixed","dl_score_level","ppo_action_name","final_mode_fixed"
]].reset_index(drop=True))
st.markdown("### Event Inspector")
conj = st.text_input("Conjunction ID")
if conj:
try:
conj_int = int(conj)
row = df_final[df_final["conjunction_id"] == conj_int]
if not row.empty:
r = row.iloc[0]
st.write("Orbit:", r.get("orbit_regime"))
st.write("Miss distance:", r.get("miss_distance"))
st.write("DL score:", r.get("dl_score_fixed"), r.get("dl_score_level"))
st.write("PPO action:", r.get("ppo_action_name"))
st.write("Final decision:", r.get("final_mode_fixed"))
if {"miss_distance","collision_probability"}.issubset(df_final.columns):
fig = px.scatter(
row,
x="miss_distance",
y="collision_probability",
title="Miss distance vs Collision probability"
)
st.plotly_chart(fig, use_container_width=True)
except:
st.error("Invalid ID.")