Spaces:
Running on Zero
Running on Zero
| import hashlib | |
| import json | |
| from functools import lru_cache | |
| from pathlib import Path | |
| import cv2 | |
| def existing_file(path): | |
| return str(path) if path and Path(path).exists() else None | |
| def sanitize_text(value): | |
| return " ".join(str(value or "").strip().split()) | |
| def select_predict_focus_time(badas_context, reason_context): | |
| reason_critical = (reason_context or {}).get("critical_risk_time") | |
| if isinstance(reason_critical, (int, float)): | |
| return float(reason_critical) | |
| prediction_window = (badas_context or {}).get("prediction_window_summary") or {} | |
| peak_start = prediction_window.get("peak_window_start_time") | |
| peak_end = prediction_window.get("peak_window_end_time") | |
| if isinstance(peak_start, (int, float)) and isinstance(peak_end, (int, float)): | |
| return float((peak_start + peak_end) / 2.0) | |
| top_predictions = (badas_context or {}).get("top_predictions") or [] | |
| if top_predictions: | |
| return float(top_predictions[0].get("time_sec") or 0.0) | |
| alert_time = (badas_context or {}).get("alert_time") | |
| if isinstance(alert_time, (int, float)): | |
| return float(alert_time) | |
| return 0.0 | |
| def build_conditioning_window(badas_context, reason_context): | |
| focus_time = select_predict_focus_time(badas_context or {}, reason_context or {}) | |
| prediction_window = (badas_context or {}).get("prediction_window_summary") or {} | |
| peak_start = prediction_window.get("peak_window_start_time") | |
| if isinstance(peak_start, (int, float)): | |
| start_time = max(0.0, min(float(peak_start), focus_time - 0.50)) | |
| else: | |
| start_time = max(0.0, focus_time - 1.0) | |
| frame_spacing_sec = 0.25 | |
| frame_count = 5 | |
| end_time = start_time + frame_spacing_sec * (frame_count - 1) | |
| return { | |
| "focus_time_sec": float(focus_time), | |
| "start_time_sec": float(start_time), | |
| "end_time_sec": float(end_time), | |
| "frame_spacing_sec": float(frame_spacing_sec), | |
| "frame_count": int(frame_count), | |
| } | |
| def build_conditioning_clip(source_video_path, window, output_path): | |
| source_video_path = str(source_video_path) | |
| output_path = Path(output_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| cap = cv2.VideoCapture(source_video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 0.0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| duration_sec = float(total_frames / fps) if fps else 0.0 | |
| timestamps = [] | |
| frames = [] | |
| for idx in range(int(window["frame_count"])): | |
| timestamp = min(duration_sec, float(window["start_time_sec"]) + idx * float(window["frame_spacing_sec"])) | |
| frame_index = max(0, int(round(timestamp * fps))) if fps else 0 | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| timestamps.append(float(timestamp)) | |
| frames.append(frame) | |
| cap.release() | |
| if not frames: | |
| raise RuntimeError("No frames available for Cosmos Predict conditioning clip") | |
| height, width = frames[0].shape[:2] | |
| writer = cv2.VideoWriter(str(output_path), cv2.VideoWriter_fourcc(*"mp4v"), 4.0, (width, height)) | |
| for frame in frames: | |
| writer.write(frame) | |
| writer.release() | |
| return { | |
| "clip_path": str(output_path), | |
| "frame_timestamps_sec": timestamps, | |
| "frame_count": int(len(frames)), | |
| "width": int(width), | |
| "height": int(height), | |
| "fps": 4.0, | |
| } | |
| def build_fallback_conditioning_metadata(fallback_conditioning_path): | |
| clip_path = Path(fallback_conditioning_path) | |
| cap = cv2.VideoCapture(str(clip_path)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 0.0 | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) | |
| cap.release() | |
| return { | |
| "clip_path": str(clip_path), | |
| "frame_timestamps_sec": [], | |
| "frame_count": int(frame_count), | |
| "width": int(width), | |
| "height": int(height), | |
| "fps": float(fps), | |
| } | |
| def infer_preventive_action(reason_context): | |
| text_candidates = [ | |
| sanitize_text((reason_context or {}).get("explanation")), | |
| sanitize_text((reason_context or {}).get("counterfactual_prompt")), | |
| sanitize_text((reason_context or {}).get("scene_summary")), | |
| ] | |
| combined = " ".join(candidate.lower() for candidate in text_candidates if candidate) | |
| if any(token in combined for token in ["brake", "braking", "slow", "slowing", "stop", "stopped"]): | |
| return "the visible braking and speed reduction continue" | |
| if any(token in combined for token in ["yield", "yielding", "gave way"]): | |
| return "the yielding behavior continues and conflict space is cleared" | |
| if any(token in combined for token in ["steer", "steering", "swerve", "lane correction", "turn away"]): | |
| return "the evasive steering correction continues and the vehicles maintain separation" | |
| if ((reason_context or {}).get("incident_type") or "").strip().lower() == "near_miss": | |
| return "the evasive action visible in the near-miss continues and prevents contact" | |
| return "the most plausible evasive action visible in the scene continues and reduces the chance of collision" | |
| def build_predict_prompt(badas_context, reason_context, mode, window): | |
| scene_summary = sanitize_text((reason_context or {}).get("scene_summary")) or "A traffic interaction is developing at a monitored road junction." | |
| incident_type = sanitize_text((reason_context or {}).get("incident_type")) or "unclear" | |
| severity_label = sanitize_text((reason_context or {}).get("severity_label")) or "unknown" | |
| explanation = sanitize_text((reason_context or {}).get("explanation")) | |
| at_risk_agent = sanitize_text((reason_context or {}).get("at_risk_agent")) or "the interacting road users" | |
| alert_time = (badas_context or {}).get("alert_time") | |
| prediction_window = (badas_context or {}).get("prediction_window_summary") or {} | |
| peak_start = prediction_window.get("peak_window_start_time") | |
| peak_end = prediction_window.get("peak_window_end_time") | |
| risk_context_parts = [] | |
| if isinstance(alert_time, (int, float)): | |
| risk_context_parts.append(f"BADAS detected a high-risk interaction near {float(alert_time):.2f}s") | |
| if isinstance(peak_start, (int, float)) and isinstance(peak_end, (int, float)): | |
| risk_context_parts.append(f"the strongest risk window runs from {float(peak_start):.2f}s to {float(peak_end):.2f}s") | |
| risk_context_parts.append(f"Reason classified the event as {incident_type} with {severity_label} severity") | |
| if explanation: | |
| risk_context_parts.append(explanation) | |
| risk_context = "; ".join(risk_context_parts) | |
| base_prompt = [ | |
| f"Observed scene context: {scene_summary}", | |
| f"Risk context: {risk_context}.", | |
| f"Focus on the road users already visible in the conditioning video, especially {at_risk_agent}.", | |
| f"This conditioning clip is centered on the critical interaction around {float(window['focus_time_sec']):.2f}s.", | |
| ] | |
| if mode == "prevented_continuation": | |
| preventive_action = infer_preventive_action(reason_context or {}) | |
| base_prompt.extend([ | |
| f"Counterfactual assumption: {preventive_action}.", | |
| "Task: Generate the next few seconds of physically plausible traffic evolution in which the preventive action continues to hold and the collision is reduced or avoided.", | |
| ]) | |
| else: | |
| base_prompt.append("Task: Generate the next few seconds of physically plausible traffic evolution, preserving the likely immediate continuation of the observed event.") | |
| base_prompt.append("Preserve the same camera viewpoint, traffic layout, and agent identities. Avoid impossible physics, abrupt scene changes, visual glitches, or dramatic cinematic effects.") | |
| prompt = " ".join(base_prompt) | |
| words = prompt.split() | |
| if len(words) > 290: | |
| prompt = " ".join(words[:290]) | |
| return prompt | |
| def build_cache_key(source_video_path, badas_context, reason_context, mode, model_name, conditioning_source): | |
| payload = { | |
| "source_video_path": str(source_video_path), | |
| "mode": mode, | |
| "model_name": model_name, | |
| "conditioning_source": conditioning_source, | |
| "badas": { | |
| "alert_time": (badas_context or {}).get("alert_time"), | |
| "confidence": (badas_context or {}).get("confidence"), | |
| "valid_prediction_max": (badas_context or {}).get("valid_prediction_max"), | |
| "prediction_window_summary": (badas_context or {}).get("prediction_window_summary"), | |
| "top_predictions": ((badas_context or {}).get("top_predictions") or [])[:3], | |
| }, | |
| "reason": { | |
| "incident_type": (reason_context or {}).get("incident_type"), | |
| "severity_label": (reason_context or {}).get("severity_label"), | |
| "critical_risk_time": (reason_context or {}).get("critical_risk_time"), | |
| "scene_summary": (reason_context or {}).get("scene_summary"), | |
| "explanation": (reason_context or {}).get("explanation"), | |
| "counterfactual_prompt": (reason_context or {}).get("counterfactual_prompt"), | |
| }, | |
| } | |
| return hashlib.sha256(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:12] | |
| def get_predict_inference(model_name, output_root_str, disable_guardrails=True): | |
| try: | |
| from cosmos_predict2.config import SetupArguments | |
| from cosmos_predict2.inference import Inference | |
| except ImportError: | |
| raise RuntimeError("Cosmos Predict is not installed in this environment.") | |
| setup_args = SetupArguments( | |
| output_dir=Path(output_root_str), | |
| model=model_name, | |
| keep_going=True, | |
| disable_guardrails=disable_guardrails, | |
| ) | |
| return Inference(setup_args) | |
| def prepare_conditioning_input(source_video_path, badas_context, reason_context, output_root, fallback_conditioning_path=None): | |
| output_root = Path(output_root) | |
| conditioning_window = build_conditioning_window(badas_context, reason_context) | |
| context_cache_key = hashlib.sha256( | |
| json.dumps( | |
| { | |
| "source_video_path": str(source_video_path), | |
| "conditioning_window": conditioning_window, | |
| }, | |
| sort_keys=True, | |
| ).encode("utf-8") | |
| ).hexdigest()[:12] | |
| conditioning_clip_path = output_root / "conditioning" / f"conditioning_{context_cache_key}.mp4" | |
| try: | |
| conditioning_metadata = build_conditioning_clip(source_video_path, conditioning_window, conditioning_clip_path) | |
| return { | |
| "conditioning_source": "context_aware_segment", | |
| "conditioning_window": conditioning_window, | |
| "conditioning_metadata": conditioning_metadata, | |
| "fallback_applied": False, | |
| "fallback_reason": None, | |
| } | |
| except Exception as exc: | |
| fallback_path = existing_file(fallback_conditioning_path) | |
| if not fallback_path: | |
| raise | |
| return { | |
| "conditioning_source": "badas_focus_clip", | |
| "conditioning_window": conditioning_window, | |
| "conditioning_metadata": build_fallback_conditioning_metadata(fallback_path), | |
| "fallback_applied": True, | |
| "fallback_reason": str(exc), | |
| } | |
| def execute_predict_generation(output_root, model_name, sample_name, conditioning_path, prompt): | |
| try: | |
| from cosmos_predict2.config import InferenceArguments | |
| except ImportError: | |
| raise RuntimeError("Cosmos Predict is not installed in this environment.") | |
| inference = get_predict_inference(model_name, str(output_root), True) | |
| inference_args = InferenceArguments( | |
| inference_type="video2world", | |
| name=sample_name, | |
| input_path=Path(conditioning_path), | |
| prompt=prompt, | |
| guidance=6, | |
| num_output_frames=77, | |
| num_steps=20, | |
| ) | |
| output_paths = inference.generate([inference_args], output_root) | |
| return output_paths[0] if output_paths else None | |
| def run_predict_scenario(source_video_path, badas_context=None, reason_context=None, mode="prevented_continuation", model_name="2B/post-trained", output_root="./predict_outputs", force_regenerate=False, fallback_conditioning_path=None): | |
| source_video_path = str(source_video_path) | |
| badas_context = badas_context or {} | |
| reason_context = reason_context or {} | |
| output_root = Path(output_root) | |
| output_root.mkdir(parents=True, exist_ok=True) | |
| conditioning_info = prepare_conditioning_input( | |
| source_video_path, | |
| badas_context, | |
| reason_context, | |
| output_root, | |
| fallback_conditioning_path=fallback_conditioning_path, | |
| ) | |
| conditioning_source = conditioning_info["conditioning_source"] | |
| conditioning_window = conditioning_info["conditioning_window"] | |
| conditioning_metadata = conditioning_info["conditioning_metadata"] | |
| cache_key = build_cache_key(source_video_path, badas_context, reason_context, mode, model_name, conditioning_source) | |
| prompt = build_predict_prompt(badas_context, reason_context, mode, conditioning_window) | |
| sample_name = f"predict_{mode}_{cache_key}" | |
| output_video_path = output_root / f"{sample_name}.mp4" | |
| output_args_path = output_root / f"{sample_name}.json" | |
| if output_video_path.exists() and not force_regenerate: | |
| return { | |
| "success": True, | |
| "cached": True, | |
| "mode": mode, | |
| "model_name": model_name, | |
| "cache_key": cache_key, | |
| "source_video_path": source_video_path, | |
| "conditioning_source": conditioning_source, | |
| "conditioning_clip": existing_file(conditioning_metadata.get("clip_path")), | |
| "conditioning_metadata": conditioning_metadata, | |
| "conditioning_window": conditioning_window, | |
| "fallback_applied": conditioning_info.get("fallback_applied", False), | |
| "fallback_reason": conditioning_info.get("fallback_reason"), | |
| "prompt": prompt, | |
| "output_video": existing_file(output_video_path), | |
| "output_args_json": existing_file(output_args_path), | |
| } | |
| try: | |
| output_video = execute_predict_generation( | |
| output_root, | |
| model_name, | |
| sample_name, | |
| conditioning_metadata["clip_path"], | |
| prompt, | |
| ) | |
| except Exception as exc: | |
| fallback_path = existing_file(fallback_conditioning_path) | |
| if conditioning_source == "badas_focus_clip" or not fallback_path or fallback_path == conditioning_metadata.get("clip_path"): | |
| raise | |
| fallback_conditioning_metadata = build_fallback_conditioning_metadata(fallback_path) | |
| fallback_conditioning_source = "badas_focus_clip" | |
| fallback_cache_key = build_cache_key(source_video_path, badas_context, reason_context, mode, model_name, fallback_conditioning_source) | |
| sample_name = f"predict_{mode}_{fallback_cache_key}" | |
| output_video_path = output_root / f"{sample_name}.mp4" | |
| output_args_path = output_root / f"{sample_name}.json" | |
| if output_video_path.exists() and not force_regenerate: | |
| return { | |
| "success": True, | |
| "cached": True, | |
| "mode": mode, | |
| "model_name": model_name, | |
| "cache_key": fallback_cache_key, | |
| "source_video_path": source_video_path, | |
| "conditioning_source": fallback_conditioning_source, | |
| "conditioning_clip": existing_file(fallback_conditioning_metadata.get("clip_path")), | |
| "conditioning_metadata": fallback_conditioning_metadata, | |
| "conditioning_window": conditioning_window, | |
| "fallback_applied": True, | |
| "fallback_reason": str(exc), | |
| "prompt": prompt, | |
| "output_video": existing_file(output_video_path), | |
| "output_args_json": existing_file(output_args_path), | |
| } | |
| output_video = execute_predict_generation( | |
| output_root, | |
| model_name, | |
| sample_name, | |
| fallback_conditioning_metadata["clip_path"], | |
| prompt, | |
| ) | |
| conditioning_source = fallback_conditioning_source | |
| conditioning_metadata = fallback_conditioning_metadata | |
| cache_key = fallback_cache_key | |
| conditioning_info["fallback_applied"] = True | |
| conditioning_info["fallback_reason"] = str(exc) | |
| return { | |
| "success": bool(output_video), | |
| "cached": False, | |
| "mode": mode, | |
| "model_name": model_name, | |
| "cache_key": cache_key, | |
| "source_video_path": source_video_path, | |
| "conditioning_source": conditioning_source, | |
| "conditioning_clip": existing_file(conditioning_metadata.get("clip_path")), | |
| "conditioning_metadata": conditioning_metadata, | |
| "conditioning_window": conditioning_window, | |
| "fallback_applied": conditioning_info.get("fallback_applied", False), | |
| "fallback_reason": conditioning_info.get("fallback_reason"), | |
| "prompt": prompt, | |
| "output_video": existing_file(output_video), | |
| "output_args_json": existing_file(output_args_path), | |
| } | |
| def run_predict_bundle(source_video_path, badas_context=None, reason_context=None, modes=None, model_name="2B/post-trained", output_root="./predict_outputs", force_regenerate=False, fallback_conditioning_path=None): | |
| modes = modes or ["prevented_continuation", "observed_continuation"] | |
| results = {} | |
| artifacts = {} | |
| for mode in modes: | |
| result = run_predict_scenario( | |
| source_video_path, | |
| badas_context=badas_context, | |
| reason_context=reason_context, | |
| mode=mode, | |
| model_name=model_name, | |
| output_root=output_root, | |
| force_regenerate=force_regenerate, | |
| fallback_conditioning_path=fallback_conditioning_path, | |
| ) | |
| results[mode] = result | |
| if result.get("conditioning_clip") and not artifacts.get("predict_conditioning_clip"): | |
| artifacts["predict_conditioning_clip"] = result.get("conditioning_clip") | |
| if result.get("output_video"): | |
| artifacts[f"predict_{mode}_video"] = result.get("output_video") | |
| first_result = next(iter(results.values()), {}) | |
| return { | |
| "success": any(result.get("success") for result in results.values()), | |
| "source_video_path": str(source_video_path), | |
| "model_name": model_name, | |
| "modes": list(modes), | |
| "results": results, | |
| "artifacts": artifacts, | |
| "fallback_applied": any(result.get("fallback_applied") for result in results.values()), | |
| "fallback_reasons": {mode: result.get("fallback_reason") for mode, result in results.items() if result.get("fallback_reason")}, | |
| "conditioning_source": first_result.get("conditioning_source"), | |
| } | |