N3tron commited on
Commit
daf7f5e
·
verified ·
1 Parent(s): f0c931e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -54
app.py CHANGED
@@ -5,6 +5,7 @@ import streamlit as st
5
  from insightface.app import FaceAnalysis
6
  from glob import glob
7
  from tqdm import tqdm
 
8
  import shutil
9
  import zipfile
10
 
@@ -13,38 +14,53 @@ def extract_zip(zip_file_path, extract_dir):
13
  with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
14
  zip_ref.extractall(extract_dir)
15
 
16
- # Function to recognize faces
17
- def recognize_faces(frame, names, embeddings, app):
18
- # Perform face analysis on the frame
19
- faces = app.get(frame)
20
 
21
- # Process each detected face separately
22
- for face in faces:
23
- # Retrieve the embedding for the detected face
24
- detected_embedding = face.normed_embedding
 
 
25
 
26
- # Calculate similarity scores with known embeddings
27
- scores = np.dot(detected_embedding, np.array(embeddings).T)
28
- scores = np.clip(scores, 0., 1.)
29
 
30
- # Find the index with the highest score
31
- idx = np.argmax(scores)
32
- max_score = scores[idx]
33
 
34
- # Check if the maximum score is above a certain threshold (adjust as needed)
35
- threshold = 0.7
36
- if max_score >= threshold:
37
- recognized_name = names[idx]
38
- else:
39
- recognized_name = "Unknown"
40
 
41
- # Draw bounding box around the detected face
42
- bbox = face.bbox.astype(int)
43
- cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
44
- # Write recognized name within the bounding box
45
- cv2.putText(frame, recognized_name, (bbox[0], bbox[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
46
 
47
- return frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # Function to get embeddings
50
  def get_embeddings(db_dir):
@@ -124,39 +140,21 @@ def main():
124
  uploaded_embeddings = st.file_uploader("Upload embeddings.npy", type="npy")
125
 
126
  if uploaded_names and uploaded_embeddings:
127
- # Load names and embeddings
128
  names = np.load(uploaded_names)
129
  embeddings = np.load(uploaded_embeddings)
130
 
131
- # Initialize FaceAnalysis app
132
- app = FaceAnalysis(name='buffalo_l')
133
- app.prepare(ctx_id=0, det_size=(640, 640))
134
-
135
- # Display a button to start webcam
136
- if st.button("Start Webcam"):
137
- # Start capturing video from webcam
138
- cap = cv2.VideoCapture(0)
139
-
140
- # Process each frame in real-time
141
- while True:
142
- # Capture frame-by-frame
143
- ret, frame = cap.read()
144
- if not ret:
145
- break
146
-
147
- # Perform face recognition
148
- frame = recognize_faces(frame, names, embeddings, app)
149
-
150
- # Display the resulting frame
151
- st.image(frame, channels="BGR", use_column_width=True)
152
 
153
- # Break the loop if 'q' is pressed or the user closes the Streamlit app
154
- if st.button("Stop"):
155
- break
 
 
 
156
 
157
- # Release the capture
158
- cap.release()
159
- cv2.destroyAllWindows()
160
 
161
 
162
 
 
5
  from insightface.app import FaceAnalysis
6
  from glob import glob
7
  from tqdm import tqdm
8
+ from streamlit_webrtc import webrtc_streamer, VideoTransformerBase
9
  import shutil
10
  import zipfile
11
 
 
14
  with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
15
  zip_ref.extractall(extract_dir)
16
 
 
 
 
 
17
 
18
+ class FaceRecognitionTransformer(VideoTransformerBase):
19
+ def __init__(self):
20
+ self.app = FaceAnalysis(name='buffalo_l')
21
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
22
+ self.names = None
23
+ self.embeddings = None
24
 
25
+ def _recognize_faces(self, frame):
26
+ if self.names is None or self.embeddings is None:
27
+ return frame
28
 
29
+ # Perform face analysis on the frame
30
+ faces = self.app.get(frame)
 
31
 
32
+ # Process each detected face separately
33
+ for face in faces:
34
+ # Retrieve the embedding for the detected face
35
+ detected_embedding = face.normed_embedding
 
 
36
 
37
+ # Calculate similarity scores with known embeddings
38
+ scores = np.dot(detected_embedding, np.array(self.embeddings).T)
39
+ scores = np.clip(scores, 0., 1.)
 
 
40
 
41
+ # Find the index with the highest score
42
+ idx = np.argmax(scores)
43
+ max_score = scores[idx]
44
+
45
+ # Check if the maximum score is above a certain threshold (adjust as needed)
46
+ threshold = 0.7
47
+ if max_score >= threshold:
48
+ recognized_name = self.names[idx]
49
+ else:
50
+ recognized_name = "Unknown"
51
+
52
+ # Draw bounding box around the detected face
53
+ bbox = face.bbox.astype(int)
54
+ cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
55
+ # Write recognized name within the bounding box
56
+ cv2.putText(frame, recognized_name, (bbox[0], bbox[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
57
+
58
+ return frame
59
+
60
+ def transform(self, frame):
61
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
62
+ frame = self._recognize_faces(frame)
63
+ return frame
64
 
65
  # Function to get embeddings
66
  def get_embeddings(db_dir):
 
140
  uploaded_embeddings = st.file_uploader("Upload embeddings.npy", type="npy")
141
 
142
  if uploaded_names and uploaded_embeddings:
 
143
  names = np.load(uploaded_names)
144
  embeddings = np.load(uploaded_embeddings)
145
 
146
+ # Initialize transformer with names and embeddings
147
+ transformer = FaceRecognitionTransformer()
148
+ transformer.names = names
149
+ transformer.embeddings = embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ # Create WebRTC streamer
152
+ webrtc_ctx = webrtc_streamer(
153
+ key="example",
154
+ video_transformer_factory=FaceRecognitionTransformer,
155
+ async_transform=True,
156
+ )
157
 
 
 
 
158
 
159
 
160