visionattend-api / main.py
Shevilll's picture
feat: add keep-alive GitHub Action for Hugging Face Space and remove redundant startup model warm-up
4e661a4
import os
import threading
import tempfile
import urllib.request
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import psycopg
from pgvector.psycopg import register_vector
import numpy as np
import uuid
import io
from PIL import Image
from deepface import DeepFace
from datetime import date
from dotenv import load_dotenv
load_dotenv()
app = FastAPI(title="VisionAttend API")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with frontend URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DATABASE_URL = os.getenv("DATABASE_URL")
FACENET_WEIGHTS_PATH = os.path.join(
os.path.expanduser("~"),
".deepface",
"weights",
"facenet_weights.h5",
)
FACENET_WEIGHTS_URL = "https://huggingface.co/junjiang/GestureFace/resolve/main/facenet_weights.h5"
FACENET_MIN_BYTES = 80 * 1024 * 1024
_facenet_model_lock = threading.Lock()
_facenet_model_ready = False
class MatchResponse(BaseModel):
message: str
student_name: str
roll_number: str
from psycopg_pool import ConnectionPool
def configure_conn(conn):
register_vector(conn)
# By using a ConnectionPool, we keep the TCP connection alive to Neon.
# This eliminates the ~6-8 second TLS/SSL handshake delay on every request.
db_pool = None
if DATABASE_URL:
try:
db_pool = ConnectionPool(
conninfo=DATABASE_URL,
configure=configure_conn,
min_size=1,
max_size=5,
timeout=30.0
)
except Exception as e:
print(f"WARNING: Failed to initialize database pool: {e}")
else:
print("WARNING: DATABASE_URL is not set. Database connections will fail.")
def get_db_connection():
if not db_pool:
raise HTTPException(status_code=500, detail="Database is not configured or failed to connect. Please set DATABASE_URL in Hugging Face Spaces Settings -> Variables and secrets.")
return db_pool.connection()
def ensure_facenet_weights_file() -> None:
weights_dir = os.path.dirname(FACENET_WEIGHTS_PATH)
os.makedirs(weights_dir, exist_ok=True)
if os.path.exists(FACENET_WEIGHTS_PATH) and os.path.getsize(FACENET_WEIGHTS_PATH) >= FACENET_MIN_BYTES:
return
with tempfile.NamedTemporaryFile(delete=False, dir=weights_dir, suffix=".tmp") as temp_file:
temp_path = temp_file.name
try:
with urllib.request.urlopen(FACENET_WEIGHTS_URL, timeout=120) as response, open(temp_path, "wb") as output_file:
while True:
chunk = response.read(1024 * 1024)
if not chunk:
break
output_file.write(chunk)
if os.path.getsize(temp_path) < FACENET_MIN_BYTES:
raise ValueError("Downloaded FaceNet weights file is incomplete.")
os.replace(temp_path, FACENET_WEIGHTS_PATH)
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
def ensure_facenet_model_loaded() -> None:
global _facenet_model_ready
if _facenet_model_ready:
return
with _facenet_model_lock:
if _facenet_model_ready:
return
try:
ensure_facenet_weights_file()
DeepFace.build_model("Facenet")
_facenet_model_ready = True
return
except Exception as first_error:
if os.path.exists(FACENET_WEIGHTS_PATH):
try:
os.remove(FACENET_WEIGHTS_PATH)
except OSError:
raise first_error
ensure_facenet_weights_file()
DeepFace.build_model("Facenet")
_facenet_model_ready = True
def extract_face_embedding(file_bytes: bytes) -> tuple[list[float], str]:
try:
ensure_facenet_model_loaded()
image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
img_array = np.array(image)
results = DeepFace.represent(
img_path=img_array,
enforce_detection=True,
model_name="Facenet",
detector_backend="opencv"
)
if len(results) == 0:
raise ValueError("No face detected.")
if len(results) > 1:
raise ValueError("Multiple faces detected. Please show only one face.")
embedding = results[0]["embedding"]
return embedding, "success"
except ValueError as ve:
return [], str(ve)
except Exception as e:
return [], f"An error occurred during face extraction: {str(e)}"
@app.post("/api/register")
async def register_student(
name: str = Form(...),
roll_number: str = Form(...),
image: UploadFile = File(...)
):
file_bytes = await image.read()
embedding, error = extract_face_embedding(file_bytes)
if not embedding:
raise HTTPException(status_code=400, detail=error)
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute("SELECT id FROM students WHERE roll_number = %s", (roll_number,))
if cur.fetchone():
raise HTTPException(status_code=400, detail="Student with this roll number already exists.")
cur.execute("""
SELECT name, roll_number, (face_encoding <=> %s::vector) AS distance
FROM students
ORDER BY distance ASC
LIMIT 1;
""", (str(embedding),))
result = cur.fetchone()
if result:
existing_name, existing_roll_number, distance = result
if distance <= 0.40:
raise HTTPException(status_code=400, detail=f"Face already registered to {existing_name} under roll number {existing_roll_number}.")
cur.execute(
"INSERT INTO students (name, roll_number, face_encoding) VALUES (%s, %s, %s) RETURNING id",
(name, roll_number, str(embedding))
)
student_id = cur.fetchone()[0]
conn.commit()
return {"message": "Student registered successfully", "student_id": student_id}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
@app.post("/api/recognize", response_model=MatchResponse)
async def recognize_student(image: UploadFile = File(...)):
file_bytes = await image.read()
embedding, error = extract_face_embedding(file_bytes)
if not embedding:
raise HTTPException(status_code=400, detail=error)
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute("""
SELECT id, name, roll_number, (face_encoding <=> %s::vector) AS distance
FROM students
ORDER BY distance ASC
LIMIT 1;
""", (str(embedding),))
result = cur.fetchone()
if not result:
raise HTTPException(status_code=404, detail="No match found in database.")
student_id, name, roll_number, distance = result
print(f"Discovered {name} with distance {distance}")
if distance > 0.40: # Threshold for Facenet
raise HTTPException(status_code=404, detail="Face recognized but distance exceeds confidence threshold.")
cur.execute("""
SELECT id FROM attendance
WHERE student_id = %s AND timestamp::date = %s
""", (student_id, date.today()))
if not cur.fetchone():
cur.execute(
"INSERT INTO attendance (student_id, status) VALUES (%s, 'Present')",
(student_id,)
)
conn.commit()
message = "Attendance marked."
else:
message = "Attendance already marked for today."
return MatchResponse(message=message, student_name=name, roll_number=roll_number)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
@app.get("/api/attendance")
async def get_attendance(date_filter: str = None):
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
if date_filter:
cur.execute("""
SELECT s.id, s.name, s.roll_number, a.timestamp,
CASE WHEN a.id IS NOT NULL THEN 'Present' ELSE 'Absent' END as status,
s.id as student_id
FROM students s
LEFT JOIN attendance a ON s.id = a.student_id AND DATE(a.timestamp) = %s
ORDER BY
CASE WHEN a.id IS NOT NULL THEN 0 ELSE 1 END,
s.name ASC
LIMIT 100;
""", (date_filter,))
rows = cur.fetchall()
logs = [
{
"id": str(row[0]),
"name": row[1],
"roll_number": row[2],
"timestamp": row[3],
"status": row[4],
"student_id": str(row[5])
} for row in rows
]
else:
cur.execute("""
SELECT a.id, s.name, s.roll_number, a.timestamp, a.status, s.id as student_id
FROM attendance a
JOIN students s ON a.student_id = s.id
ORDER BY a.timestamp DESC
LIMIT 100;
""")
rows = cur.fetchall()
logs = [
{
"id": str(row[0]),
"name": row[1],
"roll_number": row[2],
"timestamp": row[3],
"status": row[4],
"student_id": str(row[5])
} for row in rows
]
return {"logs": logs}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
class StudentUpdate(BaseModel):
name: str
roll_number: str
@app.put("/api/students/{student_id}")
async def update_student(student_id: uuid.UUID, data: StudentUpdate):
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute("SELECT id FROM students WHERE roll_number = %s AND id != %s", (data.roll_number, student_id))
if cur.fetchone():
raise HTTPException(status_code=400, detail="Another student with this roll number already exists.")
cur.execute(
"UPDATE students SET name = %s, roll_number = %s WHERE id = %s RETURNING id",
(data.name, data.roll_number, student_id)
)
if not cur.fetchone():
raise HTTPException(status_code=404, detail="Student not found.")
conn.commit()
return {"message": "Student updated successfully."}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
@app.delete("/api/students/{student_id}")
async def delete_student(student_id: uuid.UUID):
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute("DELETE FROM students WHERE id = %s RETURNING id", (student_id,))
if not cur.fetchone():
raise HTTPException(status_code=404, detail="Student not found.")
conn.commit()
return {"message": "Student deleted successfully."}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
@app.put("/api/students/{student_id}/photo")
async def update_student_photo(student_id: uuid.UUID, image: UploadFile = File(...)):
file_bytes = await image.read()
embedding, error = extract_face_embedding(file_bytes)
if not embedding:
raise HTTPException(status_code=400, detail=error)
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute("""
SELECT name, roll_number, (face_encoding <=> %s::vector) AS distance
FROM students
WHERE id != %s
ORDER BY distance ASC
LIMIT 1;
""", (str(embedding), student_id))
result = cur.fetchone()
if result:
existing_name, existing_roll_number, distance = result
if distance <= 0.40:
raise HTTPException(status_code=400, detail=f"This face is already registered to {existing_name} (Roll: {existing_roll_number}).")
cur.execute(
"UPDATE students SET face_encoding = %s WHERE id = %s RETURNING id",
(str(embedding), student_id)
)
if not cur.fetchone():
raise HTTPException(status_code=404, detail="Student not found.")
conn.commit()
return {"message": "Student photo updated successfully."}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
class ManualAttendance(BaseModel):
student_id: uuid.UUID
date: str # YYYY-MM-DD
status: str # "Present" or "Absent"
@app.post("/api/attendance/manual")
async def mark_manual_attendance(data: ManualAttendance):
try:
with get_db_connection() as conn:
with conn.cursor() as cur:
if data.status == "Present":
cur.execute("""
SELECT id FROM attendance
WHERE student_id = %s AND timestamp::date = %s
""", (data.student_id, data.date))
if not cur.fetchone():
cur.execute(
"INSERT INTO attendance (student_id, status, timestamp) VALUES (%s, 'Present', %s)",
(data.student_id, f"{data.date} 12:00:00") # Default to noon for manual entries
)
conn.commit()
return {"message": "Marked present successfully."}
else:
return {"message": "Already marked present for this date."}
elif data.status == "Absent":
cur.execute("""
DELETE FROM attendance
WHERE student_id = %s AND timestamp::date = %s
""", (data.student_id, data.date))
conn.commit()
return {"message": "Marked absent successfully."}
else:
raise HTTPException(status_code=400, detail="Invalid status.")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
@app.get("/")
def read_root():
return {"status": "ok", "message": "VisionAttend API is active"}