ADharsh22 commited on
Commit
f10370d
·
0 Parent(s):

Initial commit: Ready for Streamlit deployment

Browse files
Files changed (12) hide show
  1. .DS_Store +0 -0
  2. .gitignore +12 -0
  3. Facerecognition +0 -0
  4. app.py +151 -0
  5. requirements.txt +6 -0
  6. run.py +117 -0
  7. src/detect.py +51 -0
  8. src/embed.py +38 -0
  9. src/preload.py +26 -0
  10. src/recognize.py +45 -0
  11. src/register.py +27 -0
  12. src/utils.py +46 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Virtual environment directory
2
+ .venv/
3
+ # Python files generated by the system
4
+ __pycache__/
5
+ *.pyc
6
+ # Data storage (embeddings change locally, shouldn't be in the repo)
7
+ data/embeddings.pkl
8
+ # Local Streamlit cache and config
9
+ .streamlit/
10
+ # Python dependencies (listed in requirements.txt)
11
+ /lib/
12
+ /include/
Facerecognition ADDED
File without changes
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_webrtc import webrtc_streamer, VideoTransformerBase
3
+ import cv2
4
+ import numpy as np
5
+
6
+ # --- Import your modular code ---
7
+ from src.utils import load_embeddings, save_embeddings
8
+ from src.detect import detect_face
9
+ from src.embed import get_embedding
10
+ from src.recognize import recognize_face_by_embedding
11
+ from src.register import register_new_user # Ensure this is in your src folder
12
+
13
+ # Use Streamlit's cache to load the persistent data once
14
+ @st.cache_resource
15
+ def load_registered_data():
16
+ """Load the persistent data structure."""
17
+ return load_embeddings()
18
+
19
+ # Global variables for capturing a registration frame
20
+ REGISTRATION_FRAME = None
21
+ # Initialize frame_lock in session state if it doesn't exist
22
+ if 'frame_lock' not in st.session_state:
23
+ st.session_state['frame_lock'] = False
24
+ FRAME_LOCK = st.session_state['frame_lock']
25
+
26
+ class FaceRecognitionTransformer(VideoTransformerBase):
27
+ def __init__(self, data_store, detector_key, mode):
28
+ self.data_store = data_store
29
+ self.detector_key = detector_key
30
+ self.mode = mode
31
+
32
+ def transform(self, frame):
33
+ img = frame.to_ndarray(format="bgr")
34
+
35
+ # 1. Detection
36
+ detected_faces = detect_face(img, detector_type=self.detector_key)
37
+
38
+ if detected_faces:
39
+ # We focus on the largest face for simplicity
40
+ main_face = max(detected_faces, key=lambda x: x['box'][2] * x['box'][3])
41
+ x, y, w, h = main_face['box']
42
+
43
+ # Crop the face for processing (Alignment & Cropping)
44
+ face_img = img[y:y+h, x:x+w]
45
+
46
+ # 2. Draw bounding box and label
47
+ label = "Processing..."
48
+ color = (255, 255, 0) # Yellow/Cyan
49
+
50
+ if self.mode == "Recognition":
51
+ # 3. Recognition Logic
52
+
53
+ # Embedding extraction
54
+ embedding = get_embedding(face_img)
55
+
56
+ # Comparison Logic
57
+ name, distance = recognize_face_by_embedding(
58
+ embedding,
59
+ self.data_store['embeddings'],
60
+ self.data_store['names']
61
+ )
62
+
63
+ label = f"{name} (Dist: {distance:.2f})"
64
+ if name != "Unknown":
65
+ color = (0, 255, 0) # Green for known
66
+ else:
67
+ color = (0, 0, 255) # Red for unknown
68
+
69
+ elif self.mode == "Registration":
70
+ # 3. Registration capture mode
71
+ label = "Ready to Register"
72
+ color = (0, 255, 255) # Yellow
73
+
74
+ # Capture the frame for registration if the lock is not set
75
+ global REGISTRATION_FRAME
76
+ if not st.session_state['frame_lock']:
77
+ REGISTRATION_FRAME = face_img
78
+
79
+ # Draw the box and text
80
+ cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
81
+ cv2.putText(img, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
82
+
83
+ return img
84
+
85
+ # --- Streamlit UI ---
86
+
87
+ st.set_page_config(page_title="Smart Office Face Recognition", layout="wide")
88
+ # Removed erroneous tag
89
+ st.title("👨‍💼 Smart Office Access System")
90
+ st.markdown("This system identifies staff members using a webcam, welcomes them, or registers new visitors.")
91
+
92
+ st.sidebar.title("System Controls")
93
+ mode = st.sidebar.radio("Select Mode", ["Recognition", "Registration"])
94
+ detector_type = st.sidebar.radio(
95
+ "Select Detector (Bonus Feature)",
96
+ ["CNN-based (Default)", "Classical (Haar Cascade)"]
97
+ )
98
+ detector_key = 'cnn' if detector_type == 'CNN-based (Default)' else 'classical'
99
+
100
+ # Load the persistent data
101
+ data_store = load_registered_data()
102
+
103
+ if mode == "Registration":
104
+ st.header("📝 New User Registration")
105
+ st.info("Press 'Start' to activate the camera. When ready and a face is detected, enter the staff name and press 'Capture & Register'.")
106
+
107
+ user_name = st.text_input("Enter Staff Name", key='reg_name')
108
+
109
+ # ------------------ Registration Stream ------------------
110
+ webrtc_ctx_reg = webrtc_streamer(
111
+ key="registration_stream",
112
+ video_transformer_factory=lambda: FaceRecognitionTransformer(data_store, detector_key, mode),
113
+ rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
114
+ media_stream_constraints={"video": True, "audio": False},
115
+ async_transform=True,
116
+ )
117
+
118
+ if st.button("Capture & Register") and webrtc_ctx_reg.state.playing:
119
+ if not user_name:
120
+ st.error("Please enter a name for registration.")
121
+ elif REGISTRATION_FRAME is None:
122
+ st.error("No face detected in the frame. Please look at the camera.")
123
+ else:
124
+ # Set lock to prevent transformer from overwriting REGISTRATION_FRAME
125
+ st.session_state['frame_lock'] = True
126
+
127
+ with st.spinner(f"Registering {user_name} with {detector_type} detector..."):
128
+ if register_new_user(REGISTRATION_FRAME, user_name):
129
+ st.success(f"Registration successful for **{user_name}**! Embedding stored to {{data/embeddings.pkl}}")
130
+ # Force reload the data store to include the new user
131
+ load_registered_data.clear()
132
+ st.session_state['frame_lock'] = False
133
+ else:
134
+ st.error("Registration failed. Could not generate embedding.")
135
+
136
+ elif mode == "Recognition":
137
+ st.header("🔑 Real-time Recognition Check")
138
+ st.write(f"Currently **{len(data_store['embeddings'])}** users are registered.")
139
+
140
+ # ------------------ Recognition Stream ------------------
141
+ webrtc_ctx_rec = webrtc_streamer(
142
+ key="recognition_stream",
143
+ video_transformer_factory=lambda: FaceRecognitionTransformer(data_store, detector_key, mode),
144
+ rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
145
+ media_stream_constraints={"video": True, "audio": False},
146
+ async_transform=True,
147
+ )
148
+
149
+ # ✅ FIX: IndentationError is fixed here by indenting the st.success line.
150
+ if webrtc_ctx_rec.state.playing:
151
+ st.success(f"Recognition running with **{detector_type}** detector...")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ streamlit-webrtc
3
+ opencv-python
4
+ deepface
5
+ numpy
6
+ scikit-learn
run.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # run.py
2
+ # run.py (Add this to the very top, before imports)
3
+
4
+ import os
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Forces TensorFlow to use CPU only
6
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress warnings
7
+ os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" # For GPU, but harmless on CPU
8
+ os.environ["NUMEXPR_NUM_THREADS"] = "1" # Limit threads for NumPy/NumExpr
9
+ import argparse
10
+ import cv2
11
+ import sys
12
+ import numpy as np
13
+ import tensorflow as tf
14
+ # Import your core logic modules
15
+ from src.utils import load_embeddings
16
+ from src.detect import detect_face
17
+ from src.embed import get_embedding
18
+ from src.recognize import recognize_face_by_embedding
19
+ from src.register import register_new_user
20
+
21
+ # Define the logic to process a captured frame for registration or recognition
22
+ def process_frame(frame, mode, detector_key, data_store):
23
+
24
+ # 1. Detection
25
+ detected_faces = detect_face(frame, detector_type=detector_key)
26
+
27
+ if detected_faces:
28
+ # Focus on the largest face for processing
29
+ main_face = max(detected_faces, key=lambda x: x['box'][2] * x['box'][3])
30
+ x, y, w, h = main_face['box']
31
+ face_img = frame[y:y+h, x:x+w]
32
+
33
+ # Draw bounding box on the displayed frame
34
+ cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 255), 2)
35
+
36
+ if mode == 'register':
37
+ # Ask for name when face is captured
38
+ name = input("Enter name to register: ")
39
+
40
+ # 2. Registration Logic
41
+ if register_new_user(face_img, name):
42
+ return f"✅ Registration successful for {name}! Please run recognize mode next."
43
+ else:
44
+ return "❌ Registration failed. Could not generate embedding."
45
+
46
+ elif mode in ['recognize', 'identify']:
47
+ # 3. Recognition Logic
48
+ embedding = get_embedding(face_img)
49
+
50
+ if embedding is None:
51
+ return "⚠️ Could not generate embedding for recognition."
52
+
53
+ name, distance = recognize_face_by_embedding(
54
+ embedding,
55
+ data_store['embeddings'],
56
+ data_store['names']
57
+ )
58
+
59
+ if name != "Unknown":
60
+ cv2.putText(frame, name, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
61
+ return f"✅ Recognized: {name} (Dist: {distance:.2f})"
62
+ else:
63
+ cv2.putText(frame, "Unknown", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
64
+ return f"⚠️ Unknown Person (Dist: {distance:.2f})"
65
+
66
+ return "No face detected in the frame."
67
+
68
+ # --- Main CLI Function ---
69
+ def main():
70
+ parser = argparse.ArgumentParser(description="Smart Office Face Recognition System CLI.")
71
+ parser.add_argument('--mode', required=True, choices=['register', 'recognize', 'identify'], help="Operation mode (register or identify/recognize).")
72
+ parser.add_argument('--detector', required=True, choices=['cnn', 'classical'], help="Detector backend: cnn (MTCNN) or classical (Haar Cascade).")
73
+
74
+ args = parser.parse_args()
75
+ data_store = load_embeddings() # Initial load of persistent data
76
+
77
+ cap = cv2.VideoCapture(0)
78
+ if not cap.isOpened():
79
+ print("Error: Could not open webcam.")
80
+ sys.exit(1)
81
+
82
+ print(f"\n--- Running in {args.mode.upper()} mode with {args.detector.upper()} detector. ---")
83
+ print("Press 'c' to CAPTURE/PROCESS a face or 'q' to QUIT.\n")
84
+
85
+ while True:
86
+ ret, frame = cap.read()
87
+ if not ret:
88
+ break
89
+
90
+ # Display instructions on the frame
91
+ cv2.putText(frame, f"MODE: {args.mode.upper()} | DETECTOR: {args.detector.upper()}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
92
+ cv2.putText(frame, "Press 'c' to capture.", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
93
+ cv2.imshow('Smart Office Face Recognition CLI', frame)
94
+
95
+ key = cv2.waitKey(1) & 0xFF
96
+
97
+ if key == ord('q'):
98
+ break
99
+
100
+ elif key == ord('c'):
101
+ print(f"Capture command received. Processing frame in {args.mode} mode...")
102
+
103
+ # Process the captured frame
104
+ result = process_frame(frame, args.mode, args.detector, data_store)
105
+ print(result)
106
+
107
+ # FIX: Pylance Scope Fix - If registered, reload the local data_store variable for immediate recognition.
108
+ if "Registration successful" in result:
109
+ data_store = load_embeddings()
110
+ print("Data store reloaded for new user recognition.")
111
+
112
+
113
+ cap.release()
114
+ cv2.destroyAllWindows()
115
+
116
+ if __name__ == '__main__':
117
+ main()
src/detect.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from deepface import DeepFace
3
+
4
+ def detect_face(image, detector_type='cnn'):
5
+ """
6
+ Detects faces in the image using the specified method (MTCNN or Haar Cascade).
7
+
8
+ Args:
9
+ image (np.array): The input BGR image frame.
10
+ detector_type (str): 'cnn' for MTCNN (default) or 'classical' for Haar Cascade.
11
+
12
+ Returns:
13
+ list: A list of dicts/tuples containing detected face info (bounding box, landmarks, confidence).
14
+ Format: [{'box': (x, y, w, h), 'landmarks': {...}, 'confidence': float}]
15
+ """
16
+
17
+ # Map detector type to DeepFace backend name
18
+ if detector_type == 'cnn':
19
+ backend = 'mtcnn' # Multi-task Cascaded Convolutional Neural Network [cite: 21]
20
+ elif detector_type == 'classical':
21
+ backend = 'opencv' # DeepFace uses 'opencv' for Haar Cascade [cite: 26]
22
+ else:
23
+ return []
24
+
25
+ results = []
26
+
27
+ try:
28
+ # DeepFace handles detection, alignment, and returns landmarks (for MTCNN)
29
+ detected_faces = DeepFace.extract_faces(
30
+ img_path=image,
31
+ detector_backend=backend,
32
+ enforce_detection=False # Allow processing even if no face is initially found
33
+ )
34
+
35
+ for face_info in detected_faces:
36
+ x, y, w, h = face_info['facial_area'].values()
37
+
38
+ # Note: DeepFace automatically handles alignment and cropping internally for embedding,
39
+ # but we return the raw box and a placeholder for structured output
40
+ results.append({
41
+ 'box': (x, y, w, h),
42
+ # Landmarks are useful for visualizing the alignment process
43
+ 'landmarks': face_info.get('landmarks', {}),
44
+ 'confidence': face_info.get('confidence', 1.0)
45
+ })
46
+
47
+ except Exception as e:
48
+ # print(f"Detection error with {detector_type}: {e}")
49
+ pass
50
+
51
+ return results
src/embed.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from deepface import DeepFace
3
+
4
+ # Global variables to hold the model configuration
5
+ MODEL_NAME = "Facenet"
6
+ DIMENSIONS = 128
7
+
8
+ def get_embedding(face_image):
9
+ """
10
+ Extracts the 128-dimensional embedding from the face image using FaceNet.
11
+
12
+ Args:
13
+ face_image (np.array): The input BGR image frame (must contain a face).
14
+
15
+ Returns:
16
+ np.array or None: The 128-dimensional embedding vector.
17
+ """
18
+ try:
19
+ # DeepFace handles alignment, preprocessing, and model prediction internally.
20
+ # Ensure only the area containing the face is passed, or let DeepFace handle cropping.
21
+
22
+ # We use a wrapper function to ensure only the embedding is returned
23
+ embedding_objs = DeepFace.represent(
24
+ img_path=face_image,
25
+ model_name=MODEL_NAME,
26
+ enforce_detection=False # If face is already pre-cropped
27
+ )
28
+
29
+ if embedding_objs:
30
+ # The embedding is a 128-D vector
31
+ embedding = embedding_objs[0]["embedding"]
32
+ return np.array(embedding)
33
+
34
+ except Exception as e:
35
+ # print(f"Embedding generation error: {e}")
36
+ return None
37
+
38
+ return None
src/preload.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # preload.py (Stored in ~/.deepface/weights/ directory)
2
+
3
+ import deepface
4
+ import os
5
+
6
+ print("--- Starting DeepFace Model Preloading ---")
7
+
8
+ # 1. Force download and initialization of FaceNet (for embeddings)
9
+ print("Loading FaceNet model (128-dimensional embeddings)...")
10
+ try:
11
+ # Use dot notation for DeepFace submodules to avoid Pylance errors
12
+ model_facenet = deepface.basemodels.FaceNet.loadModel()
13
+ print("✅ FaceNet loaded successfully.")
14
+ except Exception as e:
15
+ print(f"❌ Error loading FaceNet: {e}")
16
+
17
+ # 2. Force download and initialization of MTCNN (CNN-based detection)
18
+ print("Loading MTCNN detector (CNN-based detection)...")
19
+ try:
20
+ # Use dot notation for DeepFace submodules
21
+ detector_mtcnn = deepface.detectors.FaceDetector.build_model('mtcnn')
22
+ print("✅ MTCNN loaded successfully.")
23
+ except Exception as e:
24
+ print(f"❌ Error loading MTCNN: {e}")
25
+
26
+ print("--- Preloading Complete ---")
src/recognize.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics.pairwise import cosine_similarity
3
+
4
+ # The user is identified if cosine distance < 0.5.
5
+ # Cosine Similarity = 1 - Cosine Distance
6
+ SIMILARITY_THRESHOLD = 0.5
7
+
8
+ def recognize_face_by_embedding(current_embedding, known_embeddings, known_names):
9
+ """
10
+ Identifies a known person by comparing the current face embedding against stored ones.
11
+
12
+ Args:
13
+ current_embedding (np.array): The 128-D embedding of the live face.
14
+ known_embeddings (list): List of stored embedding vectors (numpy arrays).
15
+ known_names (list): List of names corresponding to the embeddings.
16
+
17
+ Returns:
18
+ tuple: (Identified Name or "Unknown", Cosine Distance)
19
+ """
20
+ if not known_embeddings or current_embedding is None:
21
+ return "Unknown", 1.0 # No users registered or failed embedding
22
+
23
+ # Reshape for comparison: (1, 128)
24
+ current_embedding = current_embedding.reshape(1, -1)
25
+
26
+ # Convert list of known embeddings to a NumPy array (N, 128)
27
+ known_embeddings_array = np.array(known_embeddings)
28
+
29
+ # Calculate Cosine Similarity (vector comparison)
30
+ similarities = cosine_similarity(current_embedding, known_embeddings_array)[0]
31
+
32
+ # Find the best match (highest similarity)
33
+ best_match_index = np.argmax(similarities)
34
+ best_similarity = similarities[best_match_index]
35
+
36
+ # Calculate distance for checking the threshold: Distance = 1 - Similarity
37
+ best_distance = 1.0 - best_similarity
38
+
39
+ # Check if the similarity surpasses the threshold (i.e., distance < 0.5)
40
+ if best_distance < SIMILARITY_THRESHOLD:
41
+ identified_name = known_names[best_match_index]
42
+ return identified_name, best_distance
43
+ else:
44
+ # System notes the person as unknown [cite: 44]
45
+ return "Unknown", best_distance
src/register.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .embed import get_embedding
2
+ from .utils import load_embeddings, save_embeddings
3
+
4
+ def register_new_user(face_image, name):
5
+ """
6
+ Registers a new user by generating an embedding and persisting it.
7
+
8
+ Args:
9
+ face_image (np.array): The input BGR image frame (must contain a face).
10
+ name (str): The name of the user to register.
11
+
12
+ Returns:
13
+ bool: True on successful registration, False otherwise.
14
+ """
15
+ embedding = get_embedding(face_image)
16
+
17
+ if embedding is not None:
18
+ data = load_embeddings()
19
+
20
+ # Store the created embedding and name
21
+ data['embeddings'].append(embedding)
22
+ data['names'].append(name)
23
+
24
+ # Persist data to {data/embeddings.pkl} [cite: 38]
25
+ return save_embeddings(data)
26
+
27
+ return False
src/utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+
4
+ # Define the persistent storage path, navigating two levels up from src/
5
+ EMBEDDINGS_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'embeddings.pkl')
6
+
7
+ def load_embeddings():
8
+ """
9
+ Loads all registered embeddings and associated names from the pickle file.
10
+
11
+ Returns:
12
+ dict: A dictionary with 'embeddings' (list of numpy arrays) and 'names' (list of strings).
13
+ """
14
+ if not os.path.exists(EMBEDDINGS_PATH):
15
+ # Initialize an empty structure if the file doesn't exist
16
+ return {'embeddings': [], 'names': []}
17
+
18
+ try:
19
+ with open(EMBEDDINGS_PATH, 'rb') as f:
20
+ data = pickle.load(f)
21
+ # Ensure the loaded data has the expected structure
22
+ if 'embeddings' not in data or 'names' not in data:
23
+ return {'embeddings': [], 'names': []}
24
+ return data
25
+ except Exception as e:
26
+ print(f"Error loading embeddings: {e}")
27
+ # In case of corruption, return empty data
28
+ return {'embeddings': [], 'names': []}
29
+
30
+ def save_embeddings(data):
31
+ """
32
+ Saves the updated embeddings and names back to the pickle file.
33
+
34
+ Args:
35
+ data (dict): The dictionary containing 'embeddings' and 'names' lists.
36
+ """
37
+ # Ensure the data directory exists
38
+ os.makedirs(os.path.dirname(EMBEDDINGS_PATH), exist_ok=True)
39
+
40
+ try:
41
+ with open(EMBEDDINGS_PATH, 'wb') as f:
42
+ pickle.dump(data, f)
43
+ return True
44
+ except Exception as e:
45
+ print(f"Error saving embeddings: {e}")
46
+ return False