Spaces:
Running
Running
| import os | |
| import threading | |
| import tempfile | |
| import urllib.request | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import psycopg | |
| from pgvector.psycopg import register_vector | |
| import numpy as np | |
| import uuid | |
| import io | |
| from PIL import Image | |
| from deepface import DeepFace | |
| from datetime import date | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| app = FastAPI(title="VisionAttend API") | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, replace with frontend URL | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| DATABASE_URL = os.getenv("DATABASE_URL") | |
| FACENET_WEIGHTS_PATH = os.path.join( | |
| os.path.expanduser("~"), | |
| ".deepface", | |
| "weights", | |
| "facenet_weights.h5", | |
| ) | |
| FACENET_WEIGHTS_URL = "https://huggingface.co/junjiang/GestureFace/resolve/main/facenet_weights.h5" | |
| FACENET_MIN_BYTES = 80 * 1024 * 1024 | |
| _facenet_model_lock = threading.Lock() | |
| _facenet_model_ready = False | |
| class MatchResponse(BaseModel): | |
| message: str | |
| student_name: str | |
| roll_number: str | |
| from psycopg_pool import ConnectionPool | |
| def configure_conn(conn): | |
| register_vector(conn) | |
| # By using a ConnectionPool, we keep the TCP connection alive to Neon. | |
| # This eliminates the ~6-8 second TLS/SSL handshake delay on every request. | |
| db_pool = None | |
| if DATABASE_URL: | |
| try: | |
| db_pool = ConnectionPool( | |
| conninfo=DATABASE_URL, | |
| configure=configure_conn, | |
| min_size=1, | |
| max_size=5, | |
| timeout=30.0 | |
| ) | |
| except Exception as e: | |
| print(f"WARNING: Failed to initialize database pool: {e}") | |
| else: | |
| print("WARNING: DATABASE_URL is not set. Database connections will fail.") | |
| def get_db_connection(): | |
| if not db_pool: | |
| raise HTTPException(status_code=500, detail="Database is not configured or failed to connect. Please set DATABASE_URL in Hugging Face Spaces Settings -> Variables and secrets.") | |
| return db_pool.connection() | |
| def ensure_facenet_weights_file() -> None: | |
| weights_dir = os.path.dirname(FACENET_WEIGHTS_PATH) | |
| os.makedirs(weights_dir, exist_ok=True) | |
| if os.path.exists(FACENET_WEIGHTS_PATH) and os.path.getsize(FACENET_WEIGHTS_PATH) >= FACENET_MIN_BYTES: | |
| return | |
| with tempfile.NamedTemporaryFile(delete=False, dir=weights_dir, suffix=".tmp") as temp_file: | |
| temp_path = temp_file.name | |
| try: | |
| with urllib.request.urlopen(FACENET_WEIGHTS_URL, timeout=120) as response, open(temp_path, "wb") as output_file: | |
| while True: | |
| chunk = response.read(1024 * 1024) | |
| if not chunk: | |
| break | |
| output_file.write(chunk) | |
| if os.path.getsize(temp_path) < FACENET_MIN_BYTES: | |
| raise ValueError("Downloaded FaceNet weights file is incomplete.") | |
| os.replace(temp_path, FACENET_WEIGHTS_PATH) | |
| finally: | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| def ensure_facenet_model_loaded() -> None: | |
| global _facenet_model_ready | |
| if _facenet_model_ready: | |
| return | |
| with _facenet_model_lock: | |
| if _facenet_model_ready: | |
| return | |
| try: | |
| ensure_facenet_weights_file() | |
| DeepFace.build_model("Facenet") | |
| _facenet_model_ready = True | |
| return | |
| except Exception as first_error: | |
| if os.path.exists(FACENET_WEIGHTS_PATH): | |
| try: | |
| os.remove(FACENET_WEIGHTS_PATH) | |
| except OSError: | |
| raise first_error | |
| ensure_facenet_weights_file() | |
| DeepFace.build_model("Facenet") | |
| _facenet_model_ready = True | |
| def extract_face_embedding(file_bytes: bytes) -> tuple[list[float], str]: | |
| try: | |
| ensure_facenet_model_loaded() | |
| image = Image.open(io.BytesIO(file_bytes)).convert("RGB") | |
| img_array = np.array(image) | |
| results = DeepFace.represent( | |
| img_path=img_array, | |
| enforce_detection=True, | |
| model_name="Facenet", | |
| detector_backend="opencv" | |
| ) | |
| if len(results) == 0: | |
| raise ValueError("No face detected.") | |
| if len(results) > 1: | |
| raise ValueError("Multiple faces detected. Please show only one face.") | |
| embedding = results[0]["embedding"] | |
| return embedding, "success" | |
| except ValueError as ve: | |
| return [], str(ve) | |
| except Exception as e: | |
| return [], f"An error occurred during face extraction: {str(e)}" | |
| async def register_student( | |
| name: str = Form(...), | |
| roll_number: str = Form(...), | |
| image: UploadFile = File(...) | |
| ): | |
| file_bytes = await image.read() | |
| embedding, error = extract_face_embedding(file_bytes) | |
| if not embedding: | |
| raise HTTPException(status_code=400, detail=error) | |
| try: | |
| with get_db_connection() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute("SELECT id FROM students WHERE roll_number = %s", (roll_number,)) | |
| if cur.fetchone(): | |
| raise HTTPException(status_code=400, detail="Student with this roll number already exists.") | |
| cur.execute(""" | |
| SELECT name, roll_number, (face_encoding <=> %s::vector) AS distance | |
| FROM students | |
| ORDER BY distance ASC | |
| LIMIT 1; | |
| """, (str(embedding),)) | |
| result = cur.fetchone() | |
| if result: | |
| existing_name, existing_roll_number, distance = result | |
| if distance <= 0.40: | |
| raise HTTPException(status_code=400, detail=f"Face already registered to {existing_name} under roll number {existing_roll_number}.") | |
| cur.execute( | |
| "INSERT INTO students (name, roll_number, face_encoding) VALUES (%s, %s, %s) RETURNING id", | |
| (name, roll_number, str(embedding)) | |
| ) | |
| student_id = cur.fetchone()[0] | |
| conn.commit() | |
| return {"message": "Student registered successfully", "student_id": student_id} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") | |
| async def recognize_student(image: UploadFile = File(...)): | |
| file_bytes = await image.read() | |
| embedding, error = extract_face_embedding(file_bytes) | |
| if not embedding: | |
| raise HTTPException(status_code=400, detail=error) | |
| try: | |
| with get_db_connection() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute(""" | |
| SELECT id, name, roll_number, (face_encoding <=> %s::vector) AS distance | |
| FROM students | |
| ORDER BY distance ASC | |
| LIMIT 1; | |
| """, (str(embedding),)) | |
| result = cur.fetchone() | |
| if not result: | |
| raise HTTPException(status_code=404, detail="No match found in database.") | |
| student_id, name, roll_number, distance = result | |
| print(f"Discovered {name} with distance {distance}") | |
| if distance > 0.40: # Threshold for Facenet | |
| raise HTTPException(status_code=404, detail="Face recognized but distance exceeds confidence threshold.") | |
| cur.execute(""" | |
| SELECT id FROM attendance | |
| WHERE student_id = %s AND timestamp::date = %s | |
| """, (student_id, date.today())) | |
| if not cur.fetchone(): | |
| cur.execute( | |
| "INSERT INTO attendance (student_id, status) VALUES (%s, 'Present')", | |
| (student_id,) | |
| ) | |
| conn.commit() | |
| message = "Attendance marked." | |
| else: | |
| message = "Attendance already marked for today." | |
| return MatchResponse(message=message, student_name=name, roll_number=roll_number) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") | |
| async def get_attendance(date_filter: str = None): | |
| try: | |
| with get_db_connection() as conn: | |
| with conn.cursor() as cur: | |
| if date_filter: | |
| cur.execute(""" | |
| SELECT s.id, s.name, s.roll_number, a.timestamp, | |
| CASE WHEN a.id IS NOT NULL THEN 'Present' ELSE 'Absent' END as status, | |
| s.id as student_id | |
| FROM students s | |
| LEFT JOIN attendance a ON s.id = a.student_id AND DATE(a.timestamp) = %s | |
| ORDER BY | |
| CASE WHEN a.id IS NOT NULL THEN 0 ELSE 1 END, | |
| s.name ASC | |
| LIMIT 100; | |
| """, (date_filter,)) | |
| rows = cur.fetchall() | |
| logs = [ | |
| { | |
| "id": str(row[0]), | |
| "name": row[1], | |
| "roll_number": row[2], | |
| "timestamp": row[3], | |
| "status": row[4], | |
| "student_id": str(row[5]) | |
| } for row in rows | |
| ] | |
| else: | |
| cur.execute(""" | |
| SELECT a.id, s.name, s.roll_number, a.timestamp, a.status, s.id as student_id | |
| FROM attendance a | |
| JOIN students s ON a.student_id = s.id | |
| ORDER BY a.timestamp DESC | |
| LIMIT 100; | |
| """) | |
| rows = cur.fetchall() | |
| logs = [ | |
| { | |
| "id": str(row[0]), | |
| "name": row[1], | |
| "roll_number": row[2], | |
| "timestamp": row[3], | |
| "status": row[4], | |
| "student_id": str(row[5]) | |
| } for row in rows | |
| ] | |
| return {"logs": logs} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") | |
| class StudentUpdate(BaseModel): | |
| name: str | |
| roll_number: str | |
| async def update_student(student_id: uuid.UUID, data: StudentUpdate): | |
| try: | |
| with get_db_connection() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute("SELECT id FROM students WHERE roll_number = %s AND id != %s", (data.roll_number, student_id)) | |
| if cur.fetchone(): | |
| raise HTTPException(status_code=400, detail="Another student with this roll number already exists.") | |
| cur.execute( | |
| "UPDATE students SET name = %s, roll_number = %s WHERE id = %s RETURNING id", | |
| (data.name, data.roll_number, student_id) | |
| ) | |
| if not cur.fetchone(): | |
| raise HTTPException(status_code=404, detail="Student not found.") | |
| conn.commit() | |
| return {"message": "Student updated successfully."} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") | |
| async def delete_student(student_id: uuid.UUID): | |
| try: | |
| with get_db_connection() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute("DELETE FROM students WHERE id = %s RETURNING id", (student_id,)) | |
| if not cur.fetchone(): | |
| raise HTTPException(status_code=404, detail="Student not found.") | |
| conn.commit() | |
| return {"message": "Student deleted successfully."} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") | |
| async def update_student_photo(student_id: uuid.UUID, image: UploadFile = File(...)): | |
| file_bytes = await image.read() | |
| embedding, error = extract_face_embedding(file_bytes) | |
| if not embedding: | |
| raise HTTPException(status_code=400, detail=error) | |
| try: | |
| with get_db_connection() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute(""" | |
| SELECT name, roll_number, (face_encoding <=> %s::vector) AS distance | |
| FROM students | |
| WHERE id != %s | |
| ORDER BY distance ASC | |
| LIMIT 1; | |
| """, (str(embedding), student_id)) | |
| result = cur.fetchone() | |
| if result: | |
| existing_name, existing_roll_number, distance = result | |
| if distance <= 0.40: | |
| raise HTTPException(status_code=400, detail=f"This face is already registered to {existing_name} (Roll: {existing_roll_number}).") | |
| cur.execute( | |
| "UPDATE students SET face_encoding = %s WHERE id = %s RETURNING id", | |
| (str(embedding), student_id) | |
| ) | |
| if not cur.fetchone(): | |
| raise HTTPException(status_code=404, detail="Student not found.") | |
| conn.commit() | |
| return {"message": "Student photo updated successfully."} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") | |
| class ManualAttendance(BaseModel): | |
| student_id: uuid.UUID | |
| date: str # YYYY-MM-DD | |
| status: str # "Present" or "Absent" | |
| async def mark_manual_attendance(data: ManualAttendance): | |
| try: | |
| with get_db_connection() as conn: | |
| with conn.cursor() as cur: | |
| if data.status == "Present": | |
| cur.execute(""" | |
| SELECT id FROM attendance | |
| WHERE student_id = %s AND timestamp::date = %s | |
| """, (data.student_id, data.date)) | |
| if not cur.fetchone(): | |
| cur.execute( | |
| "INSERT INTO attendance (student_id, status, timestamp) VALUES (%s, 'Present', %s)", | |
| (data.student_id, f"{data.date} 12:00:00") # Default to noon for manual entries | |
| ) | |
| conn.commit() | |
| return {"message": "Marked present successfully."} | |
| else: | |
| return {"message": "Already marked present for this date."} | |
| elif data.status == "Absent": | |
| cur.execute(""" | |
| DELETE FROM attendance | |
| WHERE student_id = %s AND timestamp::date = %s | |
| """, (data.student_id, data.date)) | |
| conn.commit() | |
| return {"message": "Marked absent successfully."} | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid status.") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") | |
| def read_root(): | |
| return {"status": "ok", "message": "VisionAttend API is active"} | |