Spaces:
Build error
Build error
| import base64 | |
| import cv2 | |
| import numpy as np | |
| import requests | |
| import gradio as gr | |
| from face_detection import FaceDetector | |
| from mark_detection import MarkDetector | |
| from pose_estimation import PoseEstimator | |
| from utils import refine | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Constants | |
| API_KEY = os.getenv("ROBOFLOW_API_KEY") | |
| DISTANCE_TO_OBJECT = 1000 # mm | |
| HEIGHT_OF_HUMAN_FACE = 250 # mm | |
| GAZE_DETECTION_URL = f"http://localhost:9001/gaze/gaze_detection?api_key={API_KEY}" | |
| def detect_gazes(frame: np.ndarray): | |
| """Detect gazes from the inference server.""" | |
| if frame is None or frame.size == 0: | |
| print("Error: Empty or invalid frame passed to detect_gazes.") | |
| return [] | |
| try: | |
| _, img_encode = cv2.imencode(".jpg", frame) | |
| img_base64 = base64.b64encode(img_encode) | |
| resp = requests.post( | |
| GAZE_DETECTION_URL, | |
| json={ | |
| "api_key": API_KEY, | |
| "image": {"type": "base64", "value": img_base64.decode("utf-8")}, | |
| }, | |
| ) | |
| resp.raise_for_status() | |
| gazes = resp.json()[0]["predictions"] | |
| return gazes | |
| except Exception as e: | |
| print(f"Error in detect_gazes: {e}") | |
| return [] | |
| def draw_gaze(img: np.ndarray, gaze: dict): | |
| """Draw gaze direction and keypoints on the image.""" | |
| face = gaze["face"] | |
| x_min = int(face["x"] - face["width"] / 2) | |
| x_max = int(face["x"] + face["width"] / 2) | |
| y_min = int(face["y"] - face["height"] / 2) | |
| y_max = int(face["y"] + face["height"] / 2) | |
| cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 3) | |
| # Draw gaze arrow | |
| _, imgW = img.shape[:2] | |
| arrow_length = imgW / 2 | |
| dx = -arrow_length * np.sin(gaze["yaw"]) * np.cos(gaze["pitch"]) | |
| dy = -arrow_length * np.sin(gaze["pitch"]) | |
| cv2.arrowedLine( | |
| img, | |
| (int(face["x"]), int(face["y"])), | |
| (int(face["x"] + dx), int(face["y"] + dy)), | |
| (0, 0, 255), | |
| 2, | |
| cv2.LINE_AA, | |
| tipLength=0.18, | |
| ) | |
| # Draw keypoints | |
| for keypoint in face["landmarks"]: | |
| x, y = int(keypoint["x"]), int(keypoint["y"]) | |
| cv2.circle(img, (x, y), 2, (0, 255, 0), -1) | |
| return img | |
| def process_video(video_path): | |
| """Process video and return frames.""" | |
| cap = cv2.VideoCapture(video_path) | |
| frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| face_detector = FaceDetector("assets/face_detector.onnx") | |
| mark_detector = MarkDetector("assets/face_landmarks.onnx") | |
| pose_estimator = PoseEstimator(frame_width, frame_height) | |
| frame_counter = 0 | |
| output_frames = [] | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| faces, _ = face_detector.detect(frame, 0.7) | |
| if len(faces) > 0: | |
| face = refine(faces, frame_width, frame_height, 0.15)[0] | |
| x1, y1, x2, y2 = face[:4].astype(int) | |
| patch = frame[y1:y2, x1:x2] | |
| marks = mark_detector.detect([patch])[0].reshape([68, 2]) | |
| marks *= (x2 - x1) | |
| marks[:, 0] += x1 | |
| marks[:, 1] += y1 | |
| distraction_status, pose_vectors = pose_estimator.detect_distraction(marks) | |
| status_text = "Distracted" if distraction_status else "Focused" | |
| cv2.putText( | |
| frame, | |
| f"Status: {status_text}", | |
| (10, 50), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (0, 255, 0) if not distraction_status else (0, 0, 255), | |
| 2, | |
| ) | |
| # frame_counter += 1 | |
| # if frame_counter % 15 == 0: | |
| gazes = detect_gazes(frame) | |
| if gazes: | |
| for gaze in gazes: | |
| frame = draw_gaze(frame, gaze) | |
| cv2.putText( | |
| frame, | |
| f"Yaw: {gaze['yaw']:.2f}, Pitch: {gaze['pitch']:.2f}", | |
| (10, 100), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (255, 255, 0), | |
| 2, | |
| ) | |
| output_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| cap.release() | |
| return output_frames | |
| def gradio_interface(video_path): | |
| """Interface for Gradio.""" | |
| output_frames = process_video(video_path) | |
| output_video_path = "output_video.mp4" | |
| # Save frames as video with a compatible codec | |
| if len(output_frames) > 0: | |
| height, width, _ = output_frames[0].shape | |
| out = cv2.VideoWriter( | |
| output_video_path, | |
| cv2.VideoWriter_fourcc(*"avc1"), # Use H.264 codec for compatibility | |
| 30, # Frame rate | |
| (width, height), | |
| ) | |
| for frame in output_frames: | |
| out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
| out.release() | |
| else: | |
| print("No frames were processed.") | |
| return output_video_path | |
| # Launch Gradio App | |
| gr.Interface( | |
| fn=gradio_interface, | |
| inputs=gr.Video(label="Input Video"), | |
| outputs=gr.Video(label="Processed Video"), | |
| title="Distraction Detection", | |
| description="Upload a video to detect distraction status and gaze direction.", | |
| ).launch() | |