BasitAliii commited on
Commit
4ff793f
·
verified ·
1 Parent(s): 02d08c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -20
app.py CHANGED
@@ -21,11 +21,17 @@ except ImportError as e:
21
  class FaceRecognitionSystem:
22
  def __init__(self):
23
  self.setup_database()
24
- global FACE_NET_AVAILABLE # Declare global at the start of method
25
 
26
  if FACE_NET_AVAILABLE:
27
  try:
28
- self.mtcnn = MTCNN(keep_all=True, device='cpu')
 
 
 
 
 
 
29
  self.resnet = InceptionResnetV1(pretrained='vggface2', classify=False, device='cpu').eval()
30
  print("✅ FaceNet models loaded successfully")
31
  except Exception as e:
@@ -79,19 +85,31 @@ class FaceRecognitionSystem:
79
  pil_image = Image.open(image).convert('RGB')
80
  elif isinstance(image, np.ndarray):
81
  # numpy array
82
- pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
 
 
 
83
  else:
84
  # Handle Gradio file object
85
  pil_image = Image.open(image).convert('RGB')
86
 
87
- # Detect faces and extract embeddings
88
- faces = self.mtcnn(pil_image)
89
- if faces is None:
 
90
  return None, "No face detected in the image"
91
 
92
- # Get embedding for the first face
93
- embedding = self.resnet(faces.unsqueeze(0))
94
- embedding_np = embedding.detach().numpy().flatten()
 
 
 
 
 
 
 
 
95
 
96
  return embedding_np, "Face embedding extracted successfully"
97
 
@@ -155,7 +173,7 @@ class FaceRecognitionSystem:
155
  # Find the best match
156
  best_match = None
157
  min_distance = float('inf')
158
- threshold = 0.8 # Similarity threshold
159
 
160
  for student_id, name, class_name, embedding_blob in students:
161
  known_embedding = np.frombuffer(embedding_blob, dtype=np.float32)
@@ -192,15 +210,7 @@ class FaceRecognitionSystem:
192
  result_message = f"⚠️ {name} already marked today (ID: {student_id})"
193
 
194
  # Convert image for display
195
- if isinstance(image, str):
196
- display_image = cv2.imread(image)
197
- elif isinstance(image, np.ndarray):
198
- display_image = image.copy()
199
- else:
200
- # Handle Gradio file object
201
- display_image = np.array(Image.open(image))
202
- if len(display_image.shape) == 3 and display_image.shape[2] == 3:
203
- display_image = cv2.cvtColor(display_image, cv2.COLOR_RGB2BGR)
204
 
205
  # Add text to image
206
  if display_image is not None:
@@ -210,12 +220,32 @@ class FaceRecognitionSystem:
210
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
211
  cv2.putText(display_image, f"Class: {class_name}", (10, 90),
212
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
 
 
213
 
214
- display_image = cv2.cvtColor(display_image, cv2.COLOR_BGR2RGB)
215
  return display_image, result_message
216
 
217
  return image, "❌ No matching student found"
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def get_all_students(self):
220
  """Get all registered students"""
221
  self.cursor.execute('SELECT name, student_id, class FROM students ORDER BY name')
@@ -240,6 +270,12 @@ class FaceRecognitionSystem:
240
  ''')
241
  return self.cursor.fetchall()
242
 
 
 
 
 
 
 
243
  # Initialize the system
244
  face_system = FaceRecognitionSystem()
245
 
 
21
  class FaceRecognitionSystem:
22
  def __init__(self):
23
  self.setup_database()
24
+ global FACE_NET_AVAILABLE
25
 
26
  if FACE_NET_AVAILABLE:
27
  try:
28
+ # Initialize MTCNN with simpler settings
29
+ self.mtcnn = MTCNN(
30
+ keep_all=False, # Only detect one face for simplicity
31
+ min_face_size=40,
32
+ thresholds=[0.6, 0.7, 0.7],
33
+ device='cpu'
34
+ )
35
  self.resnet = InceptionResnetV1(pretrained='vggface2', classify=False, device='cpu').eval()
36
  print("✅ FaceNet models loaded successfully")
37
  except Exception as e:
 
85
  pil_image = Image.open(image).convert('RGB')
86
  elif isinstance(image, np.ndarray):
87
  # numpy array
88
+ if len(image.shape) == 3:
89
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
90
+ else:
91
+ return None, "Invalid image format"
92
  else:
93
  # Handle Gradio file object
94
  pil_image = Image.open(image).convert('RGB')
95
 
96
+ # Detect face and get cropped face
97
+ face_tensor, prob = self.mtcnn(pil_image, return_prob=True)
98
+
99
+ if face_tensor is None:
100
  return None, "No face detected in the image"
101
 
102
+ if prob < 0.9: # Confidence threshold
103
+ return None, "Low confidence face detection"
104
+
105
+ # Ensure proper tensor dimensions [3, 160, 160]
106
+ if face_tensor.dim() == 4:
107
+ face_tensor = face_tensor.squeeze(0) # Remove batch dimension if present
108
+
109
+ # Get embedding
110
+ with torch.no_grad():
111
+ embedding = self.resnet(face_tensor.unsqueeze(0)) # Add batch dimension
112
+ embedding_np = embedding.detach().numpy().flatten()
113
 
114
  return embedding_np, "Face embedding extracted successfully"
115
 
 
173
  # Find the best match
174
  best_match = None
175
  min_distance = float('inf')
176
+ threshold = 1.0 # Similarity threshold
177
 
178
  for student_id, name, class_name, embedding_blob in students:
179
  known_embedding = np.frombuffer(embedding_blob, dtype=np.float32)
 
210
  result_message = f"⚠️ {name} already marked today (ID: {student_id})"
211
 
212
  # Convert image for display
213
+ display_image = self.prepare_display_image(image)
 
 
 
 
 
 
 
 
214
 
215
  # Add text to image
216
  if display_image is not None:
 
220
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
221
  cv2.putText(display_image, f"Class: {class_name}", (10, 90),
222
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
223
+ cv2.putText(display_image, f"Similarity: {1-distance:.3f}", (10, 120),
224
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
225
 
 
226
  return display_image, result_message
227
 
228
  return image, "❌ No matching student found"
229
 
230
+ def prepare_display_image(self, image):
231
+ """Convert image to display format"""
232
+ try:
233
+ if isinstance(image, str):
234
+ display_image = cv2.imread(image)
235
+ elif isinstance(image, np.ndarray):
236
+ display_image = image.copy()
237
+ if len(display_image.shape) == 3 and display_image.shape[2] == 3:
238
+ display_image = cv2.cvtColor(display_image, cv2.COLOR_RGB2BGR)
239
+ else:
240
+ # Handle Gradio file object
241
+ display_image = np.array(Image.open(image))
242
+ if len(display_image.shape) == 3 and display_image.shape[2] == 3:
243
+ display_image = cv2.cvtColor(display_image, cv2.COLOR_RGB2BGR)
244
+
245
+ return display_image
246
+ except:
247
+ return None
248
+
249
  def get_all_students(self):
250
  """Get all registered students"""
251
  self.cursor.execute('SELECT name, student_id, class FROM students ORDER BY name')
 
270
  ''')
271
  return self.cursor.fetchall()
272
 
273
+ # Import torch after FaceNet imports to avoid conflicts
274
+ try:
275
+ import torch
276
+ except ImportError:
277
+ torch = None
278
+
279
  # Initialize the system
280
  face_system = FaceRecognitionSystem()
281