Spaces:
Running on Zero
Running on Zero
| import torch | |
| import cv2 | |
| import os | |
| import sys | |
| import json | |
| import numpy as np | |
| from functools import lru_cache | |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| if hf_token and "HF_TOKEN" not in os.environ: | |
| os.environ["HF_TOKEN"] = hf_token | |
| LOCAL_BADAS_MODEL_PATH = os.path.join(os.path.dirname(__file__), "nexar_data", "badas_model") | |
| def load_badas_loader(): | |
| if os.path.isdir(LOCAL_BADAS_MODEL_PATH) and LOCAL_BADAS_MODEL_PATH not in sys.path: | |
| sys.path.append(LOCAL_BADAS_MODEL_PATH) | |
| try: | |
| from badas_loader import load_badas_model as local_load_badas_model | |
| return local_load_badas_model | |
| except Exception: | |
| pass | |
| try: | |
| from badas import load_badas_model as package_load_badas_model | |
| return package_load_badas_model | |
| except Exception: | |
| pass | |
| from huggingface_hub import hf_hub_download | |
| loader_path = hf_hub_download( | |
| repo_id=os.environ.get("BADAS_MODEL_REPO", "nexar-ai/badas-open"), | |
| filename="badas_loader.py", | |
| token=hf_token, | |
| ) | |
| loader_parent = os.path.dirname(loader_path) | |
| if loader_parent not in sys.path: | |
| sys.path.insert(0, loader_parent) | |
| from badas_loader import load_badas_model as hub_load_badas_model | |
| return hub_load_badas_model | |
| def get_badas_model(): | |
| return load_badas_loader()() | |
| def extract_window_frames(video_path, end_time_sec, target_fps, frame_count): | |
| cap = cv2.VideoCapture(video_path) | |
| original_fps = cap.get(cv2.CAP_PROP_FPS) or 0.0 | |
| if original_fps <= 0: | |
| cap.release() | |
| return [] | |
| duration_sec = max(0.0, float(frame_count - 1) / float(target_fps)) if target_fps else 0.0 | |
| start_time_sec = max(0.0, float(end_time_sec) - duration_sec) | |
| timestamps = [start_time_sec + (idx / float(target_fps)) for idx in range(frame_count)] if target_fps else [float(end_time_sec)] | |
| frames = [] | |
| for timestamp in timestamps: | |
| frame_index = max(0, int(round(timestamp * original_fps))) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| frames.append(frame) | |
| cap.release() | |
| return frames | |
| def prepare_video_tensor(model, frames_bgr): | |
| if not frames_bgr: | |
| return None | |
| frames_rgb = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames_bgr] | |
| if model.processor: | |
| try: | |
| inputs = model.processor(frames_rgb, return_tensors="pt") | |
| if "pixel_values_videos" in inputs: | |
| return inputs["pixel_values_videos"].squeeze(0) | |
| if "pixel_values" in inputs: | |
| return inputs["pixel_values"].squeeze(0) | |
| return list(inputs.values())[0].squeeze(0) | |
| except Exception: | |
| pass | |
| if model.transform: | |
| transformed_frames = [model.transform(image=frame)["image"] for frame in frames_rgb] | |
| return torch.stack(transformed_frames) | |
| frames_tensor = torch.from_numpy(np.stack(frames_rgb).transpose(0, 3, 1, 2)).float() / 255.0 | |
| return frames_tensor | |
| def build_gradient_saliency_strip(model, video_path, focus_time_sec, sampled_fps): | |
| frames_bgr = extract_window_frames(video_path, focus_time_sec, sampled_fps, model.frame_count) | |
| if not frames_bgr: | |
| return None | |
| video_tensor = prepare_video_tensor(model, frames_bgr) | |
| if video_tensor is None: | |
| return None | |
| input_tensor = video_tensor.unsqueeze(0).to(model.device) | |
| input_tensor.requires_grad_(True) | |
| model.model.zero_grad(set_to_none=True) | |
| logits = model.model(input_tensor) | |
| if logits.ndim != 2 or logits.shape[-1] < 2: | |
| return None | |
| positive_logit = logits[0, 1] | |
| positive_logit.backward() | |
| gradients = input_tensor.grad.detach().abs().mean(dim=2).squeeze(0).cpu().numpy() | |
| overlay_frames = [] | |
| selected_indices = np.linspace(0, len(frames_bgr) - 1, min(4, len(frames_bgr)), dtype=int) | |
| for frame_index in selected_indices: | |
| frame = frames_bgr[int(frame_index)].copy() | |
| grad_map = gradients[int(frame_index)] | |
| if grad_map.max() > 0: | |
| grad_map = grad_map / grad_map.max() | |
| heatmap = np.uint8(255 * grad_map) | |
| heatmap = cv2.resize(heatmap, (frame.shape[1], frame.shape[0])) | |
| heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
| rendered = cv2.addWeighted(frame, 0.6, heatmap, 0.4, 0) | |
| cv2.putText( | |
| rendered, | |
| f"BADAS gradient saliency | t={focus_time_sec:.2f}s", | |
| (20, 36), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.8, | |
| (255, 255, 255), | |
| 2, | |
| ) | |
| overlay_frames.append(rendered) | |
| if not overlay_frames: | |
| return None | |
| resized_frames = [] | |
| resize_height = 220 | |
| for frame in overlay_frames: | |
| height, width = frame.shape[:2] | |
| scale = resize_height / max(height, 1) | |
| resized_frames.append(cv2.resize(frame, (max(1, int(round(width * scale))), resize_height))) | |
| strip = cv2.hconcat(resized_frames) | |
| output_path = "badas_gradient_saliency.png" | |
| cv2.imwrite(output_path, strip) | |
| return output_path if os.path.exists(output_path) else None | |
| def summarize_prediction_window(prediction_series, sampled_fps): | |
| if not prediction_series: | |
| return { | |
| 'window_count': 0, | |
| 'max_average_probability': None, | |
| 'peak_window_start_time': None, | |
| 'peak_window_end_time': None, | |
| } | |
| window_size = max(1, int(round(sampled_fps))) | |
| best_average = None | |
| best_start_idx = 0 | |
| probabilities = [item['probability'] for item in prediction_series] | |
| for start_idx in range(0, max(1, len(probabilities) - window_size + 1)): | |
| window = probabilities[start_idx:start_idx + window_size] | |
| if not window: | |
| continue | |
| average = float(np.mean(window)) | |
| if best_average is None or average > best_average: | |
| best_average = average | |
| best_start_idx = start_idx | |
| end_idx = min(len(prediction_series) - 1, best_start_idx + window_size - 1) | |
| return { | |
| 'window_count': max(0, len(probabilities) - window_size + 1), | |
| 'max_average_probability': best_average, | |
| 'peak_window_start_time': float(prediction_series[best_start_idx]['time_sec']), | |
| 'peak_window_end_time': float(prediction_series[end_idx]['time_sec']), | |
| } | |
| def summarize_threshold_runs(collision_frames, sampled_fps): | |
| if not collision_frames: | |
| return { | |
| 'threshold_crossing_count': 0, | |
| 'threshold_crossing_times': [], | |
| 'contiguous_alert_runs': [], | |
| 'longest_alert_run_frames': 0, | |
| 'longest_alert_run_sec': 0.0, | |
| } | |
| threshold_crossing_times = [float(frame_idx / sampled_fps) for frame_idx, _ in collision_frames] | |
| runs = [] | |
| run_start = collision_frames[0][0] | |
| run_probs = [float(collision_frames[0][1])] | |
| previous_frame = collision_frames[0][0] | |
| for frame_idx, prob in collision_frames[1:]: | |
| if frame_idx == previous_frame + 1: | |
| run_probs.append(float(prob)) | |
| else: | |
| runs.append((run_start, previous_frame, run_probs)) | |
| run_start = frame_idx | |
| run_probs = [float(prob)] | |
| previous_frame = frame_idx | |
| runs.append((run_start, previous_frame, run_probs)) | |
| contiguous_alert_runs = [ | |
| { | |
| 'start_frame': int(start_frame), | |
| 'end_frame': int(end_frame), | |
| 'start_time': float(start_frame / sampled_fps), | |
| 'end_time': float(end_frame / sampled_fps), | |
| 'duration_frames': int(end_frame - start_frame + 1), | |
| 'duration_sec': float((end_frame - start_frame + 1) / sampled_fps), | |
| 'max_probability': float(max(probabilities)), | |
| 'mean_probability': float(np.mean(probabilities)), | |
| } | |
| for start_frame, end_frame, probabilities in runs | |
| ] | |
| longest_run = max(contiguous_alert_runs, key=lambda item: item['duration_frames']) | |
| return { | |
| 'threshold_crossing_count': int(len(collision_frames)), | |
| 'threshold_crossing_times': threshold_crossing_times, | |
| 'contiguous_alert_runs': contiguous_alert_runs, | |
| 'longest_alert_run_frames': int(longest_run['duration_frames']), | |
| 'longest_alert_run_sec': float(longest_run['duration_sec']), | |
| } | |
| def run_badas_detector(video_path, confidence_threshold=0.5): | |
| """Run BADAS-Open collision detection on video""" | |
| print("Loading BADAS-Open model...") | |
| model = get_badas_model() | |
| model_info = model.get_model_info() | |
| # Run prediction on entire video | |
| print(f"Analyzing video: {video_path}") | |
| predictions = np.asarray(model.predict(video_path), dtype=np.float32) | |
| # Get video properties | |
| cap = cv2.VideoCapture(video_path) | |
| original_fps = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap.release() | |
| sampled_fps = float(model_info.get('target_fps') or original_fps) | |
| video_duration_sec = float(total_frames / original_fps) if original_fps else 0.0 | |
| valid_mask = np.isfinite(predictions) | |
| valid_indices = np.flatnonzero(valid_mask) | |
| valid_predictions = predictions[valid_mask] | |
| nan_count = int((~valid_mask).sum()) | |
| prediction_series = [ | |
| { | |
| 'sampled_frame': int(idx), | |
| 'original_frame_approx': int(round((idx / sampled_fps) * original_fps)), | |
| 'time_sec': float(idx / sampled_fps), | |
| 'probability': float(predictions[idx]), | |
| } | |
| for idx in valid_indices | |
| ] | |
| top_predictions = sorted( | |
| prediction_series, | |
| key=lambda item: item['probability'], | |
| reverse=True, | |
| )[:5] | |
| saliency_focus_time = float(top_predictions[0]['time_sec']) if top_predictions else None | |
| gradient_saliency_image = None | |
| if saliency_focus_time is not None: | |
| try: | |
| gradient_saliency_image = build_gradient_saliency_strip(model, video_path, saliency_focus_time, sampled_fps) | |
| if gradient_saliency_image: | |
| print(f"Gradient saliency artifact saved: {gradient_saliency_image}") | |
| except Exception as exc: | |
| print(f"Warning: failed to build BADAS gradient saliency artifact: {exc}") | |
| prediction_window_summary = summarize_prediction_window(prediction_series, sampled_fps) | |
| valid_prediction_summary = { | |
| 'min': float(valid_predictions.min()) if len(valid_predictions) > 0 else None, | |
| 'max': float(valid_predictions.max()) if len(valid_predictions) > 0 else None, | |
| 'mean': float(valid_predictions.mean()) if len(valid_predictions) > 0 else None, | |
| 'median': float(np.median(valid_predictions)) if len(valid_predictions) > 0 else None, | |
| 'std': float(valid_predictions.std()) if len(valid_predictions) > 0 else None, | |
| 'p90': float(np.percentile(valid_predictions, 90)) if len(valid_predictions) > 0 else None, | |
| 'p95': float(np.percentile(valid_predictions, 95)) if len(valid_predictions) > 0 else None, | |
| } | |
| print(f"Model info: {json.dumps(model_info)}") | |
| print(f"Video info: {total_frames} frames at {original_fps:.2f} FPS") | |
| print(f"Video duration: {video_duration_sec:.2f}s") | |
| print(f"Model sampling FPS: {sampled_fps:.2f}") | |
| print(f"Prediction array length: {len(predictions)}") | |
| print(f"Valid predictions: {len(valid_predictions)} | NaN warmup frames: {nan_count}") | |
| if len(valid_predictions) > 0: | |
| print(f"Prediction range: {valid_prediction_summary['min']:.3f} - {valid_prediction_summary['max']:.3f}") | |
| print(f"Mean prediction: {valid_prediction_summary['mean']:.3f}") | |
| print(f"Median prediction: {valid_prediction_summary['median']:.3f}") | |
| print(f"Prediction std: {valid_prediction_summary['std']:.3f}") | |
| print(f"P90/P95: {valid_prediction_summary['p90']:.3f} / {valid_prediction_summary['p95']:.3f}") | |
| first_valid_idx = int(valid_indices[0]) | |
| print(f"First valid predictive frame: sampled frame {first_valid_idx} ({first_valid_idx / sampled_fps:.2f}s)") | |
| print("Top prediction frames:") | |
| for item in top_predictions: | |
| frame_idx = item['sampled_frame'] | |
| prob = item['probability'] | |
| time_sec = item['time_sec'] | |
| original_frame = item['original_frame_approx'] | |
| print( | |
| f" sampled_frame={frame_idx} original_frame≈{original_frame} time={time_sec:.2f}s prob={prob:.2%}" | |
| ) | |
| if prediction_window_summary['max_average_probability'] is not None: | |
| print( | |
| f"Peak 1-second window: {prediction_window_summary['peak_window_start_time']:.2f}s" | |
| f" - {prediction_window_summary['peak_window_end_time']:.2f}s" | |
| f" avg={prediction_window_summary['max_average_probability']:.2%}" | |
| ) | |
| else: | |
| print("Prediction range: no valid predictions produced") | |
| # Find frames with high collision probability (lower threshold for better detection) | |
| collision_frames = [] | |
| for frame_idx, prob in enumerate(predictions): | |
| if not np.isfinite(prob): | |
| continue | |
| if prob > confidence_threshold: | |
| collision_frames.append((frame_idx, prob)) | |
| time_sec = frame_idx / sampled_fps | |
| original_frame = int(round(time_sec * original_fps)) | |
| print( | |
| f"⚠️ Collision risk at sampled frame {frame_idx} " | |
| f"(original frame ≈ {original_frame}, {time_sec:.2f}s): {prob:.2%}" | |
| ) | |
| threshold_summary = summarize_threshold_runs(collision_frames, sampled_fps) | |
| if threshold_summary['contiguous_alert_runs']: | |
| print("Threshold-crossing runs:") | |
| for run in threshold_summary['contiguous_alert_runs']: | |
| print( | |
| f" {run['start_time']:.2f}s - {run['end_time']:.2f}s" | |
| f" duration={run['duration_sec']:.2f}s max={run['max_probability']:.2%}" | |
| ) | |
| if collision_frames: | |
| # Return earliest high-risk frame | |
| earliest_frame, highest_prob = min(collision_frames, key=lambda x: x[0]) # Earliest frame | |
| alert_time = earliest_frame / sampled_fps | |
| alert_original_frame = int(round(alert_time * original_fps)) | |
| print( | |
| f"🚨 BADAS Alert: Collision detected at {alert_time:.2f}s " | |
| f"(sampled frame {earliest_frame}, original frame ≈ {alert_original_frame}) " | |
| f"with {highest_prob:.2%} confidence" | |
| ) | |
| return { | |
| 'collision_detected': True, | |
| 'alert_frame_sampled': int(earliest_frame), | |
| 'alert_frame_original_approx': alert_original_frame, | |
| 'alert_time': alert_time, | |
| 'confidence': float(highest_prob), | |
| 'threshold': float(confidence_threshold), | |
| 'original_fps': float(original_fps), | |
| 'sampled_fps': sampled_fps, | |
| 'model_info': model_info, | |
| 'video_metadata': { | |
| 'video_path': video_path, | |
| 'total_frames': int(total_frames), | |
| 'original_fps': float(original_fps), | |
| 'sampled_fps': float(sampled_fps), | |
| 'duration_sec': video_duration_sec, | |
| }, | |
| 'prediction_count': int(len(predictions)), | |
| 'valid_prediction_count': int(len(valid_predictions)), | |
| 'nan_warmup_count': nan_count, | |
| 'valid_prediction_min': valid_prediction_summary['min'], | |
| 'valid_prediction_max': valid_prediction_summary['max'], | |
| 'valid_prediction_mean': valid_prediction_summary['mean'], | |
| 'valid_prediction_median': valid_prediction_summary['median'], | |
| 'valid_prediction_std': valid_prediction_summary['std'], | |
| 'valid_prediction_p90': valid_prediction_summary['p90'], | |
| 'valid_prediction_p95': valid_prediction_summary['p95'], | |
| 'first_valid_time': float(valid_indices[0] / sampled_fps) if len(valid_indices) > 0 else None, | |
| 'alert_source': 'threshold_crossing', | |
| 'prediction_window_summary': prediction_window_summary, | |
| 'threshold_summary': threshold_summary, | |
| 'top_predictions': top_predictions, | |
| 'prediction_series': prediction_series, | |
| 'gradient_saliency_image': gradient_saliency_image, | |
| 'gradient_saliency_focus_time': saliency_focus_time, | |
| } | |
| else: | |
| print("⚠️ BADAS completed but no collision detected, using default alert time") | |
| # Return default alert for pipeline continuation | |
| default_frame = next(iter(valid_indices), len(predictions) // 4) | |
| default_time = float(default_frame / sampled_fps) if sampled_fps else 0.0 | |
| default_original_frame = int(round(default_time * original_fps)) | |
| return { | |
| 'collision_detected': False, | |
| 'alert_frame_sampled': int(default_frame), | |
| 'alert_frame_original_approx': default_original_frame, | |
| 'alert_time': default_time, | |
| 'confidence': 0.0, | |
| 'threshold': float(confidence_threshold), | |
| 'original_fps': float(original_fps), | |
| 'sampled_fps': sampled_fps, | |
| 'model_info': model_info, | |
| 'video_metadata': { | |
| 'video_path': video_path, | |
| 'total_frames': int(total_frames), | |
| 'original_fps': float(original_fps), | |
| 'sampled_fps': float(sampled_fps), | |
| 'duration_sec': video_duration_sec, | |
| }, | |
| 'prediction_count': int(len(predictions)), | |
| 'valid_prediction_count': int(len(valid_predictions)), | |
| 'nan_warmup_count': nan_count, | |
| 'valid_prediction_min': valid_prediction_summary['min'], | |
| 'valid_prediction_max': valid_prediction_summary['max'], | |
| 'valid_prediction_mean': valid_prediction_summary['mean'], | |
| 'valid_prediction_median': valid_prediction_summary['median'], | |
| 'valid_prediction_std': valid_prediction_summary['std'], | |
| 'valid_prediction_p90': valid_prediction_summary['p90'], | |
| 'valid_prediction_p95': valid_prediction_summary['p95'], | |
| 'first_valid_time': float(valid_indices[0] / sampled_fps) if len(valid_indices) > 0 else None, | |
| 'alert_source': 'fallback_first_valid_frame', | |
| 'prediction_window_summary': prediction_window_summary, | |
| 'threshold_summary': threshold_summary, | |
| 'top_predictions': top_predictions, | |
| 'prediction_series': prediction_series, | |
| 'gradient_saliency_image': gradient_saliency_image, | |
| 'gradient_saliency_focus_time': saliency_focus_time, | |
| } | |
| if __name__ == "__main__": | |
| # Test on sample video or provided path | |
| video_path = sys.argv[1] if len(sys.argv) > 1 else "./nexar_data/sample_videos/sample_dashcam_2.mp4" | |
| result = run_badas_detector(video_path) | |
| if result['collision_detected']: | |
| print(f"BADAS Alert: Collision detected at {result['alert_time']:.2f}s with {result['confidence']:.2%} confidence") | |
| print("This would be our System 1 trigger for the Pure Cosmos Pipeline") | |
| else: | |
| print("No collision detected by BADAS") | |
| print(f"BADAS_JSON: {json.dumps(result)}") | |