Spaces:
Running
Running
File size: 39,541 Bytes
1c44b5a 9f43980 679b5d0 b53629f 9f43980 b53629f 679b5d0 b53629f 1c44b5a b53629f e448396 94741c6 b53629f 9f43980 b53629f 9f43980 679b5d0 d5fe478 679b5d0 1949831 679b5d0 1c44b5a 679b5d0 514a298 679b5d0 9f43980 b53629f 79a6aec 9f43980 b53629f 6ae3d39 b53629f e448396 b53629f 514a298 b53629f e448396 b53629f e448396 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 9f43980 514a298 9f43980 514a298 679b5d0 1949831 514a298 679b5d0 514a298 679b5d0 514a298 1c44b5a 514a298 679b5d0 514a298 679b5d0 1c44b5a 514a298 679b5d0 514a298 679b5d0 514a298 679b5d0 514a298 679b5d0 1c44b5a 514a298 679b5d0 514a298 679b5d0 1c44b5a 1949831 1c44b5a 514a298 1c44b5a 514a298 1c44b5a 679b5d0 514a298 679b5d0 514a298 679b5d0 514a298 679b5d0 514a298 679b5d0 514a298 679b5d0 514a298 d5fe478 514a298 956dd89 6ae3d39 956dd89 6ae3d39 956dd89 514a298 6ae3d39 514a298 6ae3d39 514a298 6ae3d39 514a298 6ae3d39 514a298 6ae3d39 514a298 d5fe478 6ae3d39 b53629f 27ab41f 514a298 6ae3d39 514a298 6ae3d39 514a298 27ab41f 6ae3d39 27ab41f 1949831 27ab41f 6ae3d39 b53629f d5fe478 6ae3d39 d5fe478 b53629f 514a298 b53629f 514a298 b53629f 514a298 94741c6 9f43980 1949831 514a298 9f43980 514a298 94741c6 9f43980 1949831 9f43980 94741c6 1949831 679b5d0 1949831 679b5d0 1949831 679b5d0 d5fe478 679b5d0 d5fe478 679b5d0 9f43980 514a298 9f43980 94741c6 9f43980 94741c6 9f43980 94741c6 9f43980 1949831 9f43980 514a298 679b5d0 514a298 d5fe478 514a298 679b5d0 514a298 679b5d0 514a298 679b5d0 514a298 679b5d0 9f43980 679b5d0 9f43980 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 09083dc 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 b53629f 514a298 9216a20 514a298 9216a20 988f34d 9216a20 514a298 9216a20 988f34d 9216a20 988f34d 514a298 9216a20 514a298 988f34d 514a298 988f34d 514a298 988f34d 9216a20 988f34d 9216a20 988f34d 514a298 9216a20 988f34d 6ae3d39 988f34d 6ae3d39 514a298 6ae3d39 514a298 6ae3d39 514a298 9216a20 6ae3d39 9216a20 988f34d 6ae3d39 988f34d 514a298 9216a20 988f34d 6ae3d39 988f34d 514a298 6ae3d39 514a298 6ae3d39 514a298 988f34d 6ae3d39 988f34d 6ae3d39 514a298 1c44b5a 514a298 1c44b5a 514a298 1c44b5a 514a298 1c44b5a 514a298 6ae3d39 514a298 6ae3d39 514a298 6ae3d39 514a298 6ae3d39 514a298 6ae3d39 514a298 6ae3d39 514a298 6ae3d39 988f34d 9f43980 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 | from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, status, Depends, UploadFile, File, Form, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import uvicorn
import cv2
import numpy as np
import json
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
import threading
import os
import base64
import hashlib
import math
from pydantic import BaseModel, Field
from pymongo import AsyncMongoClient
import bcrypt
import pickle
from bson import ObjectId
from jose import JWTError, jwt
from dotenv import load_dotenv
from pathlib import Path
import shutil
import uuid
from services.single_tracker import SingleTracker
from services.multi_tracker import MultiTracker
from services.face_recognition import FaceRecognitionService
from services.audio_processing import AudioProcessor
# Load environment variables from .env file
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Executor for CPU-bound tasks
executor = ThreadPoolExecutor(max_workers=1)
# --- OBS and Recording State ---
latest_obs_frame = None # Store the latest JPEG encoded cropped frame for the OBS feed
obs_frame_lock = threading.Lock()
is_obs_active = False
is_recording = False
video_writer = None
recording_filename = ""
# --- Center Stage State (EMA Smoothing) ---
current_cx = 0.5
current_cy = 0.5
current_scale = 1.0
zoom_multiplier = 1.0
# --- Real-time Target Tracking State ---
current_target_angle = None
current_target_distance = None
# Configurable parameters for smooth panning
# Lower is smoother but slower (similar to Dart's TweenAnimation)
SMOOTHING_FACTOR = 0.1
TARGET_ASPECT_RATIO = 16.0 / 9.0 # Assuming output is meant to be 16:9
app = FastAPI(title="AFS Tracking Backend")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize trackers and services
MODEL_DIR = Path(__file__).parent / "Model"
single_tracker = SingleTracker()
multi_tracker = MultiTracker()
face_service = FaceRecognitionService(str(MODEL_DIR))
audio_processor = AudioProcessor(str(MODEL_DIR))
# MongoDB state
mongo_client: AsyncMongoClient | None = None
users_collection = None
audio_recordings_collection = None
audio_settings_collection = None
audio_angles_collection = None
# JWT Configuration
SECRET_KEY = os.getenv(
"JWT_SECRET_KEY", "your-secret-key-change-in-production")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
security = HTTPBearer()
class RegisterRequest(BaseModel):
full_name: str = Field(min_length=2, max_length=80)
email: str = Field(min_length=5, max_length=254)
password: str = Field(min_length=8, max_length=128)
class LoginRequest(BaseModel):
email: str = Field(min_length=5, max_length=254)
password: str = Field(min_length=8, max_length=128)
class UserPublic(BaseModel):
id: str
full_name: str
email: str
class AuthResponse(BaseModel):
ok: bool
message: str
user: UserPublic
token: str
def normalize_email(email: str) -> str:
return email.strip().lower()
def get_password_hash(password: str) -> str:
return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
def require_users_collection():
if users_collection is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database is not initialized yet. Please retry.",
)
return users_collection
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
collection = require_users_collection()
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)
from bson import ObjectId
try:
user_doc = await collection.find_one({"_id": ObjectId(user_id)})
except:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
if user_doc is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
return UserPublic(
id=str(user_doc["_id"]),
full_name=user_doc["full_name"],
email=user_doc["email"],
)
def decode_binary_image(img_data: bytes):
"""Decodes raw JPEG bytes into an OpenCV numpy array."""
try:
nparr = np.frombuffer(img_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return img
except Exception as e:
logger.error(f"Failed to decode image: {e}")
return None
def apply_center_stage_crop(frame, tracking_data):
"""
Applies an exponential moving average (EMA) to smoothly pan and zoom
the frame based on the tracking target bounding box.
Returns the cropped frame.
"""
global current_cx, current_cy, current_scale, current_target_angle, current_target_distance, zoom_multiplier
h, w = frame.shape[:2]
# Defaults
target_cx = 0.5
target_cy = 0.5
target_scale = 1.0
target_found = False
# Calculate target state based on tracking data
boxes = tracking_data.get("boxes", [])
if tracking_data.get("mode") == "multi":
if "aggregate_box" in tracking_data:
ab = tracking_data["aggregate_box"]
box_cx = (ab["x1"] + ab["x2"]) / 2.0
box_cy = (ab["y1"] + ab["y2"]) / 2.0
box_w = ab["x2"] - ab["x1"]
box_h = ab["y2"] - ab["y1"]
target_cx = box_cx / w
target_cy = box_cy / h
target_found = True
# Target scale logic (from Dart): max dimension proportion * 1.5 margin
max_dim = max(box_w / w, box_h / h)
target_scale = 1.0 / (max_dim * 1.5)
# Clamp scale
target_scale = max(1.0, min(target_scale, 3.0))
else: # single
target_box = None
for b in boxes:
if b.get("is_target"):
target_box = b
break
if target_box:
box_cx = (target_box["x1"] + target_box["x2"]) / 2.0
box_cy = (target_box["y1"] + target_box["y2"]) / 2.0
box_w = target_box["x2"] - target_box["x1"]
box_h = target_box["y2"] - target_box["y1"]
target_cx = box_cx / w
target_cy = box_cy / h
target_found = True
max_dim = max(box_w / w, box_h / h)
# slightly tighter for single person
target_scale = 1.0 / (max_dim * 2.0)
target_scale = max(1.0, min(target_scale, 3.0))
if target_found:
# Apply user zoom multiplier
target_scale = max(1.0, min(target_scale * zoom_multiplier, 10.0))
# Calculate distance and angle from the frame center (w/2, h/2) to the target bounding box center (box_cx, box_cy)
center_x, center_y = w / 2.0, h / 2.0
dx = box_cx - center_x
dy = box_cy - center_y
current_target_distance = math.hypot(dx, dy)
# Convert atan2 result to 0-360 degrees
angle = math.degrees(math.atan2(dy, dx))
current_target_angle = angle % 360.0
else:
current_target_angle = None
current_target_distance = None
# Apply EMA smoothing
current_cx += (target_cx - current_cx) * SMOOTHING_FACTOR
current_cy += (target_cy - current_cy) * SMOOTHING_FACTOR
current_scale += (target_scale - current_scale) * SMOOTHING_FACTOR
# Calculate crop dimensions
# When scale is S, the crop width is w / S
crop_w = int(w / current_scale)
crop_h = int(h / current_scale)
# Enforce aspect ratio
# If crop_w / crop_h is not 16:9, adjust one to match
current_ar = crop_w / max(1, crop_h)
if current_ar > TARGET_ASPECT_RATIO:
# Too wide, shrink width
crop_w = int(crop_h * TARGET_ASPECT_RATIO)
else:
# Too tall, shrink height
crop_h = int(crop_w / TARGET_ASPECT_RATIO)
# Calculate top-left point of crop, clamping to frame boundaries
center_px_x = int(current_cx * w)
center_px_y = int(current_cy * h)
start_x = max(0, center_px_x - crop_w // 2)
start_y = max(0, center_px_y - crop_h // 2)
# Adjust if crop box goes out of bounds
if start_x + crop_w > w:
start_x = w - crop_w
if start_y + crop_h > h:
start_y = h - crop_h
# Crop
cropped = frame[start_y:start_y+crop_h, start_x:start_x+crop_w]
return cropped
async def generate_obs_stream():
"""Generator for the MJPEG stream used by OBS."""
global latest_obs_frame
while True:
with obs_frame_lock:
if latest_obs_frame is not None:
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + latest_obs_frame + b'\r\n')
else:
# If no frame yet, yield a blank frame or sleep
await asyncio.sleep(0.1)
continue
# Use asyncio sleep to prevent blocking the event loop
await asyncio.sleep(0.033) # roughly 30 fps
@app.get("/obs_feed")
async def obs_feed():
"""Endpoint for OBS Media Source to connect to."""
return StreamingResponse(generate_obs_stream(), media_type="multipart/x-mixed-replace; boundary=frame")
async def vcam_generator_loop():
"""Background task to push frames to the virtual camera at 30fps."""
global is_obs_active, vcam, latest_vcam_frame
while True:
try:
if is_obs_active and vcam is not None and latest_vcam_frame is not None:
vcam.send(latest_vcam_frame)
except Exception as e:
logger.error(f"vcam loop error: {e}")
await asyncio.sleep(1/30)
@app.get("/")
async def health_check():
"""Health check endpoint."""
status_db = "connected" if users_collection is not None else "disconnected"
return {
"status": "ok",
"service": "AFS Tracking Backend",
"mongodb": status_db
}
async def mongodb_reconnect_loop():
"""Background task to attempt MongoDB reconnection if disconnected."""
global mongo_client, users_collection, audio_recordings_collection, audio_settings_collection
while True:
if users_collection is None:
mongo_uri = os.getenv("MONGODB_URI", "mongodb://localhost:27017")
mongo_db_name = os.getenv("MONGODB_DB", "afs")
try:
logger.info("Attempting to reconnect to MongoDB...")
client = AsyncMongoClient(
mongo_uri, serverSelectionTimeoutMS=5000)
# Ping to force connection verification
await client.admin.command('ping')
# Re-initialize
mongo_client = client
db = mongo_client[mongo_db_name]
users_collection = db["users"]
audio_recordings_collection = db["audio_recordings"]
audio_settings_collection = db["audio_settings"]
audio_angles_collection = db["audio_angles"]
await users_collection.create_index("email", unique=True)
logger.info("Successfully reconnected to MongoDB.")
except Exception as e:
logger.error(f"MongoDB reconnection failed: {e}")
mongo_client = None
users_collection = None
audio_recordings_collection = None
audio_settings_collection = None
audio_angles_collection = None
# Wait before next check (e.g., 10 seconds)
await asyncio.sleep(10)
@app.on_event("startup")
async def startup_event():
global mongo_client, users_collection, audio_recordings_collection, audio_settings_collection, audio_angles_collection
mongo_uri = os.getenv("MONGODB_URI", "mongodb://localhost:27017")
mongo_db_name = os.getenv("MONGODB_DB", "afs")
try:
mongo_client = AsyncMongoClient(
mongo_uri, serverSelectionTimeoutMS=5000)
# Ping to force connection verification
await mongo_client.admin.command('ping')
db = mongo_client[mongo_db_name]
users_collection = db["users"]
audio_recordings_collection = db["audio_recordings"]
audio_settings_collection = db["audio_settings"]
audio_angles_collection = db["audio_angles"]
await users_collection.create_index("email", unique=True)
logger.info("Connected to MongoDB and initialized collections.")
except Exception as e:
logger.warning(f"MongoDB connection failed on startup: {e}. Starting reconnection loop.")
mongo_client = None
users_collection = None
audio_recordings_collection = None
audio_settings_collection = None
audio_angles_collection = None
asyncio.create_task(vcam_generator_loop())
asyncio.create_task(mongodb_reconnect_loop())
@app.on_event("shutdown")
async def shutdown_event():
global mongo_client
if mongo_client is not None:
mongo_client.close()
logger.info("MongoDB connection closed.")
@app.post("/auth/register", response_model=AuthResponse)
async def register(payload: RegisterRequest):
collection = require_users_collection()
email = normalize_email(payload.email)
existing_user = await collection.find_one({"email": email})
if existing_user:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="An account with this email already exists.",
)
now = datetime.utcnow()
user_doc = {
"full_name": payload.full_name.strip(),
"email": email,
"password_hash": get_password_hash(payload.password),
"created_at": now,
"updated_at": now,
}
insert_result = await collection.insert_one(user_doc)
user_id = str(insert_result.inserted_id)
access_token = create_access_token(data={"sub": user_id})
return AuthResponse(
ok=True,
message="Account created successfully.",
user=UserPublic(
id=user_id,
full_name=user_doc["full_name"],
email=user_doc["email"],
),
token=access_token,
)
@app.post("/auth/login", response_model=AuthResponse)
async def login(payload: LoginRequest):
collection = require_users_collection()
email = normalize_email(payload.email)
user_doc = await collection.find_one({"email": email})
if not user_doc or not verify_password(payload.password, user_doc["password_hash"]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password.",
)
user_id = str(user_doc["_id"])
access_token = create_access_token(data={"sub": user_id})
return AuthResponse(
ok=True,
message="Login successful.",
user=UserPublic(
id=user_id,
full_name=user_doc["full_name"],
email=user_doc["email"],
),
token=access_token,
)
@app.get("/auth/verify", response_model=UserPublic)
async def verify_token(current_user: UserPublic = Depends(get_current_user)):
"""Verify JWT token and return user info"""
return current_user
@app.post("/api/enroll_face")
async def enroll_face(
video: UploadFile = File(...),
current_user: UserPublic = Depends(get_current_user)
):
try:
temp_path = f"temp_enroll_{uuid.uuid4()}.mp4"
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(video.file, buffer)
logger.info(f"Extracting embeddings for user {current_user.id}")
def run_extraction():
return face_service.extract_embeddings_from_video(temp_path)
embeddings, num_frames = await asyncio.get_event_loop().run_in_executor(
executor, run_extraction
)
pickled_embeddings = pickle.dumps(embeddings)
await users_collection.update_one(
{"_id": ObjectId(current_user.id)},
{"$set": {"embeddings": pickled_embeddings}}
)
os.remove(temp_path)
return {"ok": True, "message": "Face enrolled successfully", "frames_used": num_frames}
except Exception as e:
logger.error(f"Enrollment failed: {e}")
if os.path.exists(temp_path):
os.remove(temp_path)
raise HTTPException(status_code=500, detail=str(e))
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
global is_recording, video_writer, recording_filename, latest_obs_frame, is_obs_active, zoom_multiplier
await websocket.accept()
logger.info("New WebSocket connection established.")
current_mode = "single" # Default mode
ws_user_embeddings = None
try:
while True:
# Receive message (either text JSON or binary frame)
message = await websocket.receive()
if "text" in message:
try:
payload = json.loads(message["text"])
if "mode" in payload and payload["mode"] != current_mode:
logger.info(f"Switching mode from {current_mode} to {payload['mode']}")
current_mode = payload["mode"]
await websocket.send_json({"type": "mode_ack", "mode": current_mode})
elif "type" in payload and payload["type"] == "auth":
token = payload.get("token")
try:
token_data = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id = token_data.get("sub")
if user_id:
user = await users_collection.find_one({"_id": ObjectId(user_id)})
if user and "embeddings" in user and user["embeddings"]:
ws_user_embeddings = pickle.loads(user["embeddings"])
logger.info(f"Loaded custom face embeddings for user {user_id}")
await websocket.send_json({"type": "auth_ack", "status": "enrolled"})
else:
await websocket.send_json({"type": "auth_ack", "status": "no_enrollment"})
except Exception as e:
logger.error(f"WS Auth failed: {e}")
elif "zoom_scale" in payload:
zoom_multiplier = float(payload["zoom_scale"])
logger.info(f"Updated zoom multiplier to {zoom_multiplier}")
elif "command" in payload:
# Handle recording commands
command = payload["command"]
if command == "start_recording":
if not is_recording:
is_recording = True
recording_filename = f"capture_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4"
logger.info(f"Started recording to {recording_filename}")
await websocket.send_json({"type": "recording_ack", "status": "started"})
elif command == "stop_recording":
if is_recording:
is_recording = False
if video_writer is not None:
video_writer.release()
video_writer = None
logger.info(f'''Stopped recording. File saved as {recording_filename}''')
elif command == "start_obs":
if not is_obs_active:
is_obs_active = True
logger.info("Started OBS MJPEG stream")
await websocket.send_json({"type": "obs_ack", "status": "started"})
elif command == "stop_obs":
if is_obs_active:
is_obs_active = False
logger.info("Stopped OBS MJPEG stream")
await websocket.send_json({"type": "obs_ack", "status": "stopped"})
except json.JSONDecodeError:
logger.error("Invalid JSON received.")
continue
elif "bytes" in message:
frame_data = message["bytes"]
frame = decode_binary_image(frame_data)
if frame is None:
await websocket.send_json({"error": "Failed to decode binary frame"})
continue
# Prepare inference function
def run_inference(f, mode, embeddings=None):
if mode == "single":
return single_tracker.process_frame(f, custom_embeddings=embeddings)
elif mode == "multi":
return multi_tracker.process_frame(f)
else:
return {"error": f"Unknown mode: {mode}"}
# Process Frame in executor
response_data = {}
try:
response_data = await asyncio.get_event_loop().run_in_executor(
executor, run_inference, frame, current_mode, ws_user_embeddings
)
except Exception as e:
logger.error(f"Error processing frame in {current_mode} mode: {e}")
response_data = {"error": str(e)}
# Send results back to client
response_data["mode"] = current_mode
await websocket.send_json(response_data)
# Apply Crop and Handle OBS / Recording
try:
cropped_frame = apply_center_stage_crop(
frame, response_data)
# 1. Update OBS Feed
if is_obs_active:
ret, buffer = cv2.imencode('.jpg', cropped_frame)
if ret:
with obs_frame_lock:
latest_obs_frame = buffer.tobytes()
# 2. Update Recording Output
if is_recording:
h, w = cropped_frame.shape[:2]
if video_writer is None:
# Initialize writer with the exact dimensions of the FIRST cropped frame
fourcc = cv2.VideoWriter_fourcc(*'avc1')
video_writer = cv2.VideoWriter(
recording_filename, fourcc, 5.0, (w, h))
# Ensure we try to resize cleanly if aspect ratio forces slight off-by-one errors over time
if video_writer is not None:
target_w = int(video_writer.get(
cv2.CAP_PROP_FRAME_WIDTH))
target_h = int(video_writer.get(
cv2.CAP_PROP_FRAME_HEIGHT))
if (w, h) != (target_w, target_h):
cropped_frame = cv2.resize(
cropped_frame, (target_w, target_h))
video_writer.write(cropped_frame)
except Exception as e:
logger.error(f"Error handling post-process crops: {e}")
except WebSocketDisconnect:
logger.info("WebSocket client disconnected.")
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
is_obs_active = False
# Cleanup Recording
if video_writer is not None:
video_writer.release()
video_writer = None
is_recording = False
# === FACE RECOGNITION ENDPOINTS ===
@app.post("/api/face/upload-video")
async def upload_reference_video(
file: UploadFile = File(...),
current_user: UserPublic = Depends(get_current_user)
):
"""Upload a 360-degree reference video for face recognition training."""
if not file.filename.endswith(('.mp4', '.avi', '.mov', '.mkv')):
raise HTTPException(
status_code=400, detail="Invalid video format. Use mp4, avi, mov, or mkv")
video_path = MODEL_DIR / "my_scan.mp4"
try:
with open(video_path, 'wb') as f:
shutil.copyfileobj(file.file, f)
embeddings, num_frames = await asyncio.get_event_loop().run_in_executor(
executor, face_service.extract_embeddings_from_video, str(
video_path)
)
face_service.save_embeddings_cache(
embeddings, str(video_path), num_frames)
return {
"ok": True,
"message": "Video processed successfully",
"frames_used": num_frames,
"embeddings_count": len(embeddings)
}
except Exception as e:
logger.error(f"Error processing video: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/face/upload-image")
async def upload_reference_image(
file: UploadFile = File(...),
current_user: UserPublic = Depends(get_current_user)
):
"""Upload a reference image for face recognition."""
if not file.filename.endswith(('.jpg', '.jpeg', '.png')):
raise HTTPException(
status_code=400, detail="Invalid image format. Use jpg, jpeg, or png")
image_path = MODEL_DIR / f"ref_{file.filename}"
try:
with open(image_path, 'wb') as f:
shutil.copyfileobj(file.file, f)
embeddings = await asyncio.get_event_loop().run_in_executor(
executor, face_service.extract_embeddings_from_image, str(
image_path)
)
return {
"ok": True,
"message": "Image processed successfully",
"embeddings_count": len(embeddings),
"saved_path": str(image_path)
}
except Exception as e:
logger.error(f"Error processing image: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/face/cache-status")
async def get_cache_status(current_user: UserPublic = Depends(get_current_user)):
"""Get the current face recognition cache status."""
cache_data = face_service.load_embeddings_cache()
if cache_data:
return {
"ok": True,
"cached": True,
"video_path": cache_data.get('video_path'),
"model_name": cache_data.get('model_name'),
"num_frames_used": cache_data.get('num_frames_used'),
"version": cache_data.get('version')
}
else:
return {
"ok": True,
"cached": False,
"message": "No cache found. Please upload a reference video or image."
}
# === AUDIO STREAMING ENDPOINTS ===
@app.post("/api/audio/start-stream")
async def start_audio_stream(
sample_rate: int = Form(16000),
channels: int = Form(1),
current_user: UserPublic = Depends(get_current_user)
):
"""Start a new audio recording stream."""
session_id = str(uuid.uuid4())
try:
filename = audio_processor.create_audio_stream(
session_id, sample_rate, channels)
return {
"ok": True,
"session_id": session_id,
"filename": filename,
"sample_rate": sample_rate,
"channels": channels
}
except Exception as e:
logger.error(f"Error starting audio stream: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.websocket("/ws/audio/{session_id}")
async def websocket_audio_stream(websocket: WebSocket, session_id: str):
"""WebSocket endpoint for streaming audio with angle data."""
await websocket.accept()
logger.info(
f"Audio WebSocket connection established for session {session_id}")
# Auto-create stream if not exists
if session_id not in audio_processor.active_streams:
audio_processor.create_audio_stream(session_id)
logger.info(f"Auto-created audio stream for session {session_id}")
try:
while True:
message = await websocket.receive()
if "bytes" in message:
audio_data = message["bytes"]
audio_processor.write_audio_chunk(session_id, audio_data)
await websocket.send_json({"status": "received", "bytes": len(audio_data)})
elif "text" in message:
try:
payload = json.loads(message["text"])
if "audio_data" in payload and "angle" in payload:
audio_bytes = base64.b64decode(payload["audio_data"])
angle = float(payload["angle"])
audio_processor.write_audio_chunk(
session_id, audio_bytes, angle)
await websocket.send_json({"status": "received", "angle": angle})
elif payload.get("command") == "stop":
audio_processor.close_audio_stream(session_id)
await websocket.send_json({"status": "stopped", "message": "Stream closed"})
break
except json.JSONDecodeError:
logger.error("Invalid JSON in audio stream")
except WebSocketDisconnect:
logger.info(
f"Audio WebSocket client disconnected for session {session_id}")
if session_id in audio_processor.active_streams:
audio_processor.close_audio_stream(session_id)
except Exception as e:
logger.error(f"Audio WebSocket error: {e}")
if session_id in audio_processor.active_streams:
audio_processor.close_audio_stream(session_id)
@app.post("/api/audio/stop-stream/{session_id}")
async def stop_audio_stream(
session_id: str,
current_user: UserPublic = Depends(get_current_user)
):
"""Stop an active audio recording stream."""
try:
audio_processor.close_audio_stream(session_id)
return {
"ok": True,
"message": "Audio stream stopped successfully"
}
except Exception as e:
logger.error(f"Error stopping audio stream: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/audio/recordings")
async def list_audio_recordings(current_user: UserPublic = Depends(get_current_user)):
"""List all audio recordings."""
try:
recordings = audio_processor.get_audio_files()
return {
"ok": True,
"recordings": recordings,
"count": len(recordings)
}
except Exception as e:
logger.error(f"Error listing recordings: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/audio/active-sessions")
async def get_active_sessions():
"""Get currently active audio recording sessions."""
try:
sessions = list(audio_processor.active_streams.keys())
return {
"ok": True,
"active_sessions": sessions,
"count": len(sessions)
}
except Exception as e:
logger.error(f"Error getting active sessions: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/audio/angles")
async def get_audio_angles():
"""Get angle metadata for the latest audio session."""
try:
audio_dir = MODEL_DIR / "audio_recordings"
metadata_files = list(audio_dir.glob("*_metadata.txt"))
if not metadata_files:
raise HTTPException(
status_code=404,
detail="No metadata found"
)
# Get the most recently modified metadata file
import os
metadata_file = max(metadata_files, key=os.path.getmtime)
angles_data = []
with open(metadata_file, 'r') as f:
lines = f.readlines()
# Skip header if present
start_idx = 1 if lines and 'timestamp' in lines[0] else 0
for line in lines[start_idx:]:
if line.strip():
parts = line.strip().split(',')
if len(parts) >= 2:
try:
timestamp = float(parts[0])
angle = float(parts[1])
angles_data.append(
{"timestamp": timestamp, "angle": angle})
except ValueError:
continue
return {
"ok": True,
"file": metadata_file.name,
"angles": angles_data,
"count": len(angles_data)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving angles: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/audio/upload")
async def upload_audio_file(
file: UploadFile = File(...)
):
"""Upload recorded audio file from frontend and save to MongoDB."""
try:
# Read file content for DB persistence
file_content = await file.read()
if audio_recordings_collection is not None:
await audio_recordings_collection.insert_one({
"filename": file.filename,
"content": file_content, # Saved as binary in MongoDB
"content_type": file.content_type,
"timestamp": datetime.utcnow()
})
return {
"ok": True,
"message": "Audio file saved to database successfully",
"filename": file.filename,
"size": len(file_content)
}
except Exception as e:
logger.error(f"Error saving audio to DB: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/audio/set-angle")
async def set_desired_angle(
angle: float = Form(...)
):
"""Send a desired angle to the audio processing system and persist to MongoDB."""
try:
if not (0 <= angle <= 360):
raise HTTPException(
status_code=400,
detail="Angle must be between 0 and 360 degrees"
)
if audio_angles_collection is not None:
await audio_angles_collection.update_one(
{"key": "latest_angle"},
{"$set": {"value": angle, "updated_at": datetime.utcnow()}},
upsert=True
)
logger.info(f"Set and persisted desired angle {angle}° to DB")
return {
"ok": True,
"message": f"Desired angle set to {angle}° and saved to DB",
"angle": angle
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error setting angle in DB: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/audio/get-angle")
async def get_current_angle():
"""
Get the currently tracked angle of the target person.
If no person is tracked, fallback to the angle previously set via set-angle.
"""
try:
global current_target_angle, current_target_distance
logger.info(current_target_angle, current_target_distance)
# If a person is actively being tracked, return their real-time angle
if current_target_angle is not None:
return {
"ok": True,
"source": "tracking",
"angle": round(current_target_angle, 2),
"distance": round(current_target_distance, 2)
}
# Fallback to the saved angle if no target is actively tracked
if audio_angles_collection is not None:
saved_angle_doc = await audio_angles_collection.find_one({"key": "latest_angle"})
if saved_angle_doc and "value" in saved_angle_doc:
return {
"ok": True,
"source": "database",
"angle": float(saved_angle_doc["value"]),
"distance": None
}
return {
"ok": False,
"message": "No active tracking and no saved angle found",
"angle": None,
"distance": None
}
except Exception as e:
logger.error(f"Error retrieving angle: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/audio/settings")
async def get_audio_settings():
"""Retrieve all audio settings from MongoDB."""
try:
if audio_settings_collection is None:
return {"ok": False, "message": "Database not connected"}
cursor = audio_settings_collection.find({}, {"_id": 0})
settings_list = await cursor.to_list(length=100)
# Convert list to dictionary
settings_dict = {s["key"]: s["value"]
for s in settings_list if "key" in s}
return {
"ok": True,
"settings": settings_dict
}
except Exception as e:
logger.error(f"Error retrieving audio settings: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/audio/settings")
async def update_audio_settings(
settings: dict = Body(...)
):
"""Update general audio settings in MongoDB."""
try:
if audio_settings_collection is None:
raise HTTPException(
status_code=503, detail="Database not connected")
for key, value in settings.items():
await audio_settings_collection.update_one(
{"key": key},
{"$set": {"value": value, "updated_at": datetime.utcnow()}},
upsert=True
)
return {
"ok": True,
"message": "Audio settings updated successfully",
"updated_keys": list(settings.keys())
}
except Exception as e:
logger.error(f"Error updating audio settings: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)
|