Spaces:
Running on Zero
Running on Zero
| try: | |
| import spaces | |
| except ImportError: | |
| class _SpacesFallback: | |
| def GPU(fn=None, **_kwargs): | |
| if fn is None: | |
| def decorator(inner): | |
| return inner | |
| return decorator | |
| return fn | |
| spaces = _SpacesFallback() | |
| import os | |
| from pathlib import Path | |
| if Path("/data").exists(): | |
| os.environ["HF_HOME"] = "/data/.huggingface" | |
| import json | |
| import os | |
| import traceback | |
| from pathlib import Path | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| def format_html_cards(badas_result, reason_result): | |
| if not badas_result and not reason_result: | |
| return "<div style='color: #94a3b8; font-style: italic;'>Run the cached sample or upload a video to execute the pipeline.</div>" | |
| collision_triggered = bool((badas_result or {}).get("collision_detected")) | |
| gate_color = "#22c55e" if collision_triggered else "#64748b" | |
| gate_text = "Triggered" if collision_triggered else "Watching" | |
| incident = ((reason_result or {}).get("incident_type") or "unclear").lower() | |
| incident_colors = { | |
| "collision": ("#7f1d1d", "#ef4444"), | |
| "near_miss": ("#7c2d12", "#f97316"), | |
| "hazard": ("#713f12", "#f59e0b"), | |
| } | |
| inc_bg, inc_accent = incident_colors.get(incident, ("#1e293b", "#94a3b8")) | |
| severity = str((reason_result or {}).get("severity_label") or "unknown").lower() | |
| severity_colors = { | |
| "1": ("#14532d", "#22c55e"), | |
| "2": ("#713f12", "#f59e0b"), | |
| "3": ("#7c2d12", "#f97316"), | |
| "4": ("#7f1d1d", "#ef4444"), | |
| "5": ("#4c0519", "#e11d48"), | |
| } | |
| sev_bg, sev_accent = severity_colors.get(severity, ("#1e293b", "#94a3b8")) | |
| risk_score = (reason_result or {}).get("risk_score", 0) | |
| cards_html = f""" | |
| <div style="display: flex; gap: 1rem; flex-wrap: wrap; margin-bottom: 1rem; font-family: sans-serif;"> | |
| <div style="flex: 1; min-width: 200px; background: #0f172a; border-left: 4px solid {gate_color}; padding: 1rem; border-radius: 0.5rem; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);"> | |
| <div style="color: #94a3b8; font-size: 0.875rem; text-transform: uppercase; font-weight: 600;">Collision Gate</div> | |
| <div style="color: {gate_color}; font-size: 1.5rem; font-weight: 700; margin-top: 0.25rem;">{gate_text}</div> | |
| <div style="color: #64748b; font-size: 0.75rem; margin-top: 0.25rem;">BADAS V-JEPA2</div> | |
| </div> | |
| <div style="flex: 1; min-width: 200px; background: {inc_bg}; border-left: 4px solid {inc_accent}; padding: 1rem; border-radius: 0.5rem; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);"> | |
| <div style="color: #94a3b8; font-size: 0.875rem; text-transform: uppercase; font-weight: 600;">Incident</div> | |
| <div style="color: #f8fafc; font-size: 1.5rem; font-weight: 700; margin-top: 0.25rem; text-transform: capitalize;">{incident.replace("_", " ")}</div> | |
| <div style="color: #cbd5e1; font-size: 0.75rem; margin-top: 0.25rem;">Cosmos Reason 2</div> | |
| </div> | |
| <div style="flex: 1; min-width: 200px; background: {sev_bg}; border-left: 4px solid {sev_accent}; padding: 1rem; border-radius: 0.5rem; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);"> | |
| <div style="color: #94a3b8; font-size: 0.875rem; text-transform: uppercase; font-weight: 600;">Severity</div> | |
| <div style="color: #f8fafc; font-size: 1.5rem; font-weight: 700; margin-top: 0.25rem;">{severity.upper()}</div> | |
| <div style="color: #cbd5e1; font-size: 0.75rem; margin-top: 0.25rem;">Scale 1-5</div> | |
| </div> | |
| <div style="flex: 1; min-width: 200px; background: #0f172a; border-left: 4px solid #3b82f6; padding: 1rem; border-radius: 0.5rem; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);"> | |
| <div style="color: #94a3b8; font-size: 0.875rem; text-transform: uppercase; font-weight: 600;">Risk Score</div> | |
| <div style="color: #3b82f6; font-size: 1.5rem; font-weight: 700; margin-top: 0.25rem;">{risk_score}/10</div> | |
| <div style="color: #64748b; font-size: 0.75rem; margin-top: 0.25rem;">Overall hazard rating</div> | |
| </div> | |
| </div> | |
| """ | |
| return cards_html | |
| def format_reason_html(reason_result): | |
| narrative = (reason_result or {}).get("narrative", "") | |
| reasoning = (reason_result or {}).get("reasoning", "") | |
| if not narrative: | |
| return "<div style='color: #94a3b8; font-style: italic;'>Run BADAS + Reason to populate the narrative panel.</div>" | |
| html = f""" | |
| <div style="background: rgba(15,23,42,0.65); border: 1px solid rgba(51,65,85,0.8); border-radius: 0.5rem; padding: 1.25rem; margin-top: 1rem; font-family: sans-serif;"> | |
| <h3 style="color: #e2e8f0; margin-top: 0; margin-bottom: 0.5rem; font-size: 1.125rem;">Cosmos Reason 2 Narrative</h3> | |
| <p style="color: #cbd5e1; line-height: 1.6; margin: 0;">{narrative}</p> | |
| <div style="margin-top: 1rem; padding-top: 1rem; border-top: 1px solid rgba(51,65,85,0.5);"> | |
| <h4 style="color: #94a3b8; margin-top: 0; margin-bottom: 0.5rem; font-size: 0.875rem; text-transform: uppercase;">Reasoning Trace</h4> | |
| <p style="color: #94a3b8; font-family: monospace; font-size: 0.875rem; line-height: 1.5; margin: 0; white-space: pre-wrap;">{reasoning}</p> | |
| </div> | |
| </div> | |
| """ | |
| return html | |
| from space_backend import ( | |
| PREDICT_MODEL_NAME, | |
| build_pipeline_overview, | |
| cache_uploaded_video, | |
| ensure_sample_video, | |
| preload_runtime, | |
| run_pipeline, | |
| run_predict_only, | |
| ) | |
| AUTO_PRELOAD_BADAS = os.environ.get("COSMOS_PRELOAD_BADAS", "1") == "1" | |
| AUTO_PRELOAD_REASON = os.environ.get("COSMOS_PRELOAD_REASON", "1") == "1" | |
| AUTO_PRELOAD_PREDICT = os.environ.get("COSMOS_PRELOAD_PREDICT", "1") == "1" | |
| SUPPLEMENTARY_MARKDOWN = """ | |
| ## Supplementary Materials | |
| - [Project repository](https://github.com/Ryukijano/Nvidia-Cosmos-Cookoff) | |
| - [NVIDIA Cosmos overview](https://docs.nvidia.com/cosmos/latest/introduction.html) | |
| - [Cosmos Reason 2 docs](https://docs.nvidia.com/cosmos/latest/reason2/index.html) | |
| - [Cosmos Reason 2 repo](https://github.com/nvidia-cosmos/cosmos-reason2) | |
| - [Cosmos Predict 2.5 repo](https://github.com/nvidia-cosmos/cosmos-predict2.5) | |
| - [Cosmos Cookbook](https://nvidia-cosmos.github.io/cosmos-cookbook/index.html) | |
| - [BADAS paper](https://arxiv.org/abs/2510.14876) | |
| - [Nexar dataset and challenge paper](https://openaccess.thecvf.com/content/CVPR2025W/WAD/papers/Moura_Nexar_Dashcam_Collision_Prediction_Dataset_and_Challenge_CVPRW_2025_paper.pdf) | |
| """ | |
| def badge_palette(value, kind): | |
| normalized = (value or "unknown").strip().lower() | |
| if kind == "incident": | |
| palettes = { | |
| "no_incident": ("#0f172a", "#22c55e", "No incident"), | |
| "near_miss": ("#1f2937", "#f59e0b", "Near miss"), | |
| "collision": ("#2b1a1a", "#ef4444", "Collision"), | |
| "multi_vehicle_collision": ("#2a1025", "#dc2626", "Multi-vehicle collision"), | |
| "unclear": ("#172033", "#64748b", "Unclear"), | |
| } | |
| else: | |
| palettes = { | |
| "none": ("#0f172a", "#22c55e", "None"), | |
| "low": ("#10261b", "#22c55e", "Low"), | |
| "moderate": ("#2b2110", "#f59e0b", "Moderate"), | |
| "high": ("#30151a", "#ef4444", "High"), | |
| "critical": ("#2a1025", "#dc2626", "Critical"), | |
| "unknown": ("#172033", "#64748b", "Unknown"), | |
| } | |
| return palettes.get(normalized, ("#172033", "#64748b", value or "Unknown")) | |
| def make_badas_figure(badas_result): | |
| series = (badas_result or {}).get("prediction_series") or [] | |
| figure = go.Figure() | |
| if series: | |
| times = [item["time_sec"] for item in series] | |
| probs = [item["probability"] for item in series] | |
| figure.add_trace( | |
| go.Scatter( | |
| x=times, | |
| y=probs, | |
| mode="lines+markers", | |
| line=dict(color="#22c55e", width=3), | |
| marker=dict(size=6, color="#22c55e"), | |
| fill="tozeroy", | |
| fillcolor="rgba(34,197,94,0.18)", | |
| name="Collision probability", | |
| ) | |
| ) | |
| threshold = (badas_result or {}).get("threshold") | |
| if threshold is not None: | |
| figure.add_hline(y=threshold, line_dash="dash", line_color="#f59e0b") | |
| alert_time = (badas_result or {}).get("alert_time") | |
| if alert_time is not None: | |
| figure.add_vline(x=alert_time, line_dash="dot", line_color="#ef4444") | |
| figure.update_layout( | |
| title="BADAS predictive collision timeline", | |
| height=320, | |
| margin=dict(l=20, r=20, t=50, b=20), | |
| paper_bgcolor="rgba(0,0,0,0)", | |
| plot_bgcolor="rgba(15,23,42,0.1)", | |
| xaxis_title="Time (s)", | |
| yaxis_title="Probability", | |
| yaxis=dict(range=[0, 1]), | |
| ) | |
| return figure | |
| def make_badas_heatmap(badas_result): | |
| series = (badas_result or {}).get("prediction_series") or [] | |
| if not series: | |
| return go.Figure() | |
| df = pd.DataFrame(series) | |
| z = [df["probability"].tolist()] | |
| x = [f"{time_sec:.1f}s" for time_sec in df["time_sec"].tolist()] | |
| figure = go.Figure( | |
| data=go.Heatmap( | |
| z=z, | |
| x=x, | |
| y=["BADAS risk"], | |
| colorscale=[[0.0, "#052e16"], [0.35, "#166534"], [0.6, "#f59e0b"], [1.0, "#dc2626"]], | |
| zmin=0.0, | |
| zmax=1.0, | |
| ) | |
| ) | |
| figure.update_layout(title="BADAS risk heatmap", height=180, margin=dict(l=20, r=20, t=50, b=20)) | |
| return figure | |
| def make_reason_coverage_heatmap(reason_result): | |
| frame_metadata = (reason_result or {}).get("frame_metadata") or {} | |
| timestamps = frame_metadata.get("sampled_timestamps_sec") or [] | |
| if not timestamps: | |
| return go.Figure() | |
| bbox_count = max(1, int((reason_result or {}).get("bbox_count") or 0)) | |
| z = [[bbox_count for _ in timestamps]] | |
| figure = go.Figure( | |
| data=go.Heatmap( | |
| z=z, | |
| x=[f"{timestamp:.1f}s" for timestamp in timestamps], | |
| y=["Reason coverage"], | |
| colorscale="Blues", | |
| ) | |
| ) | |
| figure.update_layout(title="Reason sampled-frame coverage", height=180, margin=dict(l=20, r=20, t=50, b=20)) | |
| return figure | |
| def make_risk_gauge(reason_result): | |
| risk_score = (reason_result or {}).get("risk_score") or 0 | |
| figure = go.Figure( | |
| go.Indicator( | |
| mode="gauge+number", | |
| value=risk_score, | |
| number={"suffix": "/5"}, | |
| title={"text": "Cosmos Reason 2 risk score"}, | |
| gauge={ | |
| "axis": {"range": [0, 5]}, | |
| "bar": {"color": "#ef4444" if risk_score >= 4 else "#f59e0b" if risk_score >= 3 else "#22c55e"}, | |
| "steps": [ | |
| {"range": [0, 2], "color": "rgba(34,197,94,0.20)"}, | |
| {"range": [2, 4], "color": "rgba(245,158,11,0.20)"}, | |
| {"range": [4, 5], "color": "rgba(239,68,68,0.24)"}, | |
| ], | |
| }, | |
| ) | |
| ) | |
| figure.update_layout(height=280, margin=dict(l=20, r=20, t=50, b=20)) | |
| return figure | |
| def make_artifact_figure(artifacts): | |
| labels = ["Clip", "BBox", "Risk", "GIF", "Predict"] | |
| values = [ | |
| 1 if artifacts.get("extracted_clip") else 0, | |
| 1 if artifacts.get("bbox_image") else 0, | |
| 1 if artifacts.get("risk_image") else 0, | |
| 1 if artifacts.get("overlay_gif") else 0, | |
| 1 if any(key.startswith("predict_") and artifacts.get(key) for key in artifacts.keys()) else 0, | |
| ] | |
| colors = ["#22c55e" if value else "#475569" for value in values] | |
| figure = go.Figure(go.Bar(x=labels, y=values, marker_color=colors, text=["ready" if v else "missing" for v in values], textposition="outside")) | |
| figure.update_layout(title="Artifact readiness", height=240, yaxis=dict(range=[0, 1.2], showticklabels=False), margin=dict(l=20, r=20, t=50, b=20)) | |
| return figure | |
| def format_html_cards(badas_result, reason_result): | |
| overview = (pipeline_payload or {}).get("overview") or build_pipeline_overview(badas_result, reason_result) | |
| incident_bg, incident_accent, incident_label = badge_palette(reason_result.get("incident_type"), "incident") | |
| severity_bg, severity_accent, severity_label = badge_palette(reason_result.get("severity_label"), "severity") | |
| predict_modes = ", ".join((predict_payload or {}).get("modes") or []) if predict_payload else "Not run" | |
| return f""" | |
| ### Run summary | |
| - **Collision gate triggered:** `{overview.get('collision_gate_triggered')}` | |
| - **Alert time:** `{overview.get('alert_time_sec')}` | |
| - **BADAS confidence:** `{overview.get('alert_confidence')}` | |
| - **Incident:** `{incident_label}` | |
| - **Severity:** `{severity_label}` | |
| - **Reason risk score:** `{overview.get('reason_risk_score')}` | |
| - **Predict modes:** `{predict_modes}` | |
| <div style="display:flex; gap:0.9rem; flex-wrap:wrap; margin-top:0.75rem;"> | |
| <div style="background:{incident_bg}; border:1px solid {incident_accent}; border-radius:999px; padding:0.55rem 0.9rem; color:#f8fafc; font-weight:700;">Incident: {incident_label}</div> | |
| <div style="background:{severity_bg}; border:1px solid {severity_accent}; border-radius:999px; padding:0.55rem 0.9rem; color:#f8fafc; font-weight:700;">Severity: {severity_label}</div> | |
| </div> | |
| """ | |
| def format_reason_html(reason_result): | |
| if not reason_result: | |
| return "Run BADAS + Reason to populate the narrative panel." | |
| validation = (reason_result or {}).get("validation") or {} | |
| validation_flags = validation.get("flags") or {} | |
| fallback_override = (reason_result or {}).get("fallback_override") or {} | |
| lines = [ | |
| "### Cosmos Reason 2 narrative", | |
| f"**Scene summary:** {(reason_result or {}).get('scene_summary') or 'N/A'}", | |
| f"**At-risk agent:** {(reason_result or {}).get('at_risk_agent') or 'N/A'}", | |
| f"**Explanation:** {(reason_result or {}).get('explanation') or 'N/A'}", | |
| f"**Time to impact:** {(reason_result or {}).get('time_to_impact')}", | |
| f"**Critical risk time:** {(reason_result or {}).get('critical_risk_time')}", | |
| "", | |
| "```text", | |
| (reason_result or {}).get("text") or "No Reason output captured.", | |
| "```", | |
| ] | |
| if not validation.get("is_reliable", True): | |
| lines.append("- **Warning:** Reason output was flagged as unreliable against BADAS evidence.") | |
| if validation_flags: | |
| lines.append(f"- **Validation flags:** `{json.dumps(validation_flags)}`") | |
| if fallback_override.get("applied"): | |
| lines.append(f"- **Fallback override:** `{json.dumps(fallback_override)}`") | |
| return "\n".join(lines) | |
| def empty_outputs(status_message="Ready."): | |
| empty_json = {} | |
| empty_plot = go.Figure() | |
| return ( | |
| status_message, | |
| "", | |
| "Run the cached sample or upload a video to execute BADAS + Reason, then trigger Predict manually.", | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| empty_plot, | |
| empty_plot, | |
| empty_plot, | |
| empty_plot, | |
| empty_plot, | |
| "Run BADAS + Reason to populate the narrative panel.", | |
| empty_json, | |
| empty_json, | |
| empty_json, | |
| empty_json, | |
| None, | |
| None, | |
| None, | |
| ) | |
| def build_outputs(pipeline_payload, logs, preview_video_path, predict_payload=None): | |
| iteration = ((pipeline_payload.get("iterations") or [{}])[-1]) | |
| steps = iteration.get("steps") or {} | |
| badas_result = (steps.get("badas") or {}).get("result") or {} | |
| reason_result = (steps.get("reason") or {}).get("result") or {} | |
| artifacts = (pipeline_payload.get("artifacts") or {}) | |
| predict_payload = predict_payload if predict_payload is not None else (pipeline_payload.get("predict") or {}) | |
| first_predict_result = next(iter((predict_payload.get("results") or {}).values()), {}) if predict_payload else {} | |
| prevented_video = ((predict_payload.get("results") or {}).get("prevented_continuation") or {}).get("output_video") | |
| observed_video = ((predict_payload.get("results") or {}).get("observed_continuation") or {}).get("output_video") | |
| return ( | |
| "Pipeline completed." if pipeline_payload else "No pipeline payload available.", | |
| logs, | |
| format_html_cards(badas_result, reason_result), | |
| pipeline_payload, | |
| preview_video_path, | |
| artifacts.get("extracted_clip"), | |
| artifacts.get("badas_gradient_saliency"), | |
| artifacts.get("bbox_image"), | |
| artifacts.get("risk_image"), | |
| artifacts.get("overlay_gif"), | |
| artifacts.get("badas_frame_strip"), | |
| artifacts.get("reason_frame_strip"), | |
| make_badas_figure(badas_result), | |
| make_badas_heatmap(badas_result), | |
| make_reason_coverage_heatmap(reason_result), | |
| make_risk_gauge(reason_result), | |
| make_artifact_figure(artifacts), | |
| format_reason_html(reason_result), | |
| badas_result, | |
| reason_result, | |
| pipeline_payload, | |
| predict_payload or {}, | |
| first_predict_result.get("conditioning_clip"), | |
| prevented_video, | |
| observed_video, | |
| ) | |
| def on_upload(uploaded_file): | |
| if not uploaded_file: | |
| return None, None, "No uploaded file selected." | |
| cached_path = cache_uploaded_video(uploaded_file) | |
| return cached_path, cached_path, f"Cached uploaded video: `{Path(cached_path).name}`" | |
| def use_sample_video(): | |
| sample_path = ensure_sample_video() | |
| return sample_path, sample_path, f"Using cached sample: `{Path(sample_path).name}`" | |
| def bootstrap_space(): | |
| sample_path = ensure_sample_video() | |
| status_lines = [f"Sample cached at: {sample_path}"] | |
| return sample_path, sample_path, "Space loaded. Click 'Run BADAS + Reason' to begin.", "\n".join(status_lines) | |
| def warmup_models(): | |
| return preload_runtime( | |
| preload_badas=True, | |
| preload_reason=True, | |
| preload_predict=True, | |
| ) | |
| def run_pipeline_action(input_video_path): | |
| if not input_video_path: | |
| raise gr.Error("Upload a video or choose the sample clip first.") | |
| result = run_pipeline(str(input_video_path), include_predict=False) | |
| return build_outputs(result["pipeline_payload"], result["logs"], str(input_video_path), result.get("predict_payload")) | |
| def run_predict_action(pipeline_payload, input_video_path, selection): | |
| if not pipeline_payload: | |
| raise gr.Error("Run BADAS + Reason before Predict.") | |
| try: | |
| predict_payload, merged_payload = run_predict_only(pipeline_payload, selection=selection, predict_model_name=PREDICT_MODEL_NAME) | |
| return build_outputs(merged_payload, "Predict completed.", input_video_path, predict_payload) | |
| except Exception as e: | |
| raise gr.Error(f"Predict failed: {e}") | |
| def safe_warmup_message(): | |
| try: | |
| return warmup_models() | |
| except Exception: | |
| return traceback.format_exc() | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue").set(body_background_fill="#0f172a", body_text_color="#f1f5f9"), title="Cosmos Sentinel") as demo: | |
| pipeline_state = gr.State(None) | |
| input_video_state = gr.State(None) | |
| gr.Markdown("# 🚦 Cosmos Sentinel") | |
| gr.Markdown( | |
| "A Gradio-first Hugging Face Space for the full Cosmos Sentinel pipeline: BADAS collision gating, Cosmos Reason 2 narrative analysis, and manual Cosmos Predict continuation rollouts on ZeroGPU-backed hardware." | |
| ) | |
| with gr.Row(): | |
| startup_status = gr.Markdown("Booting Space...") | |
| warmup_log = gr.Textbox(label="Warmup / preload log", lines=5) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| upload = gr.File(label="Upload MP4 footage", file_types=[".mp4"]) | |
| sample_button = gr.Button("Use cached sample video") | |
| warmup_button = gr.Button("Warm up BADAS + Reason + Predict") | |
| run_button = gr.Button("Run BADAS + Reason", variant="primary") | |
| predict_selection = gr.Dropdown( | |
| choices=["both", "prevented_continuation", "observed_continuation"], | |
| value="both", | |
| label="Cosmos Predict rollout set", | |
| ) | |
| predict_button = gr.Button("Run Cosmos Predict") | |
| selected_video_path = gr.Textbox(label="Active input video path") | |
| with gr.Column(scale=1): | |
| preview_video = gr.Video(label="Preview") | |
| with gr.Tabs(): | |
| with gr.Tab("Run Outputs"): | |
| status_markdown = gr.Markdown("Ready.") | |
| summary_html = gr.HTML("<div style='color: #94a3b8; font-style: italic;'>Run the cached sample or upload a video to execute the pipeline.</div>") | |
| pipeline_logs = gr.Textbox(label="Pipeline logs", lines=18) | |
| reason_html = gr.HTML("<div style='color: #94a3b8; font-style: italic;'>Run BADAS + Reason to populate the narrative panel.</div>") | |
| with gr.Row(): | |
| extracted_clip = gr.Video(label="BADAS-focused clip") | |
| conditioning_clip = gr.Video(label="Predict conditioning clip") | |
| with gr.Row(): | |
| prevented_video = gr.Video(label="Predict prevented continuation") | |
| observed_video = gr.Video(label="Predict observed continuation") | |
| with gr.Tab("Visual Diagnostics"): | |
| with gr.Row(): | |
| badas_plot = gr.Plot(label="BADAS timeline") | |
| risk_gauge = gr.Plot(label="Reason risk gauge") | |
| with gr.Row(): | |
| badas_heatmap = gr.Plot(label="BADAS heatmap") | |
| reason_heatmap = gr.Plot(label="Reason coverage") | |
| artifact_plot = gr.Plot(label="Artifact readiness") | |
| with gr.Row(): | |
| badas_gradient = gr.Image(label="BADAS gradient saliency") | |
| bbox_image = gr.Image(label="Reason bounding-box visualization") | |
| with gr.Row(): | |
| risk_image = gr.Image(label="Reason risk visualization") | |
| overlay_gif = gr.Image(label="Annotated overlay GIF") | |
| badas_frame_strip = gr.Image(label="BADAS frame strip") | |
| reason_frame_strip = gr.Image(label="Reason frame strip") | |
| with gr.Tab("Structured Payloads"): | |
| badas_json = gr.JSON(label="BADAS payload") | |
| reason_json = gr.JSON(label="Reason payload") | |
| predict_json = gr.JSON(label="Predict payload") | |
| pipeline_json = gr.JSON(label="Pipeline payload") | |
| with gr.Tab("Supplementary Materials"): | |
| gr.Markdown(SUPPLEMENTARY_MARKDOWN) | |
| sample_button.click( | |
| fn=use_sample_video, | |
| inputs=None, | |
| outputs=[input_video_state, preview_video, startup_status], | |
| ).then( | |
| fn=lambda path: path, | |
| inputs=input_video_state, | |
| outputs=selected_video_path, | |
| ) | |
| upload.upload( | |
| fn=on_upload, | |
| inputs=upload, | |
| outputs=[input_video_state, preview_video, startup_status], | |
| ).then( | |
| fn=lambda path: path, | |
| inputs=input_video_state, | |
| outputs=selected_video_path, | |
| ) | |
| warmup_button.click(fn=safe_warmup_message, inputs=None, outputs=warmup_log) | |
| run_button.click( | |
| fn=run_pipeline_action, | |
| inputs=input_video_state, | |
| outputs=[ | |
| status_markdown, | |
| pipeline_logs, | |
| summary_html, | |
| pipeline_state, | |
| preview_video, | |
| extracted_clip, | |
| badas_gradient, | |
| bbox_image, | |
| risk_image, | |
| overlay_gif, | |
| badas_frame_strip, | |
| reason_frame_strip, | |
| badas_plot, | |
| badas_heatmap, | |
| reason_heatmap, | |
| risk_gauge, | |
| artifact_plot, | |
| reason_html, | |
| badas_json, | |
| reason_json, | |
| pipeline_json, | |
| predict_json, | |
| conditioning_clip, | |
| prevented_video, | |
| observed_video, | |
| ], | |
| ) | |
| predict_button.click( | |
| fn=run_predict_action, | |
| inputs=[pipeline_state, input_video_state, predict_selection], | |
| outputs=[ | |
| status_markdown, | |
| pipeline_logs, | |
| summary_html, | |
| pipeline_state, | |
| preview_video, | |
| extracted_clip, | |
| badas_gradient, | |
| bbox_image, | |
| risk_image, | |
| overlay_gif, | |
| badas_frame_strip, | |
| reason_frame_strip, | |
| badas_plot, | |
| badas_heatmap, | |
| reason_heatmap, | |
| risk_gauge, | |
| artifact_plot, | |
| reason_html, | |
| badas_json, | |
| reason_json, | |
| pipeline_json, | |
| predict_json, | |
| conditioning_clip, | |
| prevented_video, | |
| observed_video, | |
| ], | |
| ) | |
| demo.load( | |
| fn=bootstrap_space, | |
| inputs=None, | |
| outputs=[input_video_state, preview_video, startup_status, warmup_log], | |
| ).then( | |
| fn=lambda path: path, | |
| inputs=input_video_state, | |
| outputs=selected_video_path, | |
| ) | |
| demo.queue(default_concurrency_limit=1) | |
| demo.launch(ssr_mode=False) | |