Spaces:
Sleeping
Sleeping
| from ultralytics import YOLO | |
| import time | |
| import os | |
| import logging | |
| import tempfile | |
| import av | |
| import cv2 | |
| import numpy as np | |
| import streamlit as st | |
| from streamlit_webrtc import WebRtcMode, webrtc_streamer | |
| from utils.download import download_file | |
| from utils.turn import get_ice_servers | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| # CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS. | |
| # Update below string to set display title of analysis | |
| ANALYSIS_TITLE = "YOLO-8 Object Detection, Pose Estimation, and Action Detection" | |
| # Load the YOLOv8 models | |
| pose_model = YOLO("yolov8n-pose.pt") | |
| object_model = YOLO("yolov8n.pt") | |
| def detect_action(keypoints, prev_keypoints=None): | |
| keypoint_dict = { | |
| 0: "Nose", 1: "Left Eye", 2: "Right Eye", 3: "Left Ear", 4: "Right Ear", | |
| 5: "Left Shoulder", 6: "Right Shoulder", 7: "Left Elbow", 8: "Right Elbow", | |
| 9: "Left Wrist", 10: "Right Wrist", 11: "Left Hip", 12: "Right Hip", | |
| 13: "Left Knee", 14: "Right Knee", 15: "Left Ankle", 16: "Right Ankle" | |
| } | |
| confidence_threshold = 0.5 | |
| movement_threshold = 0.05 | |
| def get_keypoint(idx): | |
| if idx < len(keypoints[0]): | |
| x, y, conf = keypoints[0][idx] | |
| return np.array([x, y]) if conf > confidence_threshold else None | |
| return None | |
| def calculate_angle(a, b, c): | |
| if a is None or b is None or c is None: | |
| return None | |
| ba = a - b | |
| bc = c - b | |
| cosine_angle = np.dot(ba, bc) / \ | |
| (np.linalg.norm(ba) * np.linalg.norm(bc)) | |
| angle = np.arccos(cosine_angle) | |
| return np.degrees(angle) | |
| def calculate_movement(current, previous): | |
| if current is None or previous is None: | |
| return None | |
| return np.linalg.norm(current - previous) | |
| nose = get_keypoint(0) | |
| left_shoulder = get_keypoint(5) | |
| right_shoulder = get_keypoint(6) | |
| left_elbow = get_keypoint(7) | |
| right_elbow = get_keypoint(8) | |
| left_wrist = get_keypoint(9) | |
| right_wrist = get_keypoint(10) | |
| left_hip = get_keypoint(11) | |
| right_hip = get_keypoint(12) | |
| left_knee = get_keypoint(13) | |
| right_knee = get_keypoint(14) | |
| left_ankle = get_keypoint(15) | |
| right_ankle = get_keypoint(16) | |
| if all(kp is None for kp in [nose, left_shoulder, right_shoulder, left_hip, right_hip, left_ankle, right_ankle]): | |
| return "waiting" | |
| # Calculate midpoints | |
| shoulder_midpoint = (left_shoulder + right_shoulder) / \ | |
| 2 if left_shoulder is not None and right_shoulder is not None else None | |
| hip_midpoint = (left_hip + right_hip) / \ | |
| 2 if left_hip is not None and right_hip is not None else None | |
| ankle_midpoint = (left_ankle + right_ankle) / \ | |
| 2 if left_ankle is not None and right_ankle is not None else None | |
| # Calculate angles | |
| spine_angle = calculate_angle( | |
| shoulder_midpoint, hip_midpoint, ankle_midpoint) | |
| left_arm_angle = calculate_angle(left_shoulder, left_elbow, left_wrist) | |
| right_arm_angle = calculate_angle(right_shoulder, right_elbow, right_wrist) | |
| left_leg_angle = calculate_angle(left_hip, left_knee, left_ankle) | |
| right_leg_angle = calculate_angle(right_hip, right_knee, right_ankle) | |
| # Calculate movement | |
| movement = None | |
| if prev_keypoints is not None: | |
| prev_ankle_midpoint = ((prev_keypoints[0][15][:2] + prev_keypoints[0][16][:2]) / 2 | |
| if len(prev_keypoints[0]) > 16 else None) | |
| movement = calculate_movement(ankle_midpoint, prev_ankle_midpoint) | |
| # Detect actions | |
| if spine_angle is not None: | |
| if spine_angle > 160: | |
| if movement is not None and movement > movement_threshold: | |
| if movement > movement_threshold * 3: | |
| return "running" | |
| else: | |
| return "walking" | |
| return "standing" | |
| elif 70 < spine_angle < 110: | |
| return "sitting" | |
| elif spine_angle < 30: | |
| return "lying" | |
| # Detect pointing | |
| if (left_arm_angle is not None and left_arm_angle > 150) or (right_arm_angle is not None and right_arm_angle > 150): | |
| return "pointing" | |
| # Detect kicking | |
| if (left_leg_angle is not None and left_leg_angle > 120) or (right_leg_angle is not None and right_leg_angle > 120): | |
| return "kicking" | |
| # Detect hitting | |
| if ((left_arm_angle is not None and 80 < left_arm_angle < 120) or | |
| (right_arm_angle is not None and 80 < right_arm_angle < 120)): | |
| if movement is not None and movement > movement_threshold * 2: | |
| return "hitting" | |
| return "waiting" | |
| def analyze_frame(frame: np.ndarray): | |
| start_time = time.time() | |
| img_container["input"] = frame | |
| frame = frame.copy() | |
| detections = [] | |
| if show_labels in ["Object Detection", "Both"]: | |
| # Run YOLOv8 object detection on the frame | |
| object_results = object_model(frame, conf=0.5) | |
| for i, box in enumerate(object_results[0].boxes): | |
| class_id = int(box.cls) | |
| detection = { | |
| "label": object_model.names[class_id], | |
| "score": float(box.conf), | |
| "box_coords": [round(value.item(), 2) for value in box.xyxy.flatten()] | |
| } | |
| detections.append(detection) | |
| if show_labels in ["Pose Estimation", "Both"]: | |
| # Run YOLOv8 pose estimation on the frame | |
| pose_results = pose_model(frame, conf=0.5) | |
| for i, box in enumerate(pose_results[0].boxes): | |
| class_id = int(box.cls) | |
| detection = { | |
| "label": pose_model.names[class_id], | |
| "score": float(box.conf), | |
| "box_coords": [round(value.item(), 2) for value in box.xyxy.flatten()] | |
| } | |
| # Get keypoints for this detection if available | |
| try: | |
| if pose_results[0].keypoints is not None: | |
| keypoints = pose_results[0].keypoints[i].data.cpu().numpy() | |
| # Detect action using the keypoints | |
| prev_keypoints = img_container.get("prev_keypoints") | |
| action = detect_action(keypoints, prev_keypoints) | |
| detection["action"] = action | |
| # Store current keypoints for next frame | |
| img_container["prev_keypoints"] = keypoints | |
| # Calculate the average position of visible keypoints | |
| visible_keypoints = keypoints[0][keypoints[0] | |
| [:, 2] > 0.5][:, :2] | |
| if len(visible_keypoints) > 0: | |
| label_x, label_y = np.mean( | |
| visible_keypoints, axis=0).astype(int) | |
| else: | |
| # Fallback to the center of the bounding box if no keypoints are visible | |
| x1, y1, x2, y2 = detection["box_coords"] | |
| label_x = int((x1 + x2) / 2) | |
| label_y = int((y1 + y2) / 2) | |
| else: | |
| detection["action"] = "No keypoint data" | |
| # Use the center of the bounding box for label position | |
| x1, y1, x2, y2 = detection["box_coords"] | |
| label_x = int((x1 + x2) / 2) | |
| label_y = int((y1 + y2) / 2) | |
| except IndexError: | |
| detection["action"] = "Action detection failed" | |
| # Use the center of the bounding box for label position | |
| x1, y1, x2, y2 = detection["box_coords"] | |
| label_x = int((x1 + x2) / 2) | |
| label_y = int((y1 + y2) / 2) | |
| # Only display the action as the label | |
| label = detection.get('action', '') | |
| # Increase font scale and thickness to match box label size | |
| font_scale = 2.0 | |
| thickness = 2 | |
| # Get text size for label | |
| (label_width, label_height), _ = cv2.getTextSize( | |
| label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) | |
| # Calculate position for centered label | |
| label_y = label_y - 10 # 10 pixels above the calculated position | |
| # Draw yellow background for label | |
| cv2.rectangle(frame, (label_x - label_width // 2 - 5, label_y - label_height - 5), | |
| (label_x + label_width // 2 + 5, label_y + 5), (0, 255, 255), -1) | |
| # Draw black text for label | |
| cv2.putText(frame, label, (label_x - label_width // 2, label_y), | |
| cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), thickness) | |
| detections.append(detection) | |
| # Draw detections on the frame | |
| if show_labels == "Object Detection": | |
| frame = object_results[0].plot() | |
| elif show_labels == "Pose Estimation": | |
| frame = pose_results[0].plot(boxes=False, labels=False, kpt_line=True) | |
| else: # Both | |
| frame = object_results[0].plot() | |
| frame = pose_results[0].plot( | |
| boxes=False, labels=False, kpt_line=True, img=frame) | |
| end_time = time.time() | |
| execution_time_ms = round((end_time - start_time) * 1000, 2) | |
| img_container["analysis_time"] = execution_time_ms | |
| img_container["detections"] = detections | |
| img_container["analyzed"] = frame | |
| return | |
| # | |
| # | |
| # | |
| # DO NOT TOUCH THE BELOW CODE (NOT NEEDED) | |
| # | |
| # | |
| # Suppress FFmpeg logs | |
| os.environ["FFMPEG_LOG_LEVEL"] = "quiet" | |
| # Suppress Streamlit logs using the logging module | |
| logging.getLogger("streamlit").setLevel(logging.ERROR) | |
| # Container to hold image data and analysis results | |
| img_container = {"input": None, "analyzed": None, | |
| "analysis_time": None, "detections": None} | |
| # Logger for debugging and information | |
| logger = logging.getLogger(__name__) | |
| # Callback function to process video frames | |
| # This function is called for each video frame in the WebRTC stream. | |
| # It converts the frame to a numpy array in RGB format, analyzes the frame, | |
| # and returns the original frame. | |
| def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: | |
| # Convert frame to numpy array in RGB format | |
| img = frame.to_ndarray(format="rgb24") | |
| analyze_frame(img) # Analyze the frame | |
| return frame # Return the original frame | |
| # Get ICE servers for WebRTC | |
| ice_servers = get_ice_servers() | |
| # Streamlit UI configuration | |
| st.set_page_config(layout="wide") | |
| # Custom CSS for the Streamlit page | |
| st.markdown( | |
| """ | |
| <style> | |
| .main { | |
| padding: 2rem; | |
| } | |
| h1, h2, h3 { | |
| font-family: 'Arial', sans-serif; | |
| } | |
| h1 { | |
| font-weight: 700; | |
| font-size: 2.5rem; | |
| } | |
| h2 { | |
| font-weight: 600; | |
| font-size: 2rem; | |
| } | |
| h3 { | |
| font-weight: 500; | |
| font-size: 1.5rem; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Streamlit page title and subtitle | |
| st.title(ANALYSIS_TITLE) | |
| st.subheader("A Computer Vision Playground") | |
| # Add a link to the README file | |
| st.markdown( | |
| """ | |
| <div style="text-align: left;"> | |
| <p>See the <a href="https://huggingface.co/spaces/eusholli/sentiment-analyzer/blob/main/README.md" | |
| target="_blank">README</a> to learn how to use this code to help you start your computer vision exploration.</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Columns for input and output streams | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.header("Input Stream") | |
| input_subheader = st.empty() | |
| input_placeholder = st.empty() # Placeholder for input frame | |
| st.subheader("Input Options") | |
| # WebRTC streamer to get video input from the webcam | |
| webrtc_ctx = webrtc_streamer( | |
| key="input-webcam", | |
| mode=WebRtcMode.SENDONLY, | |
| rtc_configuration=ice_servers, | |
| video_frame_callback=video_frame_callback, | |
| media_stream_constraints={"video": True, "audio": False}, | |
| async_processing=True, | |
| ) | |
| # File uploader for images | |
| st.subheader("Upload an Image") | |
| uploaded_file = st.file_uploader( | |
| "Choose an image...", type=["jpg", "jpeg", "png"]) | |
| # Text input for image URL | |
| st.subheader("Or Enter Image URL") | |
| image_url = st.text_input("Image URL") | |
| # Text input for YouTube URL | |
| st.subheader("Enter a YouTube URL") | |
| youtube_url = st.text_input("YouTube URL") | |
| yt_error = st.empty() # Placeholder for analysis time | |
| # File uploader for videos | |
| st.subheader("Upload a Video") | |
| uploaded_video = st.file_uploader( | |
| "Choose a video...", type=["mp4", "avi", "mov", "mkv"] | |
| ) | |
| # Text input for video URL | |
| st.subheader("Or Enter Video Download URL") | |
| video_url = st.text_input("Video URL") | |
| # Streamlit footer | |
| st.markdown( | |
| """ | |
| <div style="text-align: center; margin-top: 2rem;"> | |
| <p>If you want to set up your own computer vision playground see <a href="https://huggingface.co/spaces/eusholli/computer-vision-playground/blob/main/README.md" target="_blank">here</a>.</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Function to initialize the analysis UI | |
| # This function sets up the placeholders and UI elements in the analysis section. | |
| # It creates placeholders for input and output frames, analysis time, and detected labels. | |
| def analysis_init(): | |
| global progress_bar, status_text, download_button, yt_error, analysis_time, show_labels, labels_placeholder, input_subheader, input_placeholder, output_placeholder | |
| yt_error.empty() # Placeholder for analysis time | |
| with col2: | |
| st.header("Analysis") | |
| input_subheader.subheader("Input Frame") | |
| st.subheader("Output Frame") | |
| output_placeholder = st.empty() # Placeholder for output frame | |
| analysis_time = st.empty() # Placeholder for analysis time | |
| show_labels = st.radio( | |
| "Choose Detection Type", | |
| ("Object Detection", "Pose Estimation", "Both"), | |
| index=2 # Set default to "Both" (index 2) | |
| ) | |
| # Create a progress bar | |
| progress_bar = st.empty() | |
| status_text = st.empty() | |
| labels_placeholder = st.empty() # Placeholder for labels | |
| download_button = st.empty() # Placeholder for download button | |
| # Function to publish frames and results to the Streamlit UI | |
| # This function retrieves the latest frames and results from the global container and result queue, | |
| # and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels. | |
| def publish_frame(): | |
| img = img_container["input"] | |
| if img is None: | |
| return | |
| input_placeholder.image(img, channels="RGB") # Display the input frame | |
| analyzed = img_container["analyzed"] | |
| if analyzed is None: | |
| return | |
| # Display the analyzed frame | |
| output_placeholder.image(analyzed, channels="RGB") | |
| time = img_container["analysis_time"] | |
| if time is None: | |
| return | |
| # Display the analysis time | |
| analysis_time.text(f"Analysis Time: {time} ms") | |
| detections = img_container["detections"] | |
| if detections is None: | |
| return | |
| if show_labels: | |
| labels_placeholder.table( | |
| detections | |
| ) # Display labels if the checkbox is checked | |
| # If the WebRTC streamer is playing, initialize and publish frames | |
| if webrtc_ctx.state.playing: | |
| analysis_init() # Initialize the analysis UI | |
| while True: | |
| publish_frame() # Publish the frames and results | |
| time.sleep(0.1) # Delay to control frame rate | |
| # If an image is uploaded or a URL is provided, process the image | |
| if uploaded_file is not None or image_url: | |
| analysis_init() # Initialize the analysis UI | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) # Open the uploaded image | |
| img = np.array(image.convert("RGB")) # Convert the image to RGB format | |
| else: | |
| response = requests.get(image_url) # Download the image from the URL | |
| # Open the downloaded image | |
| image = Image.open(BytesIO(response.content)) | |
| img = np.array(image.convert("RGB")) # Convert the image to RGB format | |
| analyze_frame(img) # Analyze the image | |
| publish_frame() # Publish the results | |
| # Function to process video files | |
| # This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis, | |
| # and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels. | |
| # Function to process video files | |
| def process_video(video_path): | |
| cap = cv2.VideoCapture(video_path) # Open the video file | |
| # Create a temporary file for the annotated video | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_video: | |
| temp_video_path = temp_video.name | |
| # save_annotated_video(video_path, temp_video_path) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height)) | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() # Read a frame from the video | |
| if not ret: | |
| break # Exit the loop if no more frames are available | |
| # Convert the frame from BGR to RGB format | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Analyze the frame for face detection and sentiment analysis | |
| analyze_frame(rgb_frame) | |
| analyzed_frame = img_container["analyzed"] | |
| if analyzed_frame is not None: | |
| out.write(cv2.cvtColor(analyzed_frame, cv2.COLOR_RGB2BGR)) | |
| publish_frame() # Publish the results | |
| # Update progress | |
| frame_count += 1 | |
| progress = min(100, int(frame_count / total_frames * 100)) | |
| progress_bar.progress(progress) | |
| status_text.text(f"Processing video: {progress}% complete") | |
| cap.release() # Release the video capture object | |
| out.release() | |
| # Add download button for annotated video | |
| with open(temp_video_path, "rb") as file: | |
| download_button.download_button( | |
| label="Download Annotated Video", | |
| data=file, | |
| file_name="annotated_video.mp4", | |
| mime="video/mp4" | |
| ) | |
| # Clean up the temporary file | |
| os.unlink(temp_video_path) | |
| # Function to get video URL using Cobalt API | |
| def get_cobalt_video_url(youtube_url): | |
| cobalt_api_url = "https://api.cobalt.tools/api/json" | |
| headers = { | |
| "Accept": "application/json", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "url": youtube_url, | |
| "vCodec": "h264", | |
| "vQuality": "720", | |
| "aFormat": "mp3", | |
| "isAudioOnly": False | |
| } | |
| try: | |
| response = requests.post(cobalt_api_url, headers=headers, json=payload) | |
| response.raise_for_status() | |
| data = response.json() | |
| if data['status'] == 'stream': | |
| return data['url'] | |
| elif data['status'] == 'redirect': | |
| return data['url'] | |
| else: | |
| yt_error.error(f"Error: {data['text']}") | |
| return None | |
| except requests.exceptions.RequestException as e: | |
| yt_error.error(f"Error: Unable to process the YouTube URL. {str(e)}") | |
| return None | |
| # If a YouTube URL is provided, process the video | |
| if youtube_url: | |
| analysis_init() # Initialize the analysis UI | |
| stream_url = get_cobalt_video_url(youtube_url) | |
| # stream_url = get_youtube_stream_url(youtube_url) | |
| if stream_url: | |
| process_video(stream_url) # Process the video | |
| else: | |
| yt_error.error( | |
| "Unable to process the YouTube video. Please try a different URL or video format.") | |
| # If a video is uploaded or a URL is provided, process the video | |
| if uploaded_video is not None or video_url: | |
| analysis_init() # Initialize the analysis UI | |
| if uploaded_video is not None: | |
| video_path = uploaded_video.name # Get the name of the uploaded video | |
| with open(video_path, "wb") as f: | |
| # Save the uploaded video to a file | |
| f.write(uploaded_video.getbuffer()) | |
| else: | |
| # Download the video from the URL | |
| video_path = download_file(video_url) | |
| process_video(video_path) # Process the video | |