from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, status, Depends, UploadFile, File, Form, Body from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials import uvicorn import cv2 import numpy as np import json import logging import asyncio from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta import threading import os import base64 import hashlib import math from pydantic import BaseModel, Field from pymongo import AsyncMongoClient import bcrypt import pickle from bson import ObjectId from jose import JWTError, jwt from dotenv import load_dotenv from pathlib import Path import shutil import uuid from services.single_tracker import SingleTracker from services.multi_tracker import MultiTracker from services.face_recognition import FaceRecognitionService from services.audio_processing import AudioProcessor # Load environment variables from .env file load_dotenv() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Executor for CPU-bound tasks executor = ThreadPoolExecutor(max_workers=1) # --- OBS and Recording State --- latest_obs_frame = None # Store the latest JPEG encoded cropped frame for the OBS feed obs_frame_lock = threading.Lock() is_obs_active = False is_recording = False video_writer = None recording_filename = "" # --- Center Stage State (EMA Smoothing) --- current_cx = 0.5 current_cy = 0.5 current_scale = 1.0 zoom_multiplier = 1.0 # --- Real-time Target Tracking State --- current_target_angle = None current_target_distance = None # Configurable parameters for smooth panning # Lower is smoother but slower (similar to Dart's TweenAnimation) SMOOTHING_FACTOR = 0.1 TARGET_ASPECT_RATIO = 16.0 / 9.0 # Assuming output is meant to be 16:9 app = FastAPI(title="AFS Tracking Backend") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize trackers and services MODEL_DIR = Path(__file__).parent / "Model" single_tracker = SingleTracker() multi_tracker = MultiTracker() face_service = FaceRecognitionService(str(MODEL_DIR)) audio_processor = AudioProcessor(str(MODEL_DIR)) # MongoDB state mongo_client: AsyncMongoClient | None = None users_collection = None audio_recordings_collection = None audio_settings_collection = None audio_angles_collection = None # JWT Configuration SECRET_KEY = os.getenv( "JWT_SECRET_KEY", "your-secret-key-change-in-production") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days security = HTTPBearer() class RegisterRequest(BaseModel): full_name: str = Field(min_length=2, max_length=80) email: str = Field(min_length=5, max_length=254) password: str = Field(min_length=8, max_length=128) class LoginRequest(BaseModel): email: str = Field(min_length=5, max_length=254) password: str = Field(min_length=8, max_length=128) class UserPublic(BaseModel): id: str full_name: str email: str class AuthResponse(BaseModel): ok: bool message: str user: UserPublic token: str def normalize_email(email: str) -> str: return email.strip().lower() def get_password_hash(password: str) -> str: return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') def verify_password(plain_password: str, hashed_password: str) -> bool: return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8')) def require_users_collection(): if users_collection is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Database is not initialized yet. Please retry.", ) return users_collection def create_access_token(data: dict, expires_delta: timedelta | None = None): to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): collection = require_users_collection() token = credentials.credentials try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) user_id: str = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", ) except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token", ) from bson import ObjectId try: user_doc = await collection.find_one({"_id": ObjectId(user_id)}) except: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found", ) if user_doc is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found", ) return UserPublic( id=str(user_doc["_id"]), full_name=user_doc["full_name"], email=user_doc["email"], ) def decode_binary_image(img_data: bytes): """Decodes raw JPEG bytes into an OpenCV numpy array.""" try: nparr = np.frombuffer(img_data, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) return img except Exception as e: logger.error(f"Failed to decode image: {e}") return None def apply_center_stage_crop(frame, tracking_data): """ Applies an exponential moving average (EMA) to smoothly pan and zoom the frame based on the tracking target bounding box. Returns the cropped frame. """ global current_cx, current_cy, current_scale, current_target_angle, current_target_distance, zoom_multiplier h, w = frame.shape[:2] # Defaults target_cx = 0.5 target_cy = 0.5 target_scale = 1.0 target_found = False # Calculate target state based on tracking data boxes = tracking_data.get("boxes", []) if tracking_data.get("mode") == "multi": if "aggregate_box" in tracking_data: ab = tracking_data["aggregate_box"] box_cx = (ab["x1"] + ab["x2"]) / 2.0 box_cy = (ab["y1"] + ab["y2"]) / 2.0 box_w = ab["x2"] - ab["x1"] box_h = ab["y2"] - ab["y1"] target_cx = box_cx / w target_cy = box_cy / h target_found = True # Target scale logic (from Dart): max dimension proportion * 1.5 margin max_dim = max(box_w / w, box_h / h) target_scale = 1.0 / (max_dim * 1.5) # Clamp scale target_scale = max(1.0, min(target_scale, 3.0)) else: # single target_box = None for b in boxes: if b.get("is_target"): target_box = b break if target_box: box_cx = (target_box["x1"] + target_box["x2"]) / 2.0 box_cy = (target_box["y1"] + target_box["y2"]) / 2.0 box_w = target_box["x2"] - target_box["x1"] box_h = target_box["y2"] - target_box["y1"] target_cx = box_cx / w target_cy = box_cy / h target_found = True max_dim = max(box_w / w, box_h / h) # slightly tighter for single person target_scale = 1.0 / (max_dim * 2.0) target_scale = max(1.0, min(target_scale, 3.0)) if target_found: # Apply user zoom multiplier target_scale = max(1.0, min(target_scale * zoom_multiplier, 10.0)) # Calculate distance and angle from the frame center (w/2, h/2) to the target bounding box center (box_cx, box_cy) center_x, center_y = w / 2.0, h / 2.0 dx = box_cx - center_x dy = box_cy - center_y current_target_distance = math.hypot(dx, dy) # Convert atan2 result to 0-360 degrees angle = math.degrees(math.atan2(dy, dx)) current_target_angle = angle % 360.0 else: current_target_angle = None current_target_distance = None # Apply EMA smoothing current_cx += (target_cx - current_cx) * SMOOTHING_FACTOR current_cy += (target_cy - current_cy) * SMOOTHING_FACTOR current_scale += (target_scale - current_scale) * SMOOTHING_FACTOR # Calculate crop dimensions # When scale is S, the crop width is w / S crop_w = int(w / current_scale) crop_h = int(h / current_scale) # Enforce aspect ratio # If crop_w / crop_h is not 16:9, adjust one to match current_ar = crop_w / max(1, crop_h) if current_ar > TARGET_ASPECT_RATIO: # Too wide, shrink width crop_w = int(crop_h * TARGET_ASPECT_RATIO) else: # Too tall, shrink height crop_h = int(crop_w / TARGET_ASPECT_RATIO) # Calculate top-left point of crop, clamping to frame boundaries center_px_x = int(current_cx * w) center_px_y = int(current_cy * h) start_x = max(0, center_px_x - crop_w // 2) start_y = max(0, center_px_y - crop_h // 2) # Adjust if crop box goes out of bounds if start_x + crop_w > w: start_x = w - crop_w if start_y + crop_h > h: start_y = h - crop_h # Crop cropped = frame[start_y:start_y+crop_h, start_x:start_x+crop_w] return cropped async def generate_obs_stream(): """Generator for the MJPEG stream used by OBS.""" global latest_obs_frame while True: with obs_frame_lock: if latest_obs_frame is not None: yield (b'--frame\r\n' b'Content-Type: image/jpeg\r\n\r\n' + latest_obs_frame + b'\r\n') else: # If no frame yet, yield a blank frame or sleep await asyncio.sleep(0.1) continue # Use asyncio sleep to prevent blocking the event loop await asyncio.sleep(0.033) # roughly 30 fps @app.get("/obs_feed") async def obs_feed(): """Endpoint for OBS Media Source to connect to.""" return StreamingResponse(generate_obs_stream(), media_type="multipart/x-mixed-replace; boundary=frame") async def vcam_generator_loop(): """Background task to push frames to the virtual camera at 30fps.""" global is_obs_active, vcam, latest_vcam_frame while True: try: if is_obs_active and vcam is not None and latest_vcam_frame is not None: vcam.send(latest_vcam_frame) except Exception as e: logger.error(f"vcam loop error: {e}") await asyncio.sleep(1/30) @app.get("/") async def health_check(): """Health check endpoint.""" status_db = "connected" if users_collection is not None else "disconnected" return { "status": "ok", "service": "AFS Tracking Backend", "mongodb": status_db } async def mongodb_reconnect_loop(): """Background task to attempt MongoDB reconnection if disconnected.""" global mongo_client, users_collection, audio_recordings_collection, audio_settings_collection while True: if users_collection is None: mongo_uri = os.getenv("MONGODB_URI", "mongodb://localhost:27017") mongo_db_name = os.getenv("MONGODB_DB", "afs") try: logger.info("Attempting to reconnect to MongoDB...") client = AsyncMongoClient( mongo_uri, serverSelectionTimeoutMS=5000) # Ping to force connection verification await client.admin.command('ping') # Re-initialize mongo_client = client db = mongo_client[mongo_db_name] users_collection = db["users"] audio_recordings_collection = db["audio_recordings"] audio_settings_collection = db["audio_settings"] audio_angles_collection = db["audio_angles"] await users_collection.create_index("email", unique=True) logger.info("Successfully reconnected to MongoDB.") except Exception as e: logger.error(f"MongoDB reconnection failed: {e}") mongo_client = None users_collection = None audio_recordings_collection = None audio_settings_collection = None audio_angles_collection = None # Wait before next check (e.g., 10 seconds) await asyncio.sleep(10) @app.on_event("startup") async def startup_event(): global mongo_client, users_collection, audio_recordings_collection, audio_settings_collection, audio_angles_collection mongo_uri = os.getenv("MONGODB_URI", "mongodb://localhost:27017") mongo_db_name = os.getenv("MONGODB_DB", "afs") try: mongo_client = AsyncMongoClient( mongo_uri, serverSelectionTimeoutMS=5000) # Ping to force connection verification await mongo_client.admin.command('ping') db = mongo_client[mongo_db_name] users_collection = db["users"] audio_recordings_collection = db["audio_recordings"] audio_settings_collection = db["audio_settings"] audio_angles_collection = db["audio_angles"] await users_collection.create_index("email", unique=True) logger.info("Connected to MongoDB and initialized collections.") except Exception as e: logger.warning(f"MongoDB connection failed on startup: {e}. Starting reconnection loop.") mongo_client = None users_collection = None audio_recordings_collection = None audio_settings_collection = None audio_angles_collection = None asyncio.create_task(vcam_generator_loop()) asyncio.create_task(mongodb_reconnect_loop()) @app.on_event("shutdown") async def shutdown_event(): global mongo_client if mongo_client is not None: mongo_client.close() logger.info("MongoDB connection closed.") @app.post("/auth/register", response_model=AuthResponse) async def register(payload: RegisterRequest): collection = require_users_collection() email = normalize_email(payload.email) existing_user = await collection.find_one({"email": email}) if existing_user: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="An account with this email already exists.", ) now = datetime.utcnow() user_doc = { "full_name": payload.full_name.strip(), "email": email, "password_hash": get_password_hash(payload.password), "created_at": now, "updated_at": now, } insert_result = await collection.insert_one(user_doc) user_id = str(insert_result.inserted_id) access_token = create_access_token(data={"sub": user_id}) return AuthResponse( ok=True, message="Account created successfully.", user=UserPublic( id=user_id, full_name=user_doc["full_name"], email=user_doc["email"], ), token=access_token, ) @app.post("/auth/login", response_model=AuthResponse) async def login(payload: LoginRequest): collection = require_users_collection() email = normalize_email(payload.email) user_doc = await collection.find_one({"email": email}) if not user_doc or not verify_password(payload.password, user_doc["password_hash"]): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password.", ) user_id = str(user_doc["_id"]) access_token = create_access_token(data={"sub": user_id}) return AuthResponse( ok=True, message="Login successful.", user=UserPublic( id=user_id, full_name=user_doc["full_name"], email=user_doc["email"], ), token=access_token, ) @app.get("/auth/verify", response_model=UserPublic) async def verify_token(current_user: UserPublic = Depends(get_current_user)): """Verify JWT token and return user info""" return current_user @app.post("/api/enroll_face") async def enroll_face( video: UploadFile = File(...), current_user: UserPublic = Depends(get_current_user) ): try: temp_path = f"temp_enroll_{uuid.uuid4()}.mp4" with open(temp_path, "wb") as buffer: shutil.copyfileobj(video.file, buffer) logger.info(f"Extracting embeddings for user {current_user.id}") def run_extraction(): return face_service.extract_embeddings_from_video(temp_path) embeddings, num_frames = await asyncio.get_event_loop().run_in_executor( executor, run_extraction ) pickled_embeddings = pickle.dumps(embeddings) await users_collection.update_one( {"_id": ObjectId(current_user.id)}, {"$set": {"embeddings": pickled_embeddings}} ) os.remove(temp_path) return {"ok": True, "message": "Face enrolled successfully", "frames_used": num_frames} except Exception as e: logger.error(f"Enrollment failed: {e}") if os.path.exists(temp_path): os.remove(temp_path) raise HTTPException(status_code=500, detail=str(e)) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): global is_recording, video_writer, recording_filename, latest_obs_frame, is_obs_active, zoom_multiplier await websocket.accept() logger.info("New WebSocket connection established.") current_mode = "single" # Default mode ws_user_embeddings = None try: while True: # Receive message (either text JSON or binary frame) message = await websocket.receive() if "text" in message: try: payload = json.loads(message["text"]) if "mode" in payload and payload["mode"] != current_mode: logger.info(f"Switching mode from {current_mode} to {payload['mode']}") current_mode = payload["mode"] await websocket.send_json({"type": "mode_ack", "mode": current_mode}) elif "type" in payload and payload["type"] == "auth": token = payload.get("token") try: token_data = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) user_id = token_data.get("sub") if user_id: user = await users_collection.find_one({"_id": ObjectId(user_id)}) if user and "embeddings" in user and user["embeddings"]: ws_user_embeddings = pickle.loads(user["embeddings"]) logger.info(f"Loaded custom face embeddings for user {user_id}") await websocket.send_json({"type": "auth_ack", "status": "enrolled"}) else: await websocket.send_json({"type": "auth_ack", "status": "no_enrollment"}) except Exception as e: logger.error(f"WS Auth failed: {e}") elif "zoom_scale" in payload: zoom_multiplier = float(payload["zoom_scale"]) logger.info(f"Updated zoom multiplier to {zoom_multiplier}") elif "command" in payload: # Handle recording commands command = payload["command"] if command == "start_recording": if not is_recording: is_recording = True recording_filename = f"capture_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4" logger.info(f"Started recording to {recording_filename}") await websocket.send_json({"type": "recording_ack", "status": "started"}) elif command == "stop_recording": if is_recording: is_recording = False if video_writer is not None: video_writer.release() video_writer = None logger.info(f'''Stopped recording. File saved as {recording_filename}''') elif command == "start_obs": if not is_obs_active: is_obs_active = True logger.info("Started OBS MJPEG stream") await websocket.send_json({"type": "obs_ack", "status": "started"}) elif command == "stop_obs": if is_obs_active: is_obs_active = False logger.info("Stopped OBS MJPEG stream") await websocket.send_json({"type": "obs_ack", "status": "stopped"}) except json.JSONDecodeError: logger.error("Invalid JSON received.") continue elif "bytes" in message: frame_data = message["bytes"] frame = decode_binary_image(frame_data) if frame is None: await websocket.send_json({"error": "Failed to decode binary frame"}) continue # Prepare inference function def run_inference(f, mode, embeddings=None): if mode == "single": return single_tracker.process_frame(f, custom_embeddings=embeddings) elif mode == "multi": return multi_tracker.process_frame(f) else: return {"error": f"Unknown mode: {mode}"} # Process Frame in executor response_data = {} try: response_data = await asyncio.get_event_loop().run_in_executor( executor, run_inference, frame, current_mode, ws_user_embeddings ) except Exception as e: logger.error(f"Error processing frame in {current_mode} mode: {e}") response_data = {"error": str(e)} # Send results back to client response_data["mode"] = current_mode await websocket.send_json(response_data) # Apply Crop and Handle OBS / Recording try: cropped_frame = apply_center_stage_crop( frame, response_data) # 1. Update OBS Feed if is_obs_active: ret, buffer = cv2.imencode('.jpg', cropped_frame) if ret: with obs_frame_lock: latest_obs_frame = buffer.tobytes() # 2. Update Recording Output if is_recording: h, w = cropped_frame.shape[:2] if video_writer is None: # Initialize writer with the exact dimensions of the FIRST cropped frame fourcc = cv2.VideoWriter_fourcc(*'avc1') video_writer = cv2.VideoWriter( recording_filename, fourcc, 5.0, (w, h)) # Ensure we try to resize cleanly if aspect ratio forces slight off-by-one errors over time if video_writer is not None: target_w = int(video_writer.get( cv2.CAP_PROP_FRAME_WIDTH)) target_h = int(video_writer.get( cv2.CAP_PROP_FRAME_HEIGHT)) if (w, h) != (target_w, target_h): cropped_frame = cv2.resize( cropped_frame, (target_w, target_h)) video_writer.write(cropped_frame) except Exception as e: logger.error(f"Error handling post-process crops: {e}") except WebSocketDisconnect: logger.info("WebSocket client disconnected.") except Exception as e: logger.error(f"WebSocket error: {e}") finally: is_obs_active = False # Cleanup Recording if video_writer is not None: video_writer.release() video_writer = None is_recording = False # === FACE RECOGNITION ENDPOINTS === @app.post("/api/face/upload-video") async def upload_reference_video( file: UploadFile = File(...), current_user: UserPublic = Depends(get_current_user) ): """Upload a 360-degree reference video for face recognition training.""" if not file.filename.endswith(('.mp4', '.avi', '.mov', '.mkv')): raise HTTPException( status_code=400, detail="Invalid video format. Use mp4, avi, mov, or mkv") video_path = MODEL_DIR / "my_scan.mp4" try: with open(video_path, 'wb') as f: shutil.copyfileobj(file.file, f) embeddings, num_frames = await asyncio.get_event_loop().run_in_executor( executor, face_service.extract_embeddings_from_video, str( video_path) ) face_service.save_embeddings_cache( embeddings, str(video_path), num_frames) return { "ok": True, "message": "Video processed successfully", "frames_used": num_frames, "embeddings_count": len(embeddings) } except Exception as e: logger.error(f"Error processing video: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/face/upload-image") async def upload_reference_image( file: UploadFile = File(...), current_user: UserPublic = Depends(get_current_user) ): """Upload a reference image for face recognition.""" if not file.filename.endswith(('.jpg', '.jpeg', '.png')): raise HTTPException( status_code=400, detail="Invalid image format. Use jpg, jpeg, or png") image_path = MODEL_DIR / f"ref_{file.filename}" try: with open(image_path, 'wb') as f: shutil.copyfileobj(file.file, f) embeddings = await asyncio.get_event_loop().run_in_executor( executor, face_service.extract_embeddings_from_image, str( image_path) ) return { "ok": True, "message": "Image processed successfully", "embeddings_count": len(embeddings), "saved_path": str(image_path) } except Exception as e: logger.error(f"Error processing image: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/face/cache-status") async def get_cache_status(current_user: UserPublic = Depends(get_current_user)): """Get the current face recognition cache status.""" cache_data = face_service.load_embeddings_cache() if cache_data: return { "ok": True, "cached": True, "video_path": cache_data.get('video_path'), "model_name": cache_data.get('model_name'), "num_frames_used": cache_data.get('num_frames_used'), "version": cache_data.get('version') } else: return { "ok": True, "cached": False, "message": "No cache found. Please upload a reference video or image." } # === AUDIO STREAMING ENDPOINTS === @app.post("/api/audio/start-stream") async def start_audio_stream( sample_rate: int = Form(16000), channels: int = Form(1), current_user: UserPublic = Depends(get_current_user) ): """Start a new audio recording stream.""" session_id = str(uuid.uuid4()) try: filename = audio_processor.create_audio_stream( session_id, sample_rate, channels) return { "ok": True, "session_id": session_id, "filename": filename, "sample_rate": sample_rate, "channels": channels } except Exception as e: logger.error(f"Error starting audio stream: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.websocket("/ws/audio/{session_id}") async def websocket_audio_stream(websocket: WebSocket, session_id: str): """WebSocket endpoint for streaming audio with angle data.""" await websocket.accept() logger.info( f"Audio WebSocket connection established for session {session_id}") # Auto-create stream if not exists if session_id not in audio_processor.active_streams: audio_processor.create_audio_stream(session_id) logger.info(f"Auto-created audio stream for session {session_id}") try: while True: message = await websocket.receive() if "bytes" in message: audio_data = message["bytes"] audio_processor.write_audio_chunk(session_id, audio_data) await websocket.send_json({"status": "received", "bytes": len(audio_data)}) elif "text" in message: try: payload = json.loads(message["text"]) if "audio_data" in payload and "angle" in payload: audio_bytes = base64.b64decode(payload["audio_data"]) angle = float(payload["angle"]) audio_processor.write_audio_chunk( session_id, audio_bytes, angle) await websocket.send_json({"status": "received", "angle": angle}) elif payload.get("command") == "stop": audio_processor.close_audio_stream(session_id) await websocket.send_json({"status": "stopped", "message": "Stream closed"}) break except json.JSONDecodeError: logger.error("Invalid JSON in audio stream") except WebSocketDisconnect: logger.info( f"Audio WebSocket client disconnected for session {session_id}") if session_id in audio_processor.active_streams: audio_processor.close_audio_stream(session_id) except Exception as e: logger.error(f"Audio WebSocket error: {e}") if session_id in audio_processor.active_streams: audio_processor.close_audio_stream(session_id) @app.post("/api/audio/stop-stream/{session_id}") async def stop_audio_stream( session_id: str, current_user: UserPublic = Depends(get_current_user) ): """Stop an active audio recording stream.""" try: audio_processor.close_audio_stream(session_id) return { "ok": True, "message": "Audio stream stopped successfully" } except Exception as e: logger.error(f"Error stopping audio stream: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/audio/recordings") async def list_audio_recordings(current_user: UserPublic = Depends(get_current_user)): """List all audio recordings.""" try: recordings = audio_processor.get_audio_files() return { "ok": True, "recordings": recordings, "count": len(recordings) } except Exception as e: logger.error(f"Error listing recordings: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/audio/active-sessions") async def get_active_sessions(): """Get currently active audio recording sessions.""" try: sessions = list(audio_processor.active_streams.keys()) return { "ok": True, "active_sessions": sessions, "count": len(sessions) } except Exception as e: logger.error(f"Error getting active sessions: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/audio/angles") async def get_audio_angles(): """Get angle metadata for the latest audio session.""" try: audio_dir = MODEL_DIR / "audio_recordings" metadata_files = list(audio_dir.glob("*_metadata.txt")) if not metadata_files: raise HTTPException( status_code=404, detail="No metadata found" ) # Get the most recently modified metadata file import os metadata_file = max(metadata_files, key=os.path.getmtime) angles_data = [] with open(metadata_file, 'r') as f: lines = f.readlines() # Skip header if present start_idx = 1 if lines and 'timestamp' in lines[0] else 0 for line in lines[start_idx:]: if line.strip(): parts = line.strip().split(',') if len(parts) >= 2: try: timestamp = float(parts[0]) angle = float(parts[1]) angles_data.append( {"timestamp": timestamp, "angle": angle}) except ValueError: continue return { "ok": True, "file": metadata_file.name, "angles": angles_data, "count": len(angles_data) } except HTTPException: raise except Exception as e: logger.error(f"Error retrieving angles: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/audio/upload") async def upload_audio_file( file: UploadFile = File(...) ): """Upload recorded audio file from frontend and save to MongoDB.""" try: # Read file content for DB persistence file_content = await file.read() if audio_recordings_collection is not None: await audio_recordings_collection.insert_one({ "filename": file.filename, "content": file_content, # Saved as binary in MongoDB "content_type": file.content_type, "timestamp": datetime.utcnow() }) return { "ok": True, "message": "Audio file saved to database successfully", "filename": file.filename, "size": len(file_content) } except Exception as e: logger.error(f"Error saving audio to DB: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/audio/set-angle") async def set_desired_angle( angle: float = Form(...) ): """Send a desired angle to the audio processing system and persist to MongoDB.""" try: if not (0 <= angle <= 360): raise HTTPException( status_code=400, detail="Angle must be between 0 and 360 degrees" ) if audio_angles_collection is not None: await audio_angles_collection.update_one( {"key": "latest_angle"}, {"$set": {"value": angle, "updated_at": datetime.utcnow()}}, upsert=True ) logger.info(f"Set and persisted desired angle {angle}° to DB") return { "ok": True, "message": f"Desired angle set to {angle}° and saved to DB", "angle": angle } except HTTPException: raise except Exception as e: logger.error(f"Error setting angle in DB: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/audio/get-angle") async def get_current_angle(): """ Get the currently tracked angle of the target person. If no person is tracked, fallback to the angle previously set via set-angle. """ try: global current_target_angle, current_target_distance logger.info(current_target_angle, current_target_distance) # If a person is actively being tracked, return their real-time angle if current_target_angle is not None: return { "ok": True, "source": "tracking", "angle": round(current_target_angle, 2), "distance": round(current_target_distance, 2) } # Fallback to the saved angle if no target is actively tracked if audio_angles_collection is not None: saved_angle_doc = await audio_angles_collection.find_one({"key": "latest_angle"}) if saved_angle_doc and "value" in saved_angle_doc: return { "ok": True, "source": "database", "angle": float(saved_angle_doc["value"]), "distance": None } return { "ok": False, "message": "No active tracking and no saved angle found", "angle": None, "distance": None } except Exception as e: logger.error(f"Error retrieving angle: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/audio/settings") async def get_audio_settings(): """Retrieve all audio settings from MongoDB.""" try: if audio_settings_collection is None: return {"ok": False, "message": "Database not connected"} cursor = audio_settings_collection.find({}, {"_id": 0}) settings_list = await cursor.to_list(length=100) # Convert list to dictionary settings_dict = {s["key"]: s["value"] for s in settings_list if "key" in s} return { "ok": True, "settings": settings_dict } except Exception as e: logger.error(f"Error retrieving audio settings: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/audio/settings") async def update_audio_settings( settings: dict = Body(...) ): """Update general audio settings in MongoDB.""" try: if audio_settings_collection is None: raise HTTPException( status_code=503, detail="Database not connected") for key, value in settings.items(): await audio_settings_collection.update_one( {"key": key}, {"$set": {"value": value, "updated_at": datetime.utcnow()}}, upsert=True ) return { "ok": True, "message": "Audio settings updated successfully", "updated_keys": list(settings.keys()) } except Exception as e: logger.error(f"Error updating audio settings: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)