Afnan214's picture
distraction detection
24961de unverified
import base64
import cv2
import numpy as np
import requests
import gradio as gr
from face_detection import FaceDetector
from mark_detection import MarkDetector
from pose_estimation import PoseEstimator
from utils import refine
import os
from dotenv import load_dotenv
load_dotenv()
# Constants
API_KEY = os.getenv("ROBOFLOW_API_KEY")
DISTANCE_TO_OBJECT = 1000 # mm
HEIGHT_OF_HUMAN_FACE = 250 # mm
GAZE_DETECTION_URL = f"http://localhost:9001/gaze/gaze_detection?api_key={API_KEY}"
def detect_gazes(frame: np.ndarray):
"""Detect gazes from the inference server."""
if frame is None or frame.size == 0:
print("Error: Empty or invalid frame passed to detect_gazes.")
return []
try:
_, img_encode = cv2.imencode(".jpg", frame)
img_base64 = base64.b64encode(img_encode)
resp = requests.post(
GAZE_DETECTION_URL,
json={
"api_key": API_KEY,
"image": {"type": "base64", "value": img_base64.decode("utf-8")},
},
)
resp.raise_for_status()
gazes = resp.json()[0]["predictions"]
return gazes
except Exception as e:
print(f"Error in detect_gazes: {e}")
return []
def draw_gaze(img: np.ndarray, gaze: dict):
"""Draw gaze direction and keypoints on the image."""
face = gaze["face"]
x_min = int(face["x"] - face["width"] / 2)
x_max = int(face["x"] + face["width"] / 2)
y_min = int(face["y"] - face["height"] / 2)
y_max = int(face["y"] + face["height"] / 2)
cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 3)
# Draw gaze arrow
_, imgW = img.shape[:2]
arrow_length = imgW / 2
dx = -arrow_length * np.sin(gaze["yaw"]) * np.cos(gaze["pitch"])
dy = -arrow_length * np.sin(gaze["pitch"])
cv2.arrowedLine(
img,
(int(face["x"]), int(face["y"])),
(int(face["x"] + dx), int(face["y"] + dy)),
(0, 0, 255),
2,
cv2.LINE_AA,
tipLength=0.18,
)
# Draw keypoints
for keypoint in face["landmarks"]:
x, y = int(keypoint["x"]), int(keypoint["y"])
cv2.circle(img, (x, y), 2, (0, 255, 0), -1)
return img
def process_video(video_path):
"""Process video and return frames."""
cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
face_detector = FaceDetector("assets/face_detector.onnx")
mark_detector = MarkDetector("assets/face_landmarks.onnx")
pose_estimator = PoseEstimator(frame_width, frame_height)
frame_counter = 0
output_frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
faces, _ = face_detector.detect(frame, 0.7)
if len(faces) > 0:
face = refine(faces, frame_width, frame_height, 0.15)[0]
x1, y1, x2, y2 = face[:4].astype(int)
patch = frame[y1:y2, x1:x2]
marks = mark_detector.detect([patch])[0].reshape([68, 2])
marks *= (x2 - x1)
marks[:, 0] += x1
marks[:, 1] += y1
distraction_status, pose_vectors = pose_estimator.detect_distraction(marks)
status_text = "Distracted" if distraction_status else "Focused"
cv2.putText(
frame,
f"Status: {status_text}",
(10, 50),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 255, 0) if not distraction_status else (0, 0, 255),
2,
)
# frame_counter += 1
# if frame_counter % 15 == 0:
gazes = detect_gazes(frame)
if gazes:
for gaze in gazes:
frame = draw_gaze(frame, gaze)
cv2.putText(
frame,
f"Yaw: {gaze['yaw']:.2f}, Pitch: {gaze['pitch']:.2f}",
(10, 100),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 255, 0),
2,
)
output_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
return output_frames
def gradio_interface(video_path):
"""Interface for Gradio."""
output_frames = process_video(video_path)
output_video_path = "output_video.mp4"
# Save frames as video with a compatible codec
if len(output_frames) > 0:
height, width, _ = output_frames[0].shape
out = cv2.VideoWriter(
output_video_path,
cv2.VideoWriter_fourcc(*"avc1"), # Use H.264 codec for compatibility
30, # Frame rate
(width, height),
)
for frame in output_frames:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
else:
print("No frames were processed.")
return output_video_path
# Launch Gradio App
gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Input Video"),
outputs=gr.Video(label="Processed Video"),
title="Distraction Detection",
description="Upload a video to detect distraction status and gaze direction.",
).launch()