GeraldoRiberia commited on
Commit
9f43980
·
0 Parent(s):

checkpoint

Browse files

-kinda working so far

requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ websockets
4
+ opencv-python
5
+ numpy
6
+ ultralytics
7
+ deepface
server.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import uvicorn
4
+ import cv2
5
+ import numpy as np
6
+ import base64
7
+ import json
8
+ import logging
9
+ import asyncio
10
+ from concurrent.futures import ThreadPoolExecutor
11
+
12
+ from services.single_tracker import SingleTracker
13
+ from services.multi_tracker import MultiTracker
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Executor for CPU-bound tasks
20
+ executor = ThreadPoolExecutor(max_workers=1)
21
+
22
+ app = FastAPI(title="AFS Tracking Backend")
23
+
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"],
27
+ allow_credentials=True,
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+ # Initialize trackers
33
+ single_tracker = SingleTracker()
34
+ multi_tracker = MultiTracker()
35
+
36
+ def decode_binary_image(img_data: bytes):
37
+ """Decodes raw JPEG bytes into an OpenCV numpy array."""
38
+ try:
39
+ nparr = np.frombuffer(img_data, np.uint8)
40
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
41
+ return img
42
+ except Exception as e:
43
+ logger.error(f"Failed to decode image: {e}")
44
+ return None
45
+
46
+ @app.websocket("/ws")
47
+ async def websocket_endpoint(websocket: WebSocket):
48
+ await websocket.accept()
49
+ logger.info("New WebSocket connection established.")
50
+
51
+ current_mode = "single" # Default mode
52
+
53
+ try:
54
+ while True:
55
+ # Receive message (either text JSON or binary frame)
56
+ message = await websocket.receive()
57
+
58
+ if "text" in message:
59
+ try:
60
+ payload = json.loads(message["text"])
61
+ if "mode" in payload and payload["mode"] != current_mode:
62
+ logger.info(f"Switching mode from {current_mode} to {payload['mode']}")
63
+ current_mode = payload["mode"]
64
+ await websocket.send_json({"type": "mode_ack", "mode": current_mode})
65
+ except json.JSONDecodeError:
66
+ logger.error("Invalid JSON received.")
67
+ continue
68
+
69
+ elif "bytes" in message:
70
+ frame_data = message["bytes"]
71
+ frame = decode_binary_image(frame_data)
72
+
73
+ if frame is None:
74
+ await websocket.send_json({"error": "Failed to decode binary frame"})
75
+ continue
76
+
77
+ # Prepare inference function
78
+ def run_inference(f, mode):
79
+ if mode == "single":
80
+ return single_tracker.process_frame(f)
81
+ elif mode == "multi":
82
+ return multi_tracker.process_frame(f)
83
+ else:
84
+ return {"error": f"Unknown mode: {mode}"}
85
+
86
+ # Process Frame in executor
87
+ response_data = {}
88
+ try:
89
+ response_data = await asyncio.get_event_loop().run_in_executor(
90
+ executor, run_inference, frame, current_mode
91
+ )
92
+ except Exception as e:
93
+ logger.error(f"Error processing frame in {current_mode} mode: {e}")
94
+ response_data = {"error": str(e)}
95
+
96
+ # Send results back to client
97
+ response_data["mode"] = current_mode
98
+ await websocket.send_json(response_data)
99
+
100
+ except WebSocketDisconnect:
101
+ logger.info("WebSocket client disconnected.")
102
+ except Exception as e:
103
+ logger.error(f"WebSocket error: {e}")
104
+
105
+ if __name__ == "__main__":
106
+ uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)
services/multi_tracker.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import logging
5
+ from ultralytics import YOLO
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class MultiTracker:
10
+ def __init__(self):
11
+ logger.info("Initializing Multi Tracker (Group Centroid)")
12
+
13
+ # Determine paths
14
+ base_dir = "/Users/adisankarlalan/Documents/GitHub/afs-fl/Model"
15
+ detector_model_path = os.path.join(base_dir, "yolov8n-face.pt")
16
+
17
+ try:
18
+ self.model = YOLO(detector_model_path)
19
+ logger.info(f"Loaded YOLO model from {detector_model_path}")
20
+ except Exception as e:
21
+ logger.error(f"Failed to load YOLO model: {e}")
22
+ self.model = None
23
+
24
+ def process_frame(self, frame):
25
+ """
26
+ Process a single BGR image frame for group object tracking.
27
+ Returns a dictionary with tracking results (individual boxes + aggregate box).
28
+ """
29
+ results_data = {
30
+ "individual_boxes": [],
31
+ "aggregate_box": None,
32
+ "centroid": None,
33
+ "error": None,
34
+ "frame_width": int(frame.shape[1]),
35
+ "frame_height": int(frame.shape[0])
36
+ }
37
+
38
+ if self.model is None:
39
+ results_data["error"] = "Model not initialized"
40
+ return results_data
41
+
42
+ try:
43
+ # RUN BYTETRACK (Detection + Tracking)
44
+ results = self.model.track(frame, persist=True, tracker="bytetrack.yaml", verbose=False)
45
+
46
+ if results and len(results) > 0 and results[0].boxes.id is not None:
47
+ boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
48
+ track_ids = results[0].boxes.id.cpu().numpy().astype(int)
49
+
50
+ all_x1, all_y1, all_x2, all_y2 = [], [], [], []
51
+
52
+ for box, track_id in zip(boxes, track_ids):
53
+ x1, y1, x2, y2 = box.tolist()
54
+
55
+ all_x1.append(x1)
56
+ all_y1.append(y1)
57
+ all_x2.append(x2)
58
+ all_y2.append(y2)
59
+
60
+ results_data["individual_boxes"].append({
61
+ "id": int(track_id),
62
+ "x1": int(x1), "y1": int(y1),
63
+ "x2": int(x2), "y2": int(y2)
64
+ })
65
+
66
+ # Calculate Aggregate Bounding Box if faces exist
67
+ if len(all_x1) > 0:
68
+ agg_x1 = int(min(all_x1))
69
+ agg_y1 = int(min(all_y1))
70
+ agg_x2 = int(max(all_x2))
71
+ agg_y2 = int(max(all_y2))
72
+
73
+ # Aggregate Centroid
74
+ agg_cx = (agg_x1 + agg_x2) // 2
75
+ agg_cy = (agg_y1 + agg_y2) // 2
76
+
77
+ results_data["aggregate_box"] = {
78
+ "x1": agg_x1, "y1": agg_y1,
79
+ "x2": agg_x2, "y2": agg_y2
80
+ }
81
+ results_data["centroid"] = {
82
+ "cx": agg_cx,
83
+ "cy": agg_cy
84
+ }
85
+
86
+ except Exception as e:
87
+ logger.error(f"Error during ByteTrack: {e}")
88
+ results_data["error"] = str(e)
89
+
90
+ return results_data
services/single_tracker.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import pickle
4
+ import numpy as np
5
+ import logging
6
+ from ultralytics import YOLO
7
+ from deepface import DeepFace
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class SingleTracker:
12
+ def __init__(self):
13
+ logger.info("Initializing Single Tracker (Face Priority)")
14
+
15
+ # Configuration matches face_model.py
16
+ self.base_dir = "/Users/adisankarlalan/Documents/GitHub/afs-fl/Model"
17
+ self.reference_video_path = os.path.join(self.base_dir, 'my_scan.mp4')
18
+ self.model_name = "ArcFace"
19
+ self.detector_model_path = os.path.join(self.base_dir, "yolov8n-face.pt")
20
+ self.cache_file = os.path.join(self.base_dir, "embeddings_cache.pkl")
21
+
22
+ # State
23
+ self.priority_track_id = None
24
+ self.known_tracks = {} # {track_id: is_main_user}
25
+ self.track_retries = {} # {track_id: retry_count}
26
+
27
+ self.max_retries = 20
28
+ self.similarity_threshold = 0.70
29
+
30
+ self.main_user_embeddings = []
31
+ self._load_embeddings()
32
+
33
+ try:
34
+ self.model = YOLO(self.detector_model_path)
35
+ logger.info("Loaded YOLO model")
36
+ except Exception as e:
37
+ logger.error(f"Failed to load YOLO model: {e}")
38
+ self.model = None
39
+
40
+ def _is_cache_valid(self, cache_data):
41
+ if not cache_data:
42
+ return False
43
+ if cache_data.get('video_path') != 'my_scan.mp4' and cache_data.get('video_path') != self.reference_video_path:
44
+ return False
45
+ if cache_data.get('model_name') != self.model_name:
46
+ return False
47
+ if cache_data.get('version', 1) < 2:
48
+ return False
49
+ return True
50
+
51
+ def _load_embeddings(self):
52
+ logger.info("Loading main user embeddings...")
53
+ cache_loaded = False
54
+
55
+ if os.path.exists(self.cache_file):
56
+ try:
57
+ with open(self.cache_file, 'rb') as f:
58
+ cache_data = pickle.load(f)
59
+
60
+ if self._is_cache_valid(cache_data):
61
+ self.main_user_embeddings = cache_data['embeddings']
62
+ logger.info("Loaded master signature from cache")
63
+ cache_loaded = True
64
+ except Exception as e:
65
+ logger.error(f"Could not load cache: {e}")
66
+
67
+ if not cache_loaded:
68
+ logger.warning(f"Cache invalid or not found at {self.cache_file}. Returning empty embeddings. Please run Model/face_model.py to generate cache.")
69
+
70
+ def process_frame(self, frame):
71
+ """
72
+ Process a single BGR image frame for single face tracking.
73
+ Returns a dictionary with tracking results.
74
+ """
75
+ results_data = {
76
+ "boxes": [],
77
+ "priority_id": self.priority_track_id,
78
+ "error": None,
79
+ "frame_width": int(frame.shape[1]),
80
+ "frame_height": int(frame.shape[0])
81
+ }
82
+
83
+ if self.model is None:
84
+ results_data["error"] = "Model not initialized"
85
+ return results_data
86
+
87
+ try:
88
+ # RUN BYTETRACK
89
+ results = self.model.track(frame, persist=True, tracker="bytetrack.yaml", verbose=False)
90
+
91
+ if results and len(results) > 0 and results[0].boxes.id is not None:
92
+ boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
93
+ track_ids = results[0].boxes.id.cpu().numpy().astype(int)
94
+
95
+ for box, track_id in zip(boxes, track_ids):
96
+ x1, y1, x2, y2 = box.tolist()
97
+ track_id = int(track_id)
98
+ max_similarity = 0.0
99
+
100
+ # Lock resolution logic
101
+ if track_id not in self.known_tracks and len(self.main_user_embeddings) > 0:
102
+ if track_id not in self.track_retries:
103
+ self.track_retries[track_id] = 0
104
+
105
+ # Crop face
106
+ face_crop = frame[y1:y2, x1:x2]
107
+
108
+ try:
109
+ # Strict check
110
+ current_face = DeepFace.represent(face_crop, model_name=self.model_name, enforce_detection=False)[0]["embedding"]
111
+
112
+ for user_embedding in self.main_user_embeddings:
113
+ sim = np.dot(user_embedding, current_face) / (np.linalg.norm(user_embedding) * np.linalg.norm(current_face))
114
+ if sim > max_similarity:
115
+ max_similarity = sim
116
+
117
+ max_similarity = float(max_similarity)
118
+
119
+ if max_similarity > self.similarity_threshold:
120
+ self.known_tracks[track_id] = True
121
+ self.priority_track_id = track_id
122
+ results_data["priority_id"] = track_id
123
+ if track_id in self.track_retries:
124
+ del self.track_retries[track_id]
125
+ else:
126
+ self.track_retries[track_id] += 1
127
+ if self.track_retries[track_id] > self.max_retries:
128
+ self.known_tracks[track_id] = False
129
+ del self.track_retries[track_id]
130
+
131
+ except Exception as e:
132
+ # Exception means no face/blur => skip for this frame but count retry
133
+ logger.error(f"DeepFace failed on track_id {track_id}: {e}")
134
+ self.track_retries[track_id] += 1
135
+ if self.track_retries[track_id] > self.max_retries:
136
+ self.known_tracks[track_id] = False
137
+ del self.track_retries[track_id]
138
+ else:
139
+ # Ensures unknown tracks still get registered for scanning display
140
+ if track_id not in self.known_tracks and track_id not in self.track_retries:
141
+ self.track_retries[track_id] = 0
142
+
143
+ # Determine label and color representation
144
+ is_target = self.known_tracks.get(track_id, False)
145
+ if is_target:
146
+ label = f"TARGET LOCKED"
147
+ results_data["boxes"].append({
148
+ "id": track_id,
149
+ "x1": x1, "y1": y1,
150
+ "x2": x2, "y2": y2,
151
+ "is_target": True,
152
+ "label": label,
153
+ "similarity": max_similarity if 'max_similarity' in locals() else -1.0
154
+ })
155
+ elif track_id in self.track_retries:
156
+ # Draw scanning box
157
+ label = f"SCANNING"
158
+ results_data["boxes"].append({
159
+ "id": track_id,
160
+ "x1": x1, "y1": y1,
161
+ "x2": x2, "y2": y2,
162
+ "is_target": False,
163
+ "label": label,
164
+ "similarity": max_similarity if 'max_similarity' in locals() else -1.0
165
+ })
166
+
167
+ except Exception as e:
168
+ logger.error(f"Error during SingleTrack: {e}")
169
+ results_data["error"] = str(e)
170
+
171
+ return results_data