try: import spaces except ImportError: class _SpacesFallback: @staticmethod def GPU(fn=None, **_kwargs): if fn is None: def decorator(inner): return inner return decorator return fn spaces = _SpacesFallback() import os from pathlib import Path if Path("/data").exists(): os.environ["HF_HOME"] = "/data/.huggingface" import json import os import traceback from pathlib import Path import gradio as gr import pandas as pd import plotly.graph_objects as go def format_html_cards(badas_result, reason_result): if not badas_result and not reason_result: return "
Run the cached sample or upload a video to execute the pipeline.
" collision_triggered = bool((badas_result or {}).get("collision_detected")) gate_color = "#22c55e" if collision_triggered else "#64748b" gate_text = "Triggered" if collision_triggered else "Watching" incident = ((reason_result or {}).get("incident_type") or "unclear").lower() incident_colors = { "collision": ("#7f1d1d", "#ef4444"), "near_miss": ("#7c2d12", "#f97316"), "hazard": ("#713f12", "#f59e0b"), } inc_bg, inc_accent = incident_colors.get(incident, ("#1e293b", "#94a3b8")) severity = str((reason_result or {}).get("severity_label") or "unknown").lower() severity_colors = { "1": ("#14532d", "#22c55e"), "2": ("#713f12", "#f59e0b"), "3": ("#7c2d12", "#f97316"), "4": ("#7f1d1d", "#ef4444"), "5": ("#4c0519", "#e11d48"), } sev_bg, sev_accent = severity_colors.get(severity, ("#1e293b", "#94a3b8")) risk_score = (reason_result or {}).get("risk_score", 0) cards_html = f"""
Collision Gate
{gate_text}
BADAS V-JEPA2
Incident
{incident.replace("_", " ")}
Cosmos Reason 2
Severity
{severity.upper()}
Scale 1-5
Risk Score
{risk_score}/10
Overall hazard rating
""" return cards_html def format_reason_html(reason_result): narrative = (reason_result or {}).get("narrative", "") reasoning = (reason_result or {}).get("reasoning", "") if not narrative: return "
Run BADAS + Reason to populate the narrative panel.
" html = f"""

Cosmos Reason 2 Narrative

{narrative}

Reasoning Trace

{reasoning}

""" return html from space_backend import ( PREDICT_MODEL_NAME, build_pipeline_overview, cache_uploaded_video, ensure_sample_video, preload_runtime, run_pipeline, run_predict_only, ) AUTO_PRELOAD_BADAS = os.environ.get("COSMOS_PRELOAD_BADAS", "1") == "1" AUTO_PRELOAD_REASON = os.environ.get("COSMOS_PRELOAD_REASON", "1") == "1" AUTO_PRELOAD_PREDICT = os.environ.get("COSMOS_PRELOAD_PREDICT", "1") == "1" SUPPLEMENTARY_MARKDOWN = """ ## Supplementary Materials - [Project repository](https://github.com/Ryukijano/Nvidia-Cosmos-Cookoff) - [NVIDIA Cosmos overview](https://docs.nvidia.com/cosmos/latest/introduction.html) - [Cosmos Reason 2 docs](https://docs.nvidia.com/cosmos/latest/reason2/index.html) - [Cosmos Reason 2 repo](https://github.com/nvidia-cosmos/cosmos-reason2) - [Cosmos Predict 2.5 repo](https://github.com/nvidia-cosmos/cosmos-predict2.5) - [Cosmos Cookbook](https://nvidia-cosmos.github.io/cosmos-cookbook/index.html) - [BADAS paper](https://arxiv.org/abs/2510.14876) - [Nexar dataset and challenge paper](https://openaccess.thecvf.com/content/CVPR2025W/WAD/papers/Moura_Nexar_Dashcam_Collision_Prediction_Dataset_and_Challenge_CVPRW_2025_paper.pdf) """ def badge_palette(value, kind): normalized = (value or "unknown").strip().lower() if kind == "incident": palettes = { "no_incident": ("#0f172a", "#22c55e", "No incident"), "near_miss": ("#1f2937", "#f59e0b", "Near miss"), "collision": ("#2b1a1a", "#ef4444", "Collision"), "multi_vehicle_collision": ("#2a1025", "#dc2626", "Multi-vehicle collision"), "unclear": ("#172033", "#64748b", "Unclear"), } else: palettes = { "none": ("#0f172a", "#22c55e", "None"), "low": ("#10261b", "#22c55e", "Low"), "moderate": ("#2b2110", "#f59e0b", "Moderate"), "high": ("#30151a", "#ef4444", "High"), "critical": ("#2a1025", "#dc2626", "Critical"), "unknown": ("#172033", "#64748b", "Unknown"), } return palettes.get(normalized, ("#172033", "#64748b", value or "Unknown")) def make_badas_figure(badas_result): series = (badas_result or {}).get("prediction_series") or [] figure = go.Figure() if series: times = [item["time_sec"] for item in series] probs = [item["probability"] for item in series] figure.add_trace( go.Scatter( x=times, y=probs, mode="lines+markers", line=dict(color="#22c55e", width=3), marker=dict(size=6, color="#22c55e"), fill="tozeroy", fillcolor="rgba(34,197,94,0.18)", name="Collision probability", ) ) threshold = (badas_result or {}).get("threshold") if threshold is not None: figure.add_hline(y=threshold, line_dash="dash", line_color="#f59e0b") alert_time = (badas_result or {}).get("alert_time") if alert_time is not None: figure.add_vline(x=alert_time, line_dash="dot", line_color="#ef4444") figure.update_layout( title="BADAS predictive collision timeline", height=320, margin=dict(l=20, r=20, t=50, b=20), paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(15,23,42,0.1)", xaxis_title="Time (s)", yaxis_title="Probability", yaxis=dict(range=[0, 1]), ) return figure def make_badas_heatmap(badas_result): series = (badas_result or {}).get("prediction_series") or [] if not series: return go.Figure() df = pd.DataFrame(series) z = [df["probability"].tolist()] x = [f"{time_sec:.1f}s" for time_sec in df["time_sec"].tolist()] figure = go.Figure( data=go.Heatmap( z=z, x=x, y=["BADAS risk"], colorscale=[[0.0, "#052e16"], [0.35, "#166534"], [0.6, "#f59e0b"], [1.0, "#dc2626"]], zmin=0.0, zmax=1.0, ) ) figure.update_layout(title="BADAS risk heatmap", height=180, margin=dict(l=20, r=20, t=50, b=20)) return figure def make_reason_coverage_heatmap(reason_result): frame_metadata = (reason_result or {}).get("frame_metadata") or {} timestamps = frame_metadata.get("sampled_timestamps_sec") or [] if not timestamps: return go.Figure() bbox_count = max(1, int((reason_result or {}).get("bbox_count") or 0)) z = [[bbox_count for _ in timestamps]] figure = go.Figure( data=go.Heatmap( z=z, x=[f"{timestamp:.1f}s" for timestamp in timestamps], y=["Reason coverage"], colorscale="Blues", ) ) figure.update_layout(title="Reason sampled-frame coverage", height=180, margin=dict(l=20, r=20, t=50, b=20)) return figure def make_risk_gauge(reason_result): risk_score = (reason_result or {}).get("risk_score") or 0 figure = go.Figure( go.Indicator( mode="gauge+number", value=risk_score, number={"suffix": "/5"}, title={"text": "Cosmos Reason 2 risk score"}, gauge={ "axis": {"range": [0, 5]}, "bar": {"color": "#ef4444" if risk_score >= 4 else "#f59e0b" if risk_score >= 3 else "#22c55e"}, "steps": [ {"range": [0, 2], "color": "rgba(34,197,94,0.20)"}, {"range": [2, 4], "color": "rgba(245,158,11,0.20)"}, {"range": [4, 5], "color": "rgba(239,68,68,0.24)"}, ], }, ) ) figure.update_layout(height=280, margin=dict(l=20, r=20, t=50, b=20)) return figure def make_artifact_figure(artifacts): labels = ["Clip", "BBox", "Risk", "GIF", "Predict"] values = [ 1 if artifacts.get("extracted_clip") else 0, 1 if artifacts.get("bbox_image") else 0, 1 if artifacts.get("risk_image") else 0, 1 if artifacts.get("overlay_gif") else 0, 1 if any(key.startswith("predict_") and artifacts.get(key) for key in artifacts.keys()) else 0, ] colors = ["#22c55e" if value else "#475569" for value in values] figure = go.Figure(go.Bar(x=labels, y=values, marker_color=colors, text=["ready" if v else "missing" for v in values], textposition="outside")) figure.update_layout(title="Artifact readiness", height=240, yaxis=dict(range=[0, 1.2], showticklabels=False), margin=dict(l=20, r=20, t=50, b=20)) return figure def format_html_cards(badas_result, reason_result): overview = (pipeline_payload or {}).get("overview") or build_pipeline_overview(badas_result, reason_result) incident_bg, incident_accent, incident_label = badge_palette(reason_result.get("incident_type"), "incident") severity_bg, severity_accent, severity_label = badge_palette(reason_result.get("severity_label"), "severity") predict_modes = ", ".join((predict_payload or {}).get("modes") or []) if predict_payload else "Not run" return f""" ### Run summary - **Collision gate triggered:** `{overview.get('collision_gate_triggered')}` - **Alert time:** `{overview.get('alert_time_sec')}` - **BADAS confidence:** `{overview.get('alert_confidence')}` - **Incident:** `{incident_label}` - **Severity:** `{severity_label}` - **Reason risk score:** `{overview.get('reason_risk_score')}` - **Predict modes:** `{predict_modes}`
Incident: {incident_label}
Severity: {severity_label}
""" def format_reason_html(reason_result): if not reason_result: return "Run BADAS + Reason to populate the narrative panel." validation = (reason_result or {}).get("validation") or {} validation_flags = validation.get("flags") or {} fallback_override = (reason_result or {}).get("fallback_override") or {} lines = [ "### Cosmos Reason 2 narrative", f"**Scene summary:** {(reason_result or {}).get('scene_summary') or 'N/A'}", f"**At-risk agent:** {(reason_result or {}).get('at_risk_agent') or 'N/A'}", f"**Explanation:** {(reason_result or {}).get('explanation') or 'N/A'}", f"**Time to impact:** {(reason_result or {}).get('time_to_impact')}", f"**Critical risk time:** {(reason_result or {}).get('critical_risk_time')}", "", "```text", (reason_result or {}).get("text") or "No Reason output captured.", "```", ] if not validation.get("is_reliable", True): lines.append("- **Warning:** Reason output was flagged as unreliable against BADAS evidence.") if validation_flags: lines.append(f"- **Validation flags:** `{json.dumps(validation_flags)}`") if fallback_override.get("applied"): lines.append(f"- **Fallback override:** `{json.dumps(fallback_override)}`") return "\n".join(lines) def empty_outputs(status_message="Ready."): empty_json = {} empty_plot = go.Figure() return ( status_message, "", "Run the cached sample or upload a video to execute BADAS + Reason, then trigger Predict manually.", None, None, None, None, None, None, None, None, None, empty_plot, empty_plot, empty_plot, empty_plot, empty_plot, "Run BADAS + Reason to populate the narrative panel.", empty_json, empty_json, empty_json, empty_json, None, None, None, ) def build_outputs(pipeline_payload, logs, preview_video_path, predict_payload=None): iteration = ((pipeline_payload.get("iterations") or [{}])[-1]) steps = iteration.get("steps") or {} badas_result = (steps.get("badas") or {}).get("result") or {} reason_result = (steps.get("reason") or {}).get("result") or {} artifacts = (pipeline_payload.get("artifacts") or {}) predict_payload = predict_payload if predict_payload is not None else (pipeline_payload.get("predict") or {}) first_predict_result = next(iter((predict_payload.get("results") or {}).values()), {}) if predict_payload else {} prevented_video = ((predict_payload.get("results") or {}).get("prevented_continuation") or {}).get("output_video") observed_video = ((predict_payload.get("results") or {}).get("observed_continuation") or {}).get("output_video") return ( "Pipeline completed." if pipeline_payload else "No pipeline payload available.", logs, format_html_cards(badas_result, reason_result), pipeline_payload, preview_video_path, artifacts.get("extracted_clip"), artifacts.get("badas_gradient_saliency"), artifacts.get("bbox_image"), artifacts.get("risk_image"), artifacts.get("overlay_gif"), artifacts.get("badas_frame_strip"), artifacts.get("reason_frame_strip"), make_badas_figure(badas_result), make_badas_heatmap(badas_result), make_reason_coverage_heatmap(reason_result), make_risk_gauge(reason_result), make_artifact_figure(artifacts), format_reason_html(reason_result), badas_result, reason_result, pipeline_payload, predict_payload or {}, first_predict_result.get("conditioning_clip"), prevented_video, observed_video, ) def on_upload(uploaded_file): if not uploaded_file: return None, None, "No uploaded file selected." cached_path = cache_uploaded_video(uploaded_file) return cached_path, cached_path, f"Cached uploaded video: `{Path(cached_path).name}`" def use_sample_video(): sample_path = ensure_sample_video() return sample_path, sample_path, f"Using cached sample: `{Path(sample_path).name}`" def bootstrap_space(): sample_path = ensure_sample_video() status_lines = [f"Sample cached at: {sample_path}"] return sample_path, sample_path, "Space loaded. Click 'Run BADAS + Reason' to begin.", "\n".join(status_lines) @spaces.GPU(duration=60) def warmup_models(): return preload_runtime( preload_badas=True, preload_reason=True, preload_predict=True, ) @spaces.GPU(duration=60) def run_pipeline_action(input_video_path): if not input_video_path: raise gr.Error("Upload a video or choose the sample clip first.") result = run_pipeline(str(input_video_path), include_predict=False) return build_outputs(result["pipeline_payload"], result["logs"], str(input_video_path), result.get("predict_payload")) @spaces.GPU(duration=60) def run_predict_action(pipeline_payload, input_video_path, selection): if not pipeline_payload: raise gr.Error("Run BADAS + Reason before Predict.") try: predict_payload, merged_payload = run_predict_only(pipeline_payload, selection=selection, predict_model_name=PREDICT_MODEL_NAME) return build_outputs(merged_payload, "Predict completed.", input_video_path, predict_payload) except Exception as e: raise gr.Error(f"Predict failed: {e}") def safe_warmup_message(): try: return warmup_models() except Exception: return traceback.format_exc() with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue").set(body_background_fill="#0f172a", body_text_color="#f1f5f9"), title="Cosmos Sentinel") as demo: pipeline_state = gr.State(None) input_video_state = gr.State(None) gr.Markdown("# 🚦 Cosmos Sentinel") gr.Markdown( "A Gradio-first Hugging Face Space for the full Cosmos Sentinel pipeline: BADAS collision gating, Cosmos Reason 2 narrative analysis, and manual Cosmos Predict continuation rollouts on ZeroGPU-backed hardware." ) with gr.Row(): startup_status = gr.Markdown("Booting Space...") warmup_log = gr.Textbox(label="Warmup / preload log", lines=5) with gr.Row(): with gr.Column(scale=1): upload = gr.File(label="Upload MP4 footage", file_types=[".mp4"]) sample_button = gr.Button("Use cached sample video") warmup_button = gr.Button("Warm up BADAS + Reason + Predict") run_button = gr.Button("Run BADAS + Reason", variant="primary") predict_selection = gr.Dropdown( choices=["both", "prevented_continuation", "observed_continuation"], value="both", label="Cosmos Predict rollout set", ) predict_button = gr.Button("Run Cosmos Predict") selected_video_path = gr.Textbox(label="Active input video path") with gr.Column(scale=1): preview_video = gr.Video(label="Preview") with gr.Tabs(): with gr.Tab("Run Outputs"): status_markdown = gr.Markdown("Ready.") summary_html = gr.HTML("
Run the cached sample or upload a video to execute the pipeline.
") pipeline_logs = gr.Textbox(label="Pipeline logs", lines=18) reason_html = gr.HTML("
Run BADAS + Reason to populate the narrative panel.
") with gr.Row(): extracted_clip = gr.Video(label="BADAS-focused clip") conditioning_clip = gr.Video(label="Predict conditioning clip") with gr.Row(): prevented_video = gr.Video(label="Predict prevented continuation") observed_video = gr.Video(label="Predict observed continuation") with gr.Tab("Visual Diagnostics"): with gr.Row(): badas_plot = gr.Plot(label="BADAS timeline") risk_gauge = gr.Plot(label="Reason risk gauge") with gr.Row(): badas_heatmap = gr.Plot(label="BADAS heatmap") reason_heatmap = gr.Plot(label="Reason coverage") artifact_plot = gr.Plot(label="Artifact readiness") with gr.Row(): badas_gradient = gr.Image(label="BADAS gradient saliency") bbox_image = gr.Image(label="Reason bounding-box visualization") with gr.Row(): risk_image = gr.Image(label="Reason risk visualization") overlay_gif = gr.Image(label="Annotated overlay GIF") badas_frame_strip = gr.Image(label="BADAS frame strip") reason_frame_strip = gr.Image(label="Reason frame strip") with gr.Tab("Structured Payloads"): badas_json = gr.JSON(label="BADAS payload") reason_json = gr.JSON(label="Reason payload") predict_json = gr.JSON(label="Predict payload") pipeline_json = gr.JSON(label="Pipeline payload") with gr.Tab("Supplementary Materials"): gr.Markdown(SUPPLEMENTARY_MARKDOWN) sample_button.click( fn=use_sample_video, inputs=None, outputs=[input_video_state, preview_video, startup_status], ).then( fn=lambda path: path, inputs=input_video_state, outputs=selected_video_path, ) upload.upload( fn=on_upload, inputs=upload, outputs=[input_video_state, preview_video, startup_status], ).then( fn=lambda path: path, inputs=input_video_state, outputs=selected_video_path, ) warmup_button.click(fn=safe_warmup_message, inputs=None, outputs=warmup_log) run_button.click( fn=run_pipeline_action, inputs=input_video_state, outputs=[ status_markdown, pipeline_logs, summary_html, pipeline_state, preview_video, extracted_clip, badas_gradient, bbox_image, risk_image, overlay_gif, badas_frame_strip, reason_frame_strip, badas_plot, badas_heatmap, reason_heatmap, risk_gauge, artifact_plot, reason_html, badas_json, reason_json, pipeline_json, predict_json, conditioning_clip, prevented_video, observed_video, ], ) predict_button.click( fn=run_predict_action, inputs=[pipeline_state, input_video_state, predict_selection], outputs=[ status_markdown, pipeline_logs, summary_html, pipeline_state, preview_video, extracted_clip, badas_gradient, bbox_image, risk_image, overlay_gif, badas_frame_strip, reason_frame_strip, badas_plot, badas_heatmap, reason_heatmap, risk_gauge, artifact_plot, reason_html, badas_json, reason_json, pipeline_json, predict_json, conditioning_clip, prevented_video, observed_video, ], ) demo.load( fn=bootstrap_space, inputs=None, outputs=[input_video_state, preview_video, startup_status, warmup_log], ).then( fn=lambda path: path, inputs=input_video_state, outputs=selected_video_path, ) demo.queue(default_concurrency_limit=1) demo.launch(ssr_mode=False)