Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -21,11 +21,17 @@ except ImportError as e:
|
|
| 21 |
class FaceRecognitionSystem:
|
| 22 |
def __init__(self):
|
| 23 |
self.setup_database()
|
| 24 |
-
global FACE_NET_AVAILABLE
|
| 25 |
|
| 26 |
if FACE_NET_AVAILABLE:
|
| 27 |
try:
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
self.resnet = InceptionResnetV1(pretrained='vggface2', classify=False, device='cpu').eval()
|
| 30 |
print("✅ FaceNet models loaded successfully")
|
| 31 |
except Exception as e:
|
|
@@ -79,19 +85,31 @@ class FaceRecognitionSystem:
|
|
| 79 |
pil_image = Image.open(image).convert('RGB')
|
| 80 |
elif isinstance(image, np.ndarray):
|
| 81 |
# numpy array
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
| 83 |
else:
|
| 84 |
# Handle Gradio file object
|
| 85 |
pil_image = Image.open(image).convert('RGB')
|
| 86 |
|
| 87 |
-
# Detect
|
| 88 |
-
|
| 89 |
-
|
|
|
|
| 90 |
return None, "No face detected in the image"
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
return embedding_np, "Face embedding extracted successfully"
|
| 97 |
|
|
@@ -155,7 +173,7 @@ class FaceRecognitionSystem:
|
|
| 155 |
# Find the best match
|
| 156 |
best_match = None
|
| 157 |
min_distance = float('inf')
|
| 158 |
-
threshold = 0
|
| 159 |
|
| 160 |
for student_id, name, class_name, embedding_blob in students:
|
| 161 |
known_embedding = np.frombuffer(embedding_blob, dtype=np.float32)
|
|
@@ -192,15 +210,7 @@ class FaceRecognitionSystem:
|
|
| 192 |
result_message = f"⚠️ {name} already marked today (ID: {student_id})"
|
| 193 |
|
| 194 |
# Convert image for display
|
| 195 |
-
|
| 196 |
-
display_image = cv2.imread(image)
|
| 197 |
-
elif isinstance(image, np.ndarray):
|
| 198 |
-
display_image = image.copy()
|
| 199 |
-
else:
|
| 200 |
-
# Handle Gradio file object
|
| 201 |
-
display_image = np.array(Image.open(image))
|
| 202 |
-
if len(display_image.shape) == 3 and display_image.shape[2] == 3:
|
| 203 |
-
display_image = cv2.cvtColor(display_image, cv2.COLOR_RGB2BGR)
|
| 204 |
|
| 205 |
# Add text to image
|
| 206 |
if display_image is not None:
|
|
@@ -210,12 +220,32 @@ class FaceRecognitionSystem:
|
|
| 210 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
| 211 |
cv2.putText(display_image, f"Class: {class_name}", (10, 90),
|
| 212 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
|
|
|
|
|
|
| 213 |
|
| 214 |
-
display_image = cv2.cvtColor(display_image, cv2.COLOR_BGR2RGB)
|
| 215 |
return display_image, result_message
|
| 216 |
|
| 217 |
return image, "❌ No matching student found"
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
def get_all_students(self):
|
| 220 |
"""Get all registered students"""
|
| 221 |
self.cursor.execute('SELECT name, student_id, class FROM students ORDER BY name')
|
|
@@ -240,6 +270,12 @@ class FaceRecognitionSystem:
|
|
| 240 |
''')
|
| 241 |
return self.cursor.fetchall()
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
# Initialize the system
|
| 244 |
face_system = FaceRecognitionSystem()
|
| 245 |
|
|
|
|
| 21 |
class FaceRecognitionSystem:
|
| 22 |
def __init__(self):
|
| 23 |
self.setup_database()
|
| 24 |
+
global FACE_NET_AVAILABLE
|
| 25 |
|
| 26 |
if FACE_NET_AVAILABLE:
|
| 27 |
try:
|
| 28 |
+
# Initialize MTCNN with simpler settings
|
| 29 |
+
self.mtcnn = MTCNN(
|
| 30 |
+
keep_all=False, # Only detect one face for simplicity
|
| 31 |
+
min_face_size=40,
|
| 32 |
+
thresholds=[0.6, 0.7, 0.7],
|
| 33 |
+
device='cpu'
|
| 34 |
+
)
|
| 35 |
self.resnet = InceptionResnetV1(pretrained='vggface2', classify=False, device='cpu').eval()
|
| 36 |
print("✅ FaceNet models loaded successfully")
|
| 37 |
except Exception as e:
|
|
|
|
| 85 |
pil_image = Image.open(image).convert('RGB')
|
| 86 |
elif isinstance(image, np.ndarray):
|
| 87 |
# numpy array
|
| 88 |
+
if len(image.shape) == 3:
|
| 89 |
+
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
| 90 |
+
else:
|
| 91 |
+
return None, "Invalid image format"
|
| 92 |
else:
|
| 93 |
# Handle Gradio file object
|
| 94 |
pil_image = Image.open(image).convert('RGB')
|
| 95 |
|
| 96 |
+
# Detect face and get cropped face
|
| 97 |
+
face_tensor, prob = self.mtcnn(pil_image, return_prob=True)
|
| 98 |
+
|
| 99 |
+
if face_tensor is None:
|
| 100 |
return None, "No face detected in the image"
|
| 101 |
|
| 102 |
+
if prob < 0.9: # Confidence threshold
|
| 103 |
+
return None, "Low confidence face detection"
|
| 104 |
+
|
| 105 |
+
# Ensure proper tensor dimensions [3, 160, 160]
|
| 106 |
+
if face_tensor.dim() == 4:
|
| 107 |
+
face_tensor = face_tensor.squeeze(0) # Remove batch dimension if present
|
| 108 |
+
|
| 109 |
+
# Get embedding
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
embedding = self.resnet(face_tensor.unsqueeze(0)) # Add batch dimension
|
| 112 |
+
embedding_np = embedding.detach().numpy().flatten()
|
| 113 |
|
| 114 |
return embedding_np, "Face embedding extracted successfully"
|
| 115 |
|
|
|
|
| 173 |
# Find the best match
|
| 174 |
best_match = None
|
| 175 |
min_distance = float('inf')
|
| 176 |
+
threshold = 1.0 # Similarity threshold
|
| 177 |
|
| 178 |
for student_id, name, class_name, embedding_blob in students:
|
| 179 |
known_embedding = np.frombuffer(embedding_blob, dtype=np.float32)
|
|
|
|
| 210 |
result_message = f"⚠️ {name} already marked today (ID: {student_id})"
|
| 211 |
|
| 212 |
# Convert image for display
|
| 213 |
+
display_image = self.prepare_display_image(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
# Add text to image
|
| 216 |
if display_image is not None:
|
|
|
|
| 220 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
| 221 |
cv2.putText(display_image, f"Class: {class_name}", (10, 90),
|
| 222 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
| 223 |
+
cv2.putText(display_image, f"Similarity: {1-distance:.3f}", (10, 120),
|
| 224 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
| 225 |
|
|
|
|
| 226 |
return display_image, result_message
|
| 227 |
|
| 228 |
return image, "❌ No matching student found"
|
| 229 |
|
| 230 |
+
def prepare_display_image(self, image):
|
| 231 |
+
"""Convert image to display format"""
|
| 232 |
+
try:
|
| 233 |
+
if isinstance(image, str):
|
| 234 |
+
display_image = cv2.imread(image)
|
| 235 |
+
elif isinstance(image, np.ndarray):
|
| 236 |
+
display_image = image.copy()
|
| 237 |
+
if len(display_image.shape) == 3 and display_image.shape[2] == 3:
|
| 238 |
+
display_image = cv2.cvtColor(display_image, cv2.COLOR_RGB2BGR)
|
| 239 |
+
else:
|
| 240 |
+
# Handle Gradio file object
|
| 241 |
+
display_image = np.array(Image.open(image))
|
| 242 |
+
if len(display_image.shape) == 3 and display_image.shape[2] == 3:
|
| 243 |
+
display_image = cv2.cvtColor(display_image, cv2.COLOR_RGB2BGR)
|
| 244 |
+
|
| 245 |
+
return display_image
|
| 246 |
+
except:
|
| 247 |
+
return None
|
| 248 |
+
|
| 249 |
def get_all_students(self):
|
| 250 |
"""Get all registered students"""
|
| 251 |
self.cursor.execute('SELECT name, student_id, class FROM students ORDER BY name')
|
|
|
|
| 270 |
''')
|
| 271 |
return self.cursor.fetchall()
|
| 272 |
|
| 273 |
+
# Import torch after FaceNet imports to avoid conflicts
|
| 274 |
+
try:
|
| 275 |
+
import torch
|
| 276 |
+
except ImportError:
|
| 277 |
+
torch = None
|
| 278 |
+
|
| 279 |
# Initialize the system
|
| 280 |
face_system = FaceRecognitionSystem()
|
| 281 |
|