| import logging |
| import logging.handlers |
| import queue |
| import urllib.request |
| from pathlib import Path |
| from typing import List, NamedTuple |
|
|
| try: |
| from typing import Literal |
| except ImportError: |
| from typing_extensions import Literal |
|
|
| import av |
| import cv2 |
| import numpy as np |
| import streamlit as st |
| from aiortc.contrib.media import MediaPlayer |
|
|
| from streamlit_webrtc import ( |
| ClientSettings, |
| VideoTransformerBase, |
| WebRtcMode, |
| webrtc_streamer, |
| ) |
|
|
| HERE = Path(__file__).parent |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| def download_file(url, download_to: Path, expected_size=None): |
| |
| |
| if download_to.exists(): |
| if expected_size: |
| if download_to.stat().st_size == expected_size: |
| return |
| else: |
| st.info(f"{url} is already downloaded.") |
| if not st.button("Download again?"): |
| return |
|
|
| download_to.parent.mkdir(parents=True, exist_ok=True) |
|
|
| |
| weights_warning, progress_bar = None, None |
| try: |
| weights_warning = st.warning("Downloading %s..." % url) |
| progress_bar = st.progress(0) |
| with open(download_to, "wb") as output_file: |
| with urllib.request.urlopen(url) as response: |
| length = int(response.info()["Content-Length"]) |
| counter = 0.0 |
| MEGABYTES = 2.0 ** 20.0 |
| while True: |
| data = response.read(8192) |
| if not data: |
| break |
| counter += len(data) |
| output_file.write(data) |
|
|
| |
| weights_warning.warning( |
| "Downloading %s... (%6.2f/%6.2f MB)" |
| % (url, counter / MEGABYTES, length / MEGABYTES) |
| ) |
| progress_bar.progress(min(counter / length, 1.0)) |
| |
| finally: |
| if weights_warning is not None: |
| weights_warning.empty() |
| if progress_bar is not None: |
| progress_bar.empty() |
|
|
|
|
| WEBRTC_CLIENT_SETTINGS = ClientSettings( |
| rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}, |
| media_stream_constraints={"video": True, "audio": True}, |
| ) |
|
|
|
|
| def main(): |
| st.header("WebRTC demo") |
|
|
| object_detection_page = "Real time object detection (sendrecv)" |
| video_filters_page = ( |
| "Real time video transform with simple OpenCV filters (sendrecv)" |
| ) |
| streaming_page = ( |
| "Consuming media files on server-side and streaming it to browser (recvonly)" |
| ) |
| sendonly_page = "WebRTC is sendonly and images are shown via st.image() (sendonly)" |
| loopback_page = "Simple video loopback (sendrecv)" |
| app_mode = st.sidebar.selectbox( |
| "Choose the app mode", |
| [ |
| object_detection_page, |
| video_filters_page, |
| streaming_page, |
| sendonly_page, |
| loopback_page, |
| ], |
| ) |
| st.subheader(app_mode) |
|
|
| if app_mode == video_filters_page: |
| app_video_filters() |
| elif app_mode == object_detection_page: |
| app_object_detection() |
| elif app_mode == streaming_page: |
| app_streaming() |
| elif app_mode == sendonly_page: |
| app_sendonly() |
| elif app_mode == loopback_page: |
| app_loopback() |
|
|
|
|
| def app_loopback(): |
| """ Simple video loopback """ |
| webrtc_streamer( |
| key="loopback", |
| mode=WebRtcMode.SENDRECV, |
| client_settings=WEBRTC_CLIENT_SETTINGS, |
| video_transformer_factory=None, |
| ) |
|
|
|
|
| def app_video_filters(): |
| """ Video transforms with OpenCV """ |
|
|
| class OpenCVVideoTransformer(VideoTransformerBase): |
| type: Literal["noop", "cartoon", "edges", "rotate"] |
|
|
| def __init__(self) -> None: |
| self.type = "noop" |
|
|
| def transform(self, frame: av.VideoFrame) -> av.VideoFrame: |
| img = frame.to_ndarray(format="bgr24") |
|
|
| if self.type == "noop": |
| pass |
| elif self.type == "cartoon": |
| |
| img_color = cv2.pyrDown(cv2.pyrDown(img)) |
| for _ in range(6): |
| img_color = cv2.bilateralFilter(img_color, 9, 9, 7) |
| img_color = cv2.pyrUp(cv2.pyrUp(img_color)) |
|
|
| |
| img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) |
| img_edges = cv2.adaptiveThreshold( |
| cv2.medianBlur(img_edges, 7), |
| 255, |
| cv2.ADAPTIVE_THRESH_MEAN_C, |
| cv2.THRESH_BINARY, |
| 9, |
| 2, |
| ) |
| img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB) |
|
|
| |
| img = cv2.bitwise_and(img_color, img_edges) |
| elif self.type == "edges": |
| |
| img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR) |
| elif self.type == "rotate": |
| |
| rows, cols, _ = img.shape |
| M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1) |
| img = cv2.warpAffine(img, M, (cols, rows)) |
|
|
| return img |
|
|
| webrtc_ctx = webrtc_streamer( |
| key="opencv-filter", |
| mode=WebRtcMode.SENDRECV, |
| client_settings=WEBRTC_CLIENT_SETTINGS, |
| video_transformer_factory=OpenCVVideoTransformer, |
| async_transform=True, |
| ) |
|
|
| transform_type = st.radio( |
| "Select transform type", ("noop", "cartoon", "edges", "rotate") |
| ) |
| if webrtc_ctx.video_transformer: |
| webrtc_ctx.video_transformer.type = transform_type |
|
|
| st.markdown( |
| "This demo is based on " |
| "https://github.com/aiortc/aiortc/blob/2362e6d1f0c730a0f8c387bbea76546775ad2fe8/examples/server/server.py#L34. " |
| "Many thanks to the project." |
| ) |
|
|
|
|
| def app_object_detection(): |
| """Object detection demo with MobileNet SSD. |
| This model and code are based on |
| https://github.com/robmarkcole/object-detection-app |
| """ |
| MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" |
| MODEL_LOCAL_PATH = HERE / "./models/MobileNetSSD_deploy.caffemodel" |
| PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" |
| PROTOTXT_LOCAL_PATH = HERE / "./models/MobileNetSSD_deploy.prototxt.txt" |
|
|
| CLASSES = [ |
| "background", |
| "aeroplane", |
| "bicycle", |
| "bird", |
| "boat", |
| "bottle", |
| "bus", |
| "car", |
| "cat", |
| "chair", |
| "cow", |
| "diningtable", |
| "dog", |
| "horse", |
| "motorbike", |
| "person", |
| "pottedplant", |
| "sheep", |
| "sofa", |
| "train", |
| "tvmonitor", |
| ] |
| COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3)) |
|
|
| download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564) |
| download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353) |
|
|
| DEFAULT_CONFIDENCE_THRESHOLD = 0.5 |
|
|
| class Detection(NamedTuple): |
| name: str |
| prob: float |
|
|
| class MobileNetSSDVideoTransformer(VideoTransformerBase): |
| confidence_threshold: float |
| result_queue: "queue.Queue[List[Detection]]" |
|
|
| def __init__(self) -> None: |
| self._net = cv2.dnn.readNetFromCaffe( |
| str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH) |
| ) |
| self.confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD |
| self.result_queue = queue.Queue() |
|
|
| def _annotate_image(self, image, detections): |
| |
| (h, w) = image.shape[:2] |
| result: List[Detection] = [] |
| for i in np.arange(0, detections.shape[2]): |
| confidence = detections[0, 0, i, 2] |
|
|
| if confidence > self.confidence_threshold: |
| |
| |
| |
| idx = int(detections[0, 0, i, 1]) |
| box = detections[0, 0, i, 3:7] * np.array([w, h, w, h]) |
| (startX, startY, endX, endY) = box.astype("int") |
|
|
| name = CLASSES[idx] |
| result.append(Detection(name=name, prob=float(confidence))) |
|
|
| |
| label = f"{name}: {round(confidence * 100, 2)}%" |
| cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2) |
| y = startY - 15 if startY - 15 > 15 else startY + 15 |
| cv2.putText( |
| image, |
| label, |
| (startX, y), |
| cv2.FONT_HERSHEY_SIMPLEX, |
| 0.5, |
| COLORS[idx], |
| 2, |
| ) |
| return image, result |
|
|
| def transform(self, frame: av.VideoFrame) -> np.ndarray: |
| image = frame.to_ndarray(format="bgr24") |
| blob = cv2.dnn.blobFromImage( |
| cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5 |
| ) |
| self._net.setInput(blob) |
| detections = self._net.forward() |
| annotated_image, result = self._annotate_image(image, detections) |
|
|
| |
| |
| self.result_queue.put(result) |
|
|
| return annotated_image |
|
|
| webrtc_ctx = webrtc_streamer( |
| key="object-detection", |
| mode=WebRtcMode.SENDRECV, |
| client_settings=WEBRTC_CLIENT_SETTINGS, |
| video_transformer_factory=MobileNetSSDVideoTransformer, |
| async_transform=True, |
| ) |
|
|
| confidence_threshold = st.slider( |
| "Confidence threshold", 0.0, 1.0, DEFAULT_CONFIDENCE_THRESHOLD, 0.05 |
| ) |
| if webrtc_ctx.video_transformer: |
| webrtc_ctx.video_transformer.confidence_threshold = confidence_threshold |
|
|
| if st.checkbox("Show the detected labels", value=True): |
| if webrtc_ctx.state.playing: |
| labels_placeholder = st.empty() |
| |
| |
| |
| |
| |
| if webrtc_ctx.video_transformer: |
| while True: |
| result = webrtc_ctx.video_transformer.result_queue.get() |
| labels_placeholder.table(result) |
|
|
| st.markdown( |
| "This demo uses a model and code from " |
| "https://github.com/robmarkcole/object-detection-app. " |
| "Many thanks to the project." |
| ) |
|
|
|
|
| def app_streaming(): |
| """ Media streamings """ |
| MEDIAFILES = { |
| "big_buck_bunny_720p_2mb.mp4": { |
| "url": "https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_2mb.mp4", |
| "local_file_path": HERE / "data/big_buck_bunny_720p_2mb.mp4", |
| "type": "video", |
| }, |
| "big_buck_bunny_720p_10mb.mp4": { |
| "url": "https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_10mb.mp4", |
| "local_file_path": HERE / "data/big_buck_bunny_720p_10mb.mp4", |
| "type": "video", |
| }, |
| "file_example_MP3_700KB.mp3": { |
| "url": "https://file-examples-com.github.io/uploads/2017/11/file_example_MP3_700KB.mp3", |
| "local_file_path": HERE / "data/file_example_MP3_700KB.mp3", |
| "type": "audio", |
| }, |
| "file_example_MP3_5MG.mp3": { |
| "url": "https://file-examples-com.github.io/uploads/2017/11/file_example_MP3_5MG.mp3", |
| "local_file_path": HERE / "data/file_example_MP3_5MG.mp3", |
| "type": "audio", |
| }, |
| } |
| media_file_label = st.radio( |
| "Select a media file to stream", tuple(MEDIAFILES.keys()) |
| ) |
| media_file_info = MEDIAFILES[media_file_label] |
| download_file(media_file_info["url"], media_file_info["local_file_path"]) |
|
|
| def create_player(): |
| return MediaPlayer(str(media_file_info["local_file_path"])) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| WEBRTC_CLIENT_SETTINGS.update( |
| { |
| "media_stream_constraints": { |
| "video": media_file_info["type"] == "video", |
| "audio": media_file_info["type"] == "audio", |
| } |
| } |
| ) |
|
|
| webrtc_streamer( |
| key=f"media-streaming-{media_file_label}", |
| mode=WebRtcMode.RECVONLY, |
| client_settings=WEBRTC_CLIENT_SETTINGS, |
| player_factory=create_player, |
| ) |
|
|
|
|
| def app_sendonly(): |
| """A sample to use WebRTC in sendonly mode to transfer frames |
| from the browser to the server and to render frames via `st.image`.""" |
| webrtc_ctx = webrtc_streamer( |
| key="loopback", |
| mode=WebRtcMode.SENDONLY, |
| client_settings=WEBRTC_CLIENT_SETTINGS, |
| ) |
|
|
| if webrtc_ctx.video_receiver: |
| image_loc = st.empty() |
| while True: |
| try: |
| frame = webrtc_ctx.video_receiver.get_frame(timeout=1) |
| except queue.Empty: |
| print("Queue is empty. Stop the loop.") |
| webrtc_ctx.video_receiver.stop() |
| break |
|
|
| img_rgb = frame.to_ndarray(format="rgb24") |
| image_loc.image(img_rgb) |
|
|
|
|
| if __name__ == "__main__": |
| logging.basicConfig( |
| format="[%(asctime)s] %(levelname)7s from %(name)s in %(filename)s:%(lineno)d: " |
| "%(message)s", |
| force=True, |
| ) |
|
|
| logger.setLevel(level=logging.DEBUG) |
|
|
| st_webrtc_logger = logging.getLogger("streamlit_webrtc") |
| st_webrtc_logger.setLevel(logging.DEBUG) |
|
|
| main() |
|
|