Cosmos_Sentinel / app.py
Ryukijano's picture
Decrease ZeroGPU duration to 60s to accommodate quota limits
1baccce
try:
import spaces
except ImportError:
class _SpacesFallback:
@staticmethod
def GPU(fn=None, **_kwargs):
if fn is None:
def decorator(inner):
return inner
return decorator
return fn
spaces = _SpacesFallback()
import os
from pathlib import Path
if Path("/data").exists():
os.environ["HF_HOME"] = "/data/.huggingface"
import json
import os
import traceback
from pathlib import Path
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
def format_html_cards(badas_result, reason_result):
if not badas_result and not reason_result:
return "<div style='color: #94a3b8; font-style: italic;'>Run the cached sample or upload a video to execute the pipeline.</div>"
collision_triggered = bool((badas_result or {}).get("collision_detected"))
gate_color = "#22c55e" if collision_triggered else "#64748b"
gate_text = "Triggered" if collision_triggered else "Watching"
incident = ((reason_result or {}).get("incident_type") or "unclear").lower()
incident_colors = {
"collision": ("#7f1d1d", "#ef4444"),
"near_miss": ("#7c2d12", "#f97316"),
"hazard": ("#713f12", "#f59e0b"),
}
inc_bg, inc_accent = incident_colors.get(incident, ("#1e293b", "#94a3b8"))
severity = str((reason_result or {}).get("severity_label") or "unknown").lower()
severity_colors = {
"1": ("#14532d", "#22c55e"),
"2": ("#713f12", "#f59e0b"),
"3": ("#7c2d12", "#f97316"),
"4": ("#7f1d1d", "#ef4444"),
"5": ("#4c0519", "#e11d48"),
}
sev_bg, sev_accent = severity_colors.get(severity, ("#1e293b", "#94a3b8"))
risk_score = (reason_result or {}).get("risk_score", 0)
cards_html = f"""
<div style="display: flex; gap: 1rem; flex-wrap: wrap; margin-bottom: 1rem; font-family: sans-serif;">
<div style="flex: 1; min-width: 200px; background: #0f172a; border-left: 4px solid {gate_color}; padding: 1rem; border-radius: 0.5rem; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);">
<div style="color: #94a3b8; font-size: 0.875rem; text-transform: uppercase; font-weight: 600;">Collision Gate</div>
<div style="color: {gate_color}; font-size: 1.5rem; font-weight: 700; margin-top: 0.25rem;">{gate_text}</div>
<div style="color: #64748b; font-size: 0.75rem; margin-top: 0.25rem;">BADAS V-JEPA2</div>
</div>
<div style="flex: 1; min-width: 200px; background: {inc_bg}; border-left: 4px solid {inc_accent}; padding: 1rem; border-radius: 0.5rem; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);">
<div style="color: #94a3b8; font-size: 0.875rem; text-transform: uppercase; font-weight: 600;">Incident</div>
<div style="color: #f8fafc; font-size: 1.5rem; font-weight: 700; margin-top: 0.25rem; text-transform: capitalize;">{incident.replace("_", " ")}</div>
<div style="color: #cbd5e1; font-size: 0.75rem; margin-top: 0.25rem;">Cosmos Reason 2</div>
</div>
<div style="flex: 1; min-width: 200px; background: {sev_bg}; border-left: 4px solid {sev_accent}; padding: 1rem; border-radius: 0.5rem; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);">
<div style="color: #94a3b8; font-size: 0.875rem; text-transform: uppercase; font-weight: 600;">Severity</div>
<div style="color: #f8fafc; font-size: 1.5rem; font-weight: 700; margin-top: 0.25rem;">{severity.upper()}</div>
<div style="color: #cbd5e1; font-size: 0.75rem; margin-top: 0.25rem;">Scale 1-5</div>
</div>
<div style="flex: 1; min-width: 200px; background: #0f172a; border-left: 4px solid #3b82f6; padding: 1rem; border-radius: 0.5rem; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);">
<div style="color: #94a3b8; font-size: 0.875rem; text-transform: uppercase; font-weight: 600;">Risk Score</div>
<div style="color: #3b82f6; font-size: 1.5rem; font-weight: 700; margin-top: 0.25rem;">{risk_score}/10</div>
<div style="color: #64748b; font-size: 0.75rem; margin-top: 0.25rem;">Overall hazard rating</div>
</div>
</div>
"""
return cards_html
def format_reason_html(reason_result):
narrative = (reason_result or {}).get("narrative", "")
reasoning = (reason_result or {}).get("reasoning", "")
if not narrative:
return "<div style='color: #94a3b8; font-style: italic;'>Run BADAS + Reason to populate the narrative panel.</div>"
html = f"""
<div style="background: rgba(15,23,42,0.65); border: 1px solid rgba(51,65,85,0.8); border-radius: 0.5rem; padding: 1.25rem; margin-top: 1rem; font-family: sans-serif;">
<h3 style="color: #e2e8f0; margin-top: 0; margin-bottom: 0.5rem; font-size: 1.125rem;">Cosmos Reason 2 Narrative</h3>
<p style="color: #cbd5e1; line-height: 1.6; margin: 0;">{narrative}</p>
<div style="margin-top: 1rem; padding-top: 1rem; border-top: 1px solid rgba(51,65,85,0.5);">
<h4 style="color: #94a3b8; margin-top: 0; margin-bottom: 0.5rem; font-size: 0.875rem; text-transform: uppercase;">Reasoning Trace</h4>
<p style="color: #94a3b8; font-family: monospace; font-size: 0.875rem; line-height: 1.5; margin: 0; white-space: pre-wrap;">{reasoning}</p>
</div>
</div>
"""
return html
from space_backend import (
PREDICT_MODEL_NAME,
build_pipeline_overview,
cache_uploaded_video,
ensure_sample_video,
preload_runtime,
run_pipeline,
run_predict_only,
)
AUTO_PRELOAD_BADAS = os.environ.get("COSMOS_PRELOAD_BADAS", "1") == "1"
AUTO_PRELOAD_REASON = os.environ.get("COSMOS_PRELOAD_REASON", "1") == "1"
AUTO_PRELOAD_PREDICT = os.environ.get("COSMOS_PRELOAD_PREDICT", "1") == "1"
SUPPLEMENTARY_MARKDOWN = """
## Supplementary Materials
- [Project repository](https://github.com/Ryukijano/Nvidia-Cosmos-Cookoff)
- [NVIDIA Cosmos overview](https://docs.nvidia.com/cosmos/latest/introduction.html)
- [Cosmos Reason 2 docs](https://docs.nvidia.com/cosmos/latest/reason2/index.html)
- [Cosmos Reason 2 repo](https://github.com/nvidia-cosmos/cosmos-reason2)
- [Cosmos Predict 2.5 repo](https://github.com/nvidia-cosmos/cosmos-predict2.5)
- [Cosmos Cookbook](https://nvidia-cosmos.github.io/cosmos-cookbook/index.html)
- [BADAS paper](https://arxiv.org/abs/2510.14876)
- [Nexar dataset and challenge paper](https://openaccess.thecvf.com/content/CVPR2025W/WAD/papers/Moura_Nexar_Dashcam_Collision_Prediction_Dataset_and_Challenge_CVPRW_2025_paper.pdf)
"""
def badge_palette(value, kind):
normalized = (value or "unknown").strip().lower()
if kind == "incident":
palettes = {
"no_incident": ("#0f172a", "#22c55e", "No incident"),
"near_miss": ("#1f2937", "#f59e0b", "Near miss"),
"collision": ("#2b1a1a", "#ef4444", "Collision"),
"multi_vehicle_collision": ("#2a1025", "#dc2626", "Multi-vehicle collision"),
"unclear": ("#172033", "#64748b", "Unclear"),
}
else:
palettes = {
"none": ("#0f172a", "#22c55e", "None"),
"low": ("#10261b", "#22c55e", "Low"),
"moderate": ("#2b2110", "#f59e0b", "Moderate"),
"high": ("#30151a", "#ef4444", "High"),
"critical": ("#2a1025", "#dc2626", "Critical"),
"unknown": ("#172033", "#64748b", "Unknown"),
}
return palettes.get(normalized, ("#172033", "#64748b", value or "Unknown"))
def make_badas_figure(badas_result):
series = (badas_result or {}).get("prediction_series") or []
figure = go.Figure()
if series:
times = [item["time_sec"] for item in series]
probs = [item["probability"] for item in series]
figure.add_trace(
go.Scatter(
x=times,
y=probs,
mode="lines+markers",
line=dict(color="#22c55e", width=3),
marker=dict(size=6, color="#22c55e"),
fill="tozeroy",
fillcolor="rgba(34,197,94,0.18)",
name="Collision probability",
)
)
threshold = (badas_result or {}).get("threshold")
if threshold is not None:
figure.add_hline(y=threshold, line_dash="dash", line_color="#f59e0b")
alert_time = (badas_result or {}).get("alert_time")
if alert_time is not None:
figure.add_vline(x=alert_time, line_dash="dot", line_color="#ef4444")
figure.update_layout(
title="BADAS predictive collision timeline",
height=320,
margin=dict(l=20, r=20, t=50, b=20),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(15,23,42,0.1)",
xaxis_title="Time (s)",
yaxis_title="Probability",
yaxis=dict(range=[0, 1]),
)
return figure
def make_badas_heatmap(badas_result):
series = (badas_result or {}).get("prediction_series") or []
if not series:
return go.Figure()
df = pd.DataFrame(series)
z = [df["probability"].tolist()]
x = [f"{time_sec:.1f}s" for time_sec in df["time_sec"].tolist()]
figure = go.Figure(
data=go.Heatmap(
z=z,
x=x,
y=["BADAS risk"],
colorscale=[[0.0, "#052e16"], [0.35, "#166534"], [0.6, "#f59e0b"], [1.0, "#dc2626"]],
zmin=0.0,
zmax=1.0,
)
)
figure.update_layout(title="BADAS risk heatmap", height=180, margin=dict(l=20, r=20, t=50, b=20))
return figure
def make_reason_coverage_heatmap(reason_result):
frame_metadata = (reason_result or {}).get("frame_metadata") or {}
timestamps = frame_metadata.get("sampled_timestamps_sec") or []
if not timestamps:
return go.Figure()
bbox_count = max(1, int((reason_result or {}).get("bbox_count") or 0))
z = [[bbox_count for _ in timestamps]]
figure = go.Figure(
data=go.Heatmap(
z=z,
x=[f"{timestamp:.1f}s" for timestamp in timestamps],
y=["Reason coverage"],
colorscale="Blues",
)
)
figure.update_layout(title="Reason sampled-frame coverage", height=180, margin=dict(l=20, r=20, t=50, b=20))
return figure
def make_risk_gauge(reason_result):
risk_score = (reason_result or {}).get("risk_score") or 0
figure = go.Figure(
go.Indicator(
mode="gauge+number",
value=risk_score,
number={"suffix": "/5"},
title={"text": "Cosmos Reason 2 risk score"},
gauge={
"axis": {"range": [0, 5]},
"bar": {"color": "#ef4444" if risk_score >= 4 else "#f59e0b" if risk_score >= 3 else "#22c55e"},
"steps": [
{"range": [0, 2], "color": "rgba(34,197,94,0.20)"},
{"range": [2, 4], "color": "rgba(245,158,11,0.20)"},
{"range": [4, 5], "color": "rgba(239,68,68,0.24)"},
],
},
)
)
figure.update_layout(height=280, margin=dict(l=20, r=20, t=50, b=20))
return figure
def make_artifact_figure(artifacts):
labels = ["Clip", "BBox", "Risk", "GIF", "Predict"]
values = [
1 if artifacts.get("extracted_clip") else 0,
1 if artifacts.get("bbox_image") else 0,
1 if artifacts.get("risk_image") else 0,
1 if artifacts.get("overlay_gif") else 0,
1 if any(key.startswith("predict_") and artifacts.get(key) for key in artifacts.keys()) else 0,
]
colors = ["#22c55e" if value else "#475569" for value in values]
figure = go.Figure(go.Bar(x=labels, y=values, marker_color=colors, text=["ready" if v else "missing" for v in values], textposition="outside"))
figure.update_layout(title="Artifact readiness", height=240, yaxis=dict(range=[0, 1.2], showticklabels=False), margin=dict(l=20, r=20, t=50, b=20))
return figure
def format_html_cards(badas_result, reason_result):
overview = (pipeline_payload or {}).get("overview") or build_pipeline_overview(badas_result, reason_result)
incident_bg, incident_accent, incident_label = badge_palette(reason_result.get("incident_type"), "incident")
severity_bg, severity_accent, severity_label = badge_palette(reason_result.get("severity_label"), "severity")
predict_modes = ", ".join((predict_payload or {}).get("modes") or []) if predict_payload else "Not run"
return f"""
### Run summary
- **Collision gate triggered:** `{overview.get('collision_gate_triggered')}`
- **Alert time:** `{overview.get('alert_time_sec')}`
- **BADAS confidence:** `{overview.get('alert_confidence')}`
- **Incident:** `{incident_label}`
- **Severity:** `{severity_label}`
- **Reason risk score:** `{overview.get('reason_risk_score')}`
- **Predict modes:** `{predict_modes}`
<div style="display:flex; gap:0.9rem; flex-wrap:wrap; margin-top:0.75rem;">
<div style="background:{incident_bg}; border:1px solid {incident_accent}; border-radius:999px; padding:0.55rem 0.9rem; color:#f8fafc; font-weight:700;">Incident: {incident_label}</div>
<div style="background:{severity_bg}; border:1px solid {severity_accent}; border-radius:999px; padding:0.55rem 0.9rem; color:#f8fafc; font-weight:700;">Severity: {severity_label}</div>
</div>
"""
def format_reason_html(reason_result):
if not reason_result:
return "Run BADAS + Reason to populate the narrative panel."
validation = (reason_result or {}).get("validation") or {}
validation_flags = validation.get("flags") or {}
fallback_override = (reason_result or {}).get("fallback_override") or {}
lines = [
"### Cosmos Reason 2 narrative",
f"**Scene summary:** {(reason_result or {}).get('scene_summary') or 'N/A'}",
f"**At-risk agent:** {(reason_result or {}).get('at_risk_agent') or 'N/A'}",
f"**Explanation:** {(reason_result or {}).get('explanation') or 'N/A'}",
f"**Time to impact:** {(reason_result or {}).get('time_to_impact')}",
f"**Critical risk time:** {(reason_result or {}).get('critical_risk_time')}",
"",
"```text",
(reason_result or {}).get("text") or "No Reason output captured.",
"```",
]
if not validation.get("is_reliable", True):
lines.append("- **Warning:** Reason output was flagged as unreliable against BADAS evidence.")
if validation_flags:
lines.append(f"- **Validation flags:** `{json.dumps(validation_flags)}`")
if fallback_override.get("applied"):
lines.append(f"- **Fallback override:** `{json.dumps(fallback_override)}`")
return "\n".join(lines)
def empty_outputs(status_message="Ready."):
empty_json = {}
empty_plot = go.Figure()
return (
status_message,
"",
"Run the cached sample or upload a video to execute BADAS + Reason, then trigger Predict manually.",
None,
None,
None,
None,
None,
None,
None,
None,
None,
empty_plot,
empty_plot,
empty_plot,
empty_plot,
empty_plot,
"Run BADAS + Reason to populate the narrative panel.",
empty_json,
empty_json,
empty_json,
empty_json,
None,
None,
None,
)
def build_outputs(pipeline_payload, logs, preview_video_path, predict_payload=None):
iteration = ((pipeline_payload.get("iterations") or [{}])[-1])
steps = iteration.get("steps") or {}
badas_result = (steps.get("badas") or {}).get("result") or {}
reason_result = (steps.get("reason") or {}).get("result") or {}
artifacts = (pipeline_payload.get("artifacts") or {})
predict_payload = predict_payload if predict_payload is not None else (pipeline_payload.get("predict") or {})
first_predict_result = next(iter((predict_payload.get("results") or {}).values()), {}) if predict_payload else {}
prevented_video = ((predict_payload.get("results") or {}).get("prevented_continuation") or {}).get("output_video")
observed_video = ((predict_payload.get("results") or {}).get("observed_continuation") or {}).get("output_video")
return (
"Pipeline completed." if pipeline_payload else "No pipeline payload available.",
logs,
format_html_cards(badas_result, reason_result),
pipeline_payload,
preview_video_path,
artifacts.get("extracted_clip"),
artifacts.get("badas_gradient_saliency"),
artifacts.get("bbox_image"),
artifacts.get("risk_image"),
artifacts.get("overlay_gif"),
artifacts.get("badas_frame_strip"),
artifacts.get("reason_frame_strip"),
make_badas_figure(badas_result),
make_badas_heatmap(badas_result),
make_reason_coverage_heatmap(reason_result),
make_risk_gauge(reason_result),
make_artifact_figure(artifacts),
format_reason_html(reason_result),
badas_result,
reason_result,
pipeline_payload,
predict_payload or {},
first_predict_result.get("conditioning_clip"),
prevented_video,
observed_video,
)
def on_upload(uploaded_file):
if not uploaded_file:
return None, None, "No uploaded file selected."
cached_path = cache_uploaded_video(uploaded_file)
return cached_path, cached_path, f"Cached uploaded video: `{Path(cached_path).name}`"
def use_sample_video():
sample_path = ensure_sample_video()
return sample_path, sample_path, f"Using cached sample: `{Path(sample_path).name}`"
def bootstrap_space():
sample_path = ensure_sample_video()
status_lines = [f"Sample cached at: {sample_path}"]
return sample_path, sample_path, "Space loaded. Click 'Run BADAS + Reason' to begin.", "\n".join(status_lines)
@spaces.GPU(duration=60)
def warmup_models():
return preload_runtime(
preload_badas=True,
preload_reason=True,
preload_predict=True,
)
@spaces.GPU(duration=60)
def run_pipeline_action(input_video_path):
if not input_video_path:
raise gr.Error("Upload a video or choose the sample clip first.")
result = run_pipeline(str(input_video_path), include_predict=False)
return build_outputs(result["pipeline_payload"], result["logs"], str(input_video_path), result.get("predict_payload"))
@spaces.GPU(duration=60)
def run_predict_action(pipeline_payload, input_video_path, selection):
if not pipeline_payload:
raise gr.Error("Run BADAS + Reason before Predict.")
try:
predict_payload, merged_payload = run_predict_only(pipeline_payload, selection=selection, predict_model_name=PREDICT_MODEL_NAME)
return build_outputs(merged_payload, "Predict completed.", input_video_path, predict_payload)
except Exception as e:
raise gr.Error(f"Predict failed: {e}")
def safe_warmup_message():
try:
return warmup_models()
except Exception:
return traceback.format_exc()
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue").set(body_background_fill="#0f172a", body_text_color="#f1f5f9"), title="Cosmos Sentinel") as demo:
pipeline_state = gr.State(None)
input_video_state = gr.State(None)
gr.Markdown("# 🚦 Cosmos Sentinel")
gr.Markdown(
"A Gradio-first Hugging Face Space for the full Cosmos Sentinel pipeline: BADAS collision gating, Cosmos Reason 2 narrative analysis, and manual Cosmos Predict continuation rollouts on ZeroGPU-backed hardware."
)
with gr.Row():
startup_status = gr.Markdown("Booting Space...")
warmup_log = gr.Textbox(label="Warmup / preload log", lines=5)
with gr.Row():
with gr.Column(scale=1):
upload = gr.File(label="Upload MP4 footage", file_types=[".mp4"])
sample_button = gr.Button("Use cached sample video")
warmup_button = gr.Button("Warm up BADAS + Reason + Predict")
run_button = gr.Button("Run BADAS + Reason", variant="primary")
predict_selection = gr.Dropdown(
choices=["both", "prevented_continuation", "observed_continuation"],
value="both",
label="Cosmos Predict rollout set",
)
predict_button = gr.Button("Run Cosmos Predict")
selected_video_path = gr.Textbox(label="Active input video path")
with gr.Column(scale=1):
preview_video = gr.Video(label="Preview")
with gr.Tabs():
with gr.Tab("Run Outputs"):
status_markdown = gr.Markdown("Ready.")
summary_html = gr.HTML("<div style='color: #94a3b8; font-style: italic;'>Run the cached sample or upload a video to execute the pipeline.</div>")
pipeline_logs = gr.Textbox(label="Pipeline logs", lines=18)
reason_html = gr.HTML("<div style='color: #94a3b8; font-style: italic;'>Run BADAS + Reason to populate the narrative panel.</div>")
with gr.Row():
extracted_clip = gr.Video(label="BADAS-focused clip")
conditioning_clip = gr.Video(label="Predict conditioning clip")
with gr.Row():
prevented_video = gr.Video(label="Predict prevented continuation")
observed_video = gr.Video(label="Predict observed continuation")
with gr.Tab("Visual Diagnostics"):
with gr.Row():
badas_plot = gr.Plot(label="BADAS timeline")
risk_gauge = gr.Plot(label="Reason risk gauge")
with gr.Row():
badas_heatmap = gr.Plot(label="BADAS heatmap")
reason_heatmap = gr.Plot(label="Reason coverage")
artifact_plot = gr.Plot(label="Artifact readiness")
with gr.Row():
badas_gradient = gr.Image(label="BADAS gradient saliency")
bbox_image = gr.Image(label="Reason bounding-box visualization")
with gr.Row():
risk_image = gr.Image(label="Reason risk visualization")
overlay_gif = gr.Image(label="Annotated overlay GIF")
badas_frame_strip = gr.Image(label="BADAS frame strip")
reason_frame_strip = gr.Image(label="Reason frame strip")
with gr.Tab("Structured Payloads"):
badas_json = gr.JSON(label="BADAS payload")
reason_json = gr.JSON(label="Reason payload")
predict_json = gr.JSON(label="Predict payload")
pipeline_json = gr.JSON(label="Pipeline payload")
with gr.Tab("Supplementary Materials"):
gr.Markdown(SUPPLEMENTARY_MARKDOWN)
sample_button.click(
fn=use_sample_video,
inputs=None,
outputs=[input_video_state, preview_video, startup_status],
).then(
fn=lambda path: path,
inputs=input_video_state,
outputs=selected_video_path,
)
upload.upload(
fn=on_upload,
inputs=upload,
outputs=[input_video_state, preview_video, startup_status],
).then(
fn=lambda path: path,
inputs=input_video_state,
outputs=selected_video_path,
)
warmup_button.click(fn=safe_warmup_message, inputs=None, outputs=warmup_log)
run_button.click(
fn=run_pipeline_action,
inputs=input_video_state,
outputs=[
status_markdown,
pipeline_logs,
summary_html,
pipeline_state,
preview_video,
extracted_clip,
badas_gradient,
bbox_image,
risk_image,
overlay_gif,
badas_frame_strip,
reason_frame_strip,
badas_plot,
badas_heatmap,
reason_heatmap,
risk_gauge,
artifact_plot,
reason_html,
badas_json,
reason_json,
pipeline_json,
predict_json,
conditioning_clip,
prevented_video,
observed_video,
],
)
predict_button.click(
fn=run_predict_action,
inputs=[pipeline_state, input_video_state, predict_selection],
outputs=[
status_markdown,
pipeline_logs,
summary_html,
pipeline_state,
preview_video,
extracted_clip,
badas_gradient,
bbox_image,
risk_image,
overlay_gif,
badas_frame_strip,
reason_frame_strip,
badas_plot,
badas_heatmap,
reason_heatmap,
risk_gauge,
artifact_plot,
reason_html,
badas_json,
reason_json,
pipeline_json,
predict_json,
conditioning_clip,
prevented_video,
observed_video,
],
)
demo.load(
fn=bootstrap_space,
inputs=None,
outputs=[input_video_state, preview_video, startup_status, warmup_log],
).then(
fn=lambda path: path,
inputs=input_video_state,
outputs=selected_video_path,
)
demo.queue(default_concurrency_limit=1)
demo.launch(ssr_mode=False)