Fix ImportError by replacing MediaMode with WebRtcMode in streamlit_webrtc import and webrtc_streamer configuration
89a037c
verified
| import streamlit as st | |
| import cv2 | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| from collections import deque | |
| from transformers import AutoFeatureExtractor, AutoModelForVideoClassification | |
| from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration, WebRtcMode | |
| # Constants | |
| NUM_FRAMES = 16 | |
| MODEL_NAME = "jatinmehra/Accident-Detection-using-Dashcam" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model_and_extractor(): | |
| extractor = AutoFeatureExtractor.from_pretrained("facebook/timesformer-base-finetuned-k400") | |
| model = AutoModelForVideoClassification.from_pretrained( | |
| MODEL_NAME, | |
| num_labels=2, | |
| ignore_mismatched_sizes=True | |
| ).to(DEVICE) | |
| model.eval() | |
| return extractor, model | |
| extractor, model = load_model_and_extractor() | |
| st.title("Dashcam Accident Predictor") | |
| st.write("**higher score = higher accident probability**") | |
| # Function to run inference on a saved video file | |
| def run_inference_on_video(video_path): | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
| if total_frames <= 0: | |
| st.error("Failed to read video frames.") | |
| return None | |
| # Uniform sampling | |
| indices = np.linspace(0, total_frames-1, NUM_FRAMES, dtype=int) | |
| frames = [] | |
| for idx in indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) | |
| ret, frame = cap.read() | |
| if not ret: | |
| frames.append(np.zeros((224,224,3), dtype=np.uint8)) | |
| else: | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| resized = cv2.resize(rgb, (224,224)) | |
| frames.append(resized) | |
| cap.release() | |
| # Preprocess and predict | |
| inputs = extractor(frames, return_tensors="pt") | |
| pixel_values = inputs['pixel_values'].to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(pixel_values=pixel_values).logits | |
| prob = torch.softmax(outputs, dim=1)[0,1].item() | |
| return prob | |
| # UI Selection | |
| source = st.radio("Choose input source", ("Upload Video", "Webcam")) | |
| if source == "Upload Video": | |
| uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"]) | |
| if uploaded_file is not None: | |
| tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| tfile.write(uploaded_file.read()) | |
| st.video(uploaded_file) | |
| st.write("Running inference...") | |
| score = run_inference_on_video(tfile.name) | |
| if score is not None: | |
| st.success(f"Accident probability: {score:.2f}") | |
| else: | |
| # Webcam stream processing | |
| class AcciTransformer(VideoTransformerBase): | |
| def __init__(self): | |
| self.buffer = deque(maxlen=NUM_FRAMES) | |
| def transform(self, frame): | |
| img = frame.to_ndarray(format="bgr24") | |
| rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| resized = cv2.resize(rgb, (224,224)) | |
| self.buffer.append(resized) | |
| if len(self.buffer) == NUM_FRAMES: | |
| inputs = extractor(list(self.buffer), return_tensors="pt") | |
| pixel_values = inputs['pixel_values'].to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(pixel_values=pixel_values).logits | |
| prob = torch.softmax(outputs, dim=1)[0,1].item() | |
| cv2.putText(img, f"Prob: {prob:.2f}", (10,30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2) | |
| return img | |
| webrtc_streamer( | |
| key="dashcam-webcam", | |
| mode=WebRtcMode.RECVONLY, | |
| rtc_configuration=RTCConfiguration({ | |
| "iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}] | |
| }), | |
| video_transformer_factory=AcciTransformer | |
| ) |