Afnan214 commited on
Commit
9c57041
·
unverified ·
1 Parent(s): 45d9f6f
Files changed (4) hide show
  1. app.py +56 -52
  2. face_detection.py +2 -2
  3. mark_detection.py +1 -1
  4. requirements.txt +0 -1
app.py CHANGED
@@ -1,99 +1,103 @@
1
  import cv2
2
  import streamlit as st
3
  import tempfile
4
- import time
5
  import numpy as np
6
  from face_detection import FaceDetector
7
  from mark_detection import MarkDetector
8
  from pose_estimation import PoseEstimator
9
  from utils import refine
10
 
 
11
 
12
- st.title("Pose-estimation")
13
-
14
  file_type = st.selectbox("Choose the type of file you want to upload", ("Image", "Video"))
15
- if file_type == "Image":
16
- uploaded_file = st.file_uploader("Upload an image of your face", type=["jpg","jpeg", "png"])
17
- else:
18
- uploaded_video = st.file_uploader("Upload a video of your face", type=["mp4","mov","avi","mkv"])
 
 
 
19
 
20
  if uploaded_file is not None:
 
21
  if file_type == "Video":
22
  tfile = tempfile.NamedTemporaryFile(delete=False)
23
  tfile.write(uploaded_file.read())
24
  cap = cv2.VideoCapture(tfile.name)
25
- print(f"Video source: {tfile.name}")
26
 
27
- #getting frame sizes
28
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
29
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
30
 
31
- #face detection
32
  face_detector = FaceDetector("assets/face_detector.onnx")
33
- #landmark detection
34
  mark_detector = MarkDetector("assets/face_landmarks.onnx")
35
- #pose estimation
36
  pose_estimator = PoseEstimator(frame_width, frame_height)
37
 
38
- tm = cv2.TickMeter()
39
-
40
- while True:
41
-
42
- # Read a frame.
43
- frame_got, frame = cap.read()
44
- if frame_got is False:
45
  break
46
 
47
-
48
- # Step 1: Get faces from current frame.
49
  faces, _ = face_detector.detect(frame, 0.7)
50
 
51
- # Any valid face found?
52
  if len(faces) > 0:
53
- tm.start()
54
-
55
- # Step 2: Detect landmarks. Crop and feed the face area into the
56
- # mark detector. Note only the first face will be used for
57
- # demonstration.
58
  face = refine(faces, frame_width, frame_height, 0.15)[0]
59
  x1, y1, x2, y2 = face[:4].astype(int)
60
  patch = frame[y1:y2, x1:x2]
61
 
62
- # Run the mark detection.
63
  marks = mark_detector.detect([patch])[0].reshape([68, 2])
64
-
65
- # Convert the locations from local face area to the global image.
66
  marks *= (x2 - x1)
67
  marks[:, 0] += x1
68
  marks[:, 1] += y1
69
 
70
- # Step 3: Try pose estimation with 68 points.
71
  pose = pose_estimator.solve(marks)
72
 
73
- tm.stop()
74
-
75
- # All done. The best way to show the result would be drawing the
76
- # pose on the frame in realtime.
77
-
78
- # Do you want to see the pose annotation?
79
  pose_estimator.visualize(frame, pose, color=(0, 255, 0))
80
 
81
- # Do you want to see the axes?
82
- # pose_estimator.draw_axes(frame, pose)
 
83
 
84
- # Do you want to see the marks?
85
- # mark_detector.visualize(frame, marks, color=(0, 255, 0))
86
 
87
- # Do you want to see the face bounding boxes?
88
- # face_detector.visualize(frame, faces)
 
 
 
89
 
90
- # Draw the FPS on screen.
91
- cv2.rectangle(frame, (0, 0), (90, 30), (0, 0, 0), cv2.FILLED)
92
- cv2.putText(frame, f"FPS: {tm.getFPS():.0f}", (10, 20),
93
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
94
-
95
- # Show preview.
96
- cv2.imshow("Preview", frame)
97
- if cv2.waitKey(1) == 27:
98
- break
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
  import streamlit as st
3
  import tempfile
 
4
  import numpy as np
5
  from face_detection import FaceDetector
6
  from mark_detection import MarkDetector
7
  from pose_estimation import PoseEstimator
8
  from utils import refine
9
 
10
+ st.title("Pose Estimation")
11
 
12
+ # Choose between Image or Video file upload
 
13
  file_type = st.selectbox("Choose the type of file you want to upload", ("Image", "Video"))
14
+ uploaded_file = st.file_uploader(
15
+ "Upload an image or video file of your face",
16
+ type=["jpg", "jpeg", "png", "mp4", "mov", "avi", "mkv"]
17
+ )
18
+
19
+ # Display placeholder for real-time video output
20
+ FRAME_WINDOW = st.image([])
21
 
22
  if uploaded_file is not None:
23
+ # Video processing
24
  if file_type == "Video":
25
  tfile = tempfile.NamedTemporaryFile(delete=False)
26
  tfile.write(uploaded_file.read())
27
  cap = cv2.VideoCapture(tfile.name)
28
+ st.write(f"Video source: {tfile.name}")
29
 
30
+ # Getting frame sizes
31
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
32
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
33
 
34
+ # Initialize face detection, landmark detection, and pose estimation models
35
  face_detector = FaceDetector("assets/face_detector.onnx")
 
36
  mark_detector = MarkDetector("assets/face_landmarks.onnx")
 
37
  pose_estimator = PoseEstimator(frame_width, frame_height)
38
 
39
+ # Process each frame
40
+ while cap.isOpened():
41
+ ret, frame = cap.read()
42
+ if not ret:
 
 
 
43
  break
44
 
45
+ # Step 1: Detect faces in the frame
 
46
  faces, _ = face_detector.detect(frame, 0.7)
47
 
48
+ # If a face is detected, proceed with pose estimation
49
  if len(faces) > 0:
50
+ # Detect landmarks for the first face
 
 
 
 
51
  face = refine(faces, frame_width, frame_height, 0.15)[0]
52
  x1, y1, x2, y2 = face[:4].astype(int)
53
  patch = frame[y1:y2, x1:x2]
54
 
55
+ # Run landmark detection and convert local face area to global image
56
  marks = mark_detector.detect([patch])[0].reshape([68, 2])
 
 
57
  marks *= (x2 - x1)
58
  marks[:, 0] += x1
59
  marks[:, 1] += y1
60
 
61
+ # Pose estimation with the detected landmarks
62
  pose = pose_estimator.solve(marks)
63
 
64
+ # Draw the pose on the frame
 
 
 
 
 
65
  pose_estimator.visualize(frame, pose, color=(0, 255, 0))
66
 
67
+ # Convert frame to RGB for Streamlit display
68
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
69
+ FRAME_WINDOW.image(frame_rgb)
70
 
71
+ cap.release()
 
72
 
73
+ # Image processing
74
+ elif file_type == "Image":
75
+ # Load and process uploaded image
76
+ image = np.array(Image.open(uploaded_file))
77
+ frame_height, frame_width, _ = image.shape
78
 
79
+ # Initialize models for detection and pose estimation
80
+ face_detector = FaceDetector("assets/face_detector.onnx")
81
+ mark_detector = MarkDetector("assets/face_landmarks.onnx")
82
+ pose_estimator = PoseEstimator(frame_width, frame_height)
 
 
 
 
 
83
 
84
+ # Detect face and landmarks
85
+ faces, _ = face_detector.detect(image, 0.7)
86
+ if len(faces) > 0:
87
+ face = refine(faces, frame_width, frame_height, 0.15)[0]
88
+ x1, y1, x2, y2 = face[:4].astype(int)
89
+ patch = image[y1:y2, x1:x2]
90
+
91
+ # Detect landmarks and map them to global image coordinates
92
+ marks = mark_detector.detect([patch])[0].reshape([68, 2])
93
+ marks *= (x2 - x1)
94
+ marks[:, 0] += x1
95
+ marks[:, 1] += y1
96
+
97
+ # Estimate pose and visualize on image
98
+ pose = pose_estimator.solve(marks)
99
+ pose_estimator.visualize(image, pose, color=(0, 255, 0))
100
+
101
+ # Convert image to RGB and display in Streamlit
102
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
103
+ st.image(image_rgb, caption="Pose Estimated Image", use_column_width=True)
face_detection.py CHANGED
@@ -36,8 +36,8 @@ class FaceDetector:
36
  assert os.path.exists(model_file), f"File not found: {model_file}"
37
  self.center_cache = {}
38
  self.nms_threshold = 0.4
39
- self.session = onnxruntime.InferenceSession(
40
- model_file, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
41
 
42
  # Get model configurations from the model file.
43
  # What is the input like?
 
36
  assert os.path.exists(model_file), f"File not found: {model_file}"
37
  self.center_cache = {}
38
  self.nms_threshold = 0.4
39
+ self.session = onnxruntime.InferenceSession(model_file, providers=['CPUExecutionProvider'])
40
+
41
 
42
  # Get model configurations from the model file.
43
  # What is the input like?
mark_detection.py CHANGED
@@ -17,7 +17,7 @@ class MarkDetector:
17
  assert os.path.exists(model_file), f"File not found: {model_file}"
18
  self._input_size = 128
19
  self.model = ort.InferenceSession(
20
- model_file, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
21
 
22
  def _preprocess(self, bgrs):
23
  """Preprocess the inputs to meet the model's needs.
 
17
  assert os.path.exists(model_file), f"File not found: {model_file}"
18
  self._input_size = 128
19
  self.model = ort.InferenceSession(
20
+ model_file, providers=["CPUExecutionProvider"])
21
 
22
  def _preprocess(self, bgrs):
23
  """Preprocess the inputs to meet the model's needs.
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  opencv-python-headless
2
  numpy
3
- tempfile
4
  time
5
  onnxruntime
6
  os
 
1
  opencv-python-headless
2
  numpy
 
3
  time
4
  onnxruntime
5
  os