Spaces:
Running
Running
Alief Gilang Permana Putra commited on
Commit ·
af35098
1
Parent(s): fb55838
feat: Add files for inference
Browse files- .dockerignore +22 -0
- .env.example +7 -0
- Dockerfile +52 -0
- api/endpoints/predict.py +18 -0
- api/endpoints/system.py +25 -0
- api/router.py +7 -0
- assets/blaze_face_short_range.tflite +3 -0
- config/metadata.json +6 -0
- config/models.json +26 -0
- core/config.py +16 -0
- core/exceptions.py +16 -0
- main.py +56 -0
- requirements.txt +76 -0
- schemas/predict.py +32 -0
- schemas/system.py +31 -0
- services/face_extractor.py +74 -0
- services/inference.py +27 -0
- services/model_manager.py +104 -0
.dockerignore
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore python virtual environments and dependencies
|
| 2 |
+
pytorch-cuda/
|
| 3 |
+
venv/
|
| 4 |
+
.venv/
|
| 5 |
+
env/
|
| 6 |
+
|
| 7 |
+
# Ignore python cache files
|
| 8 |
+
__pycache__/
|
| 9 |
+
*.pyc
|
| 10 |
+
*.pyo
|
| 11 |
+
*.pyd
|
| 12 |
+
|
| 13 |
+
# Ignore local environment file (production secrets should be injected via env variables)
|
| 14 |
+
.env
|
| 15 |
+
|
| 16 |
+
# Ignore git folder
|
| 17 |
+
.git/
|
| 18 |
+
.gitignore
|
| 19 |
+
|
| 20 |
+
# Ignore docker files
|
| 21 |
+
Dockerfile
|
| 22 |
+
.dockerignore
|
.env.example
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Server configuration
|
| 2 |
+
HOST=0.0.0.0
|
| 3 |
+
PORT=8000
|
| 4 |
+
DEBUG_MODE=True
|
| 5 |
+
|
| 6 |
+
# Hugging Face Token
|
| 7 |
+
HF_TOKEN=your_huggingface_access_token_here
|
Dockerfile
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use a stable, official Python base image
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set environment variables
|
| 5 |
+
# PYTHONUNBUFFERED=1 ensures console logs are printed immediately
|
| 6 |
+
# PYTHONDONTWRITEBYTECODE=1 prevents python from writing .pyc files
|
| 7 |
+
# PORT=7860 is the default port for Hugging Face Spaces
|
| 8 |
+
# HOME=/home/user sets the home folder for the non-root user
|
| 9 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 10 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 11 |
+
PORT=7860 \
|
| 12 |
+
HOST=0.0.0.0 \
|
| 13 |
+
HOME=/home/user
|
| 14 |
+
|
| 15 |
+
# Install system dependencies required by OpenCV, MediaPipe, and other libraries
|
| 16 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 17 |
+
build-essential \
|
| 18 |
+
libgl1 \
|
| 19 |
+
libglib2.0-0 \
|
| 20 |
+
libgomp1 \
|
| 21 |
+
sed \
|
| 22 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 23 |
+
|
| 24 |
+
# Create a non-root user with UID 1000 (Hugging Face Spaces runs as UID 1000)
|
| 25 |
+
RUN useradd -m -u 1000 user
|
| 26 |
+
WORKDIR /app
|
| 27 |
+
|
| 28 |
+
# Copy requirements.txt first for build caching
|
| 29 |
+
COPY --chown=user:user requirements.txt /app/
|
| 30 |
+
|
| 31 |
+
# Remove the custom local PyTorch wheels from requirements.txt to install standard stable versions
|
| 32 |
+
# from PyPI, supporting both CPU and GPU workloads automatically.
|
| 33 |
+
RUN sed -i '/torch==/d' requirements.txt && \
|
| 34 |
+
sed -i '/torchvision==/d' requirements.txt && \
|
| 35 |
+
pip install --no-cache-dir --upgrade pip && \
|
| 36 |
+
pip install --no-cache-dir torch torchvision && \
|
| 37 |
+
pip install --no-cache-dir -r requirements.txt
|
| 38 |
+
|
| 39 |
+
# Copy the rest of the application files
|
| 40 |
+
COPY --chown=user:user . /app/
|
| 41 |
+
|
| 42 |
+
# Setup a writeable Hugging Face cache directory inside the home folder of user 1000
|
| 43 |
+
RUN mkdir -p /home/user/.cache/huggingface && chown -R user:user /home/user
|
| 44 |
+
|
| 45 |
+
# Switch to the non-root user
|
| 46 |
+
USER user
|
| 47 |
+
|
| 48 |
+
# Expose the default port (Hugging Face Spaces automatically forwards traffic to 7860)
|
| 49 |
+
EXPOSE 7860
|
| 50 |
+
|
| 51 |
+
# Run the FastAPI server
|
| 52 |
+
CMD ["python", "main.py"]
|
api/endpoints/predict.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
from schemas.predict import InferenceRequest, PredictionResponse
|
| 3 |
+
from schemas.system import ModelsListResponse
|
| 4 |
+
from services.model_manager import model_manager
|
| 5 |
+
|
| 6 |
+
router = APIRouter(tags=["Inference"])
|
| 7 |
+
|
| 8 |
+
@router.get("/models", response_model=ModelsListResponse)
|
| 9 |
+
async def list_models():
|
| 10 |
+
"""List all available models loaded in memory for inference."""
|
| 11 |
+
return {
|
| 12 |
+
"available_models": list(model_manager.model_configs.values())
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
@router.post("/predict", response_model=PredictionResponse)
|
| 16 |
+
async def predict_personality(request: InferenceRequest):
|
| 17 |
+
"""Predict Big Five personality traits from a base64 encoded face image."""
|
| 18 |
+
return model_manager.predict(request.model_type, request.image_base64)
|
api/endpoints/system.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from fastapi import APIRouter
|
| 3 |
+
from schemas.system import MetadataResponse, HealthResponse
|
| 4 |
+
from services.model_manager import model_manager, DEVICE
|
| 5 |
+
|
| 6 |
+
router = APIRouter(tags=["System"])
|
| 7 |
+
|
| 8 |
+
@router.get("/", response_model=MetadataResponse)
|
| 9 |
+
async def root():
|
| 10 |
+
"""Standard root endpoint providing API metadata."""
|
| 11 |
+
with open("config/metadata.json", "r") as f:
|
| 12 |
+
metadata = json.load(f)
|
| 13 |
+
metadata["documentation"] = "/docs"
|
| 14 |
+
return metadata
|
| 15 |
+
|
| 16 |
+
@router.get("/health", response_model=HealthResponse)
|
| 17 |
+
async def health_check():
|
| 18 |
+
"""API Health check"""
|
| 19 |
+
return {
|
| 20 |
+
"status": "healthy",
|
| 21 |
+
"device": DEVICE,
|
| 22 |
+
"models_loaded": list(model_manager.models.keys()),
|
| 23 |
+
"port": "auto"
|
| 24 |
+
}
|
| 25 |
+
|
api/router.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
from api.endpoints import system, predict
|
| 3 |
+
|
| 4 |
+
api_router = APIRouter()
|
| 5 |
+
|
| 6 |
+
api_router.include_router(system.router)
|
| 7 |
+
api_router.include_router(predict.router)
|
assets/blaze_face_short_range.tflite
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4578f35940bf5a1a655214a1cce5cab13eba73c1297cd78e1a04c2380b0152f
|
| 3 |
+
size 229746
|
config/metadata.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"api_name": "Big Five Personality Inference API",
|
| 3 |
+
"description": "API for predicting Big Five personality traits (OCEAN) from facial images using deep learning vision models",
|
| 4 |
+
"version": "1.2.0",
|
| 5 |
+
"status": "online"
|
| 6 |
+
}
|
config/models.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"id": "vit-b16-augreg-in21k",
|
| 4 |
+
"name": "ViT-B/16 AugReg (IN21k) - Arch Tuning + Augmentation",
|
| 5 |
+
"description": "ViT Base patch 16 ArchTuning + AugReg model Trained in 21k images",
|
| 6 |
+
"repo_id": "lyfesan/vit_base_patch16_224_augreg_in21k_Run_D_The_Ultimate_bigfive"
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"id": "vit-b16-augreg-in21k-ft1k",
|
| 10 |
+
"name": "ViT-B/16 AugReg (IN21k+1k) - Arch Tuning + Augmentation",
|
| 11 |
+
"description": "ViT Base patch 16 ArchTuning + AugReg model Trained in 21k images finetuned in 1k images",
|
| 12 |
+
"repo_id": "lyfesan/vit_base_patch16_224_augreg_in21k_ft_in1k_Run_D_The_Ultimate_bigfive"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"id": "swinv2-w12-16-archtuning-in22k-ft1k",
|
| 16 |
+
"name": "SwinV2-B w12-16 (IN22k+1k) - Arch Tuning",
|
| 17 |
+
"description": "SwinV2 window12-16 ArchTuning model Trained in 22k images finetuned in 1k images",
|
| 18 |
+
"repo_id": "lyfesan/swinv2_base_window12to16_192to256_ms_in22k_ft_in1k_Run_B_Arch_Tuning_bigfive"
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"id": "swinv2-w16-archtuning-in1k",
|
| 22 |
+
"name": "SwinV2-B w16 (IN1k) - Arch Tuning",
|
| 23 |
+
"description": "SwinV2 window16 ArchTuning model trained in 1k images",
|
| 24 |
+
"repo_id": "lyfesan/swinv2_base_window16_256_ms_in1k_Run_B_Arch_Tuning_bigfive"
|
| 25 |
+
}
|
| 26 |
+
]
|
core/config.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from huggingface_hub import login
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
# Load environment variables
|
| 6 |
+
load_dotenv(override=True)
|
| 7 |
+
|
| 8 |
+
HOST = os.getenv("HOST", "0.0.0.0")
|
| 9 |
+
PORT = int(os.getenv("PORT", 8000))
|
| 10 |
+
DEBUG_MODE = os.getenv("DEBUG_MODE", "False").lower() in ("true", "1", "t")
|
| 11 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 12 |
+
|
| 13 |
+
# Authenticate with Hugging Face
|
| 14 |
+
if HF_TOKEN and HF_TOKEN != "your_huggingface_access_token_here":
|
| 15 |
+
print("Logging into Hugging Face Hub...")
|
| 16 |
+
login(token=HF_TOKEN)
|
core/exceptions.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Request
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
| 4 |
+
|
| 5 |
+
async def custom_404_handler(request: Request, exc: StarletteHTTPException):
|
| 6 |
+
"""Custom handler to format 404 errors cleanly."""
|
| 7 |
+
if exc.status_code == 404:
|
| 8 |
+
return JSONResponse(
|
| 9 |
+
status_code=404,
|
| 10 |
+
content={
|
| 11 |
+
"error": "Endpoint not found",
|
| 12 |
+
"path": request.url.path,
|
| 13 |
+
"message": "Please check the URL or visit /docs for available endpoints."
|
| 14 |
+
}
|
| 15 |
+
)
|
| 16 |
+
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
|
main.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import asynccontextmanager
|
| 2 |
+
import json
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
| 5 |
+
|
| 6 |
+
from core.config import HOST, PORT, DEBUG_MODE
|
| 7 |
+
from core.exceptions import custom_404_handler
|
| 8 |
+
from api.router import api_router
|
| 9 |
+
from services.model_manager import model_manager, DEVICE
|
| 10 |
+
|
| 11 |
+
# Load metadata from config file
|
| 12 |
+
with open("config/metadata.json", "r") as f:
|
| 13 |
+
METADATA = json.load(f)
|
| 14 |
+
|
| 15 |
+
@asynccontextmanager
|
| 16 |
+
async def lifespan(app: FastAPI):
|
| 17 |
+
print("Downloading/Loading models into VRAM (this takes a moment on first run)...")
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
with open("config/models.json", "r") as f:
|
| 21 |
+
models_config = json.load(f)
|
| 22 |
+
|
| 23 |
+
for model_info in models_config:
|
| 24 |
+
model_manager.load_hf_model_pipeline(
|
| 25 |
+
model_info["id"],
|
| 26 |
+
model_info["repo_id"],
|
| 27 |
+
model_info=model_info
|
| 28 |
+
)
|
| 29 |
+
except FileNotFoundError:
|
| 30 |
+
print("⚠️ models.json not found in config/. No models loaded automatically.")
|
| 31 |
+
|
| 32 |
+
yield
|
| 33 |
+
print("Shutting down API and releasing resources...")
|
| 34 |
+
model_manager.models.clear()
|
| 35 |
+
model_manager.transforms_dict.clear()
|
| 36 |
+
if hasattr(model_manager, "model_configs"):
|
| 37 |
+
model_manager.model_configs.clear()
|
| 38 |
+
|
| 39 |
+
app = FastAPI(
|
| 40 |
+
title=METADATA["api_name"],
|
| 41 |
+
description=METADATA["description"],
|
| 42 |
+
version=METADATA["version"],
|
| 43 |
+
debug=DEBUG_MODE,
|
| 44 |
+
lifespan=lifespan,
|
| 45 |
+
)
|
| 46 |
+
print(f"API Engine initialized on: {DEVICE.upper()}")
|
| 47 |
+
|
| 48 |
+
# Register exception handlers
|
| 49 |
+
app.add_exception_handler(StarletteHTTPException, custom_404_handler)
|
| 50 |
+
|
| 51 |
+
# Include routers
|
| 52 |
+
app.include_router(api_router)
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
import uvicorn
|
| 56 |
+
uvicorn.run(app, host=HOST, port=PORT)
|
requirements.txt
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiohappyeyeballs==2.6.1
|
| 2 |
+
aiohttp==3.13.5
|
| 3 |
+
aiosignal==1.4.0
|
| 4 |
+
annotated-doc==0.0.4
|
| 5 |
+
annotated-types==0.7.0
|
| 6 |
+
anyio==4.13.0
|
| 7 |
+
attrs==26.1.0
|
| 8 |
+
certifi==2026.4.22
|
| 9 |
+
charset-normalizer==3.4.7
|
| 10 |
+
click==8.3.3
|
| 11 |
+
colorama==0.4.6
|
| 12 |
+
contourpy==1.3.3
|
| 13 |
+
cycler==0.12.1
|
| 14 |
+
datasets==4.8.5
|
| 15 |
+
dill==0.4.1
|
| 16 |
+
fastapi==0.136.1
|
| 17 |
+
filelock==3.25.2
|
| 18 |
+
fonttools==4.62.1
|
| 19 |
+
frozenlist==1.8.0
|
| 20 |
+
fsspec==2026.2.0
|
| 21 |
+
h11==0.16.0
|
| 22 |
+
hf-xet==1.5.0
|
| 23 |
+
httpcore==1.0.9
|
| 24 |
+
httpx==0.28.1
|
| 25 |
+
huggingface_hub==1.14.0
|
| 26 |
+
idna==3.13
|
| 27 |
+
Jinja2==3.1.6
|
| 28 |
+
joblib==1.5.3
|
| 29 |
+
kiwisolver==1.5.0
|
| 30 |
+
markdown-it-py==4.1.0
|
| 31 |
+
MarkupSafe==3.0.3
|
| 32 |
+
matplotlib==3.10.9
|
| 33 |
+
mdurl==0.1.2
|
| 34 |
+
mediapipe==0.10.35
|
| 35 |
+
mpmath==1.3.0
|
| 36 |
+
multidict==6.7.1
|
| 37 |
+
multiprocess==0.70.19
|
| 38 |
+
networkx==3.6.1
|
| 39 |
+
numpy==2.4.3
|
| 40 |
+
packaging==26.2
|
| 41 |
+
pandas==3.0.2
|
| 42 |
+
pillow==12.1.1
|
| 43 |
+
propcache==0.4.1
|
| 44 |
+
pyarrow==24.0.0
|
| 45 |
+
pydantic==2.13.4
|
| 46 |
+
pydantic_core==2.46.4
|
| 47 |
+
Pygments==2.20.0
|
| 48 |
+
pyparsing==3.3.2
|
| 49 |
+
python-dateutil==2.9.0.post0
|
| 50 |
+
python-dotenv==1.2.2
|
| 51 |
+
python-multipart==0.0.27
|
| 52 |
+
PyYAML==6.0.3
|
| 53 |
+
requests==2.33.1
|
| 54 |
+
rich==15.0.0
|
| 55 |
+
safetensors==0.7.0
|
| 56 |
+
scikit-learn==1.8.0
|
| 57 |
+
scipy==1.17.1
|
| 58 |
+
seaborn==0.13.2
|
| 59 |
+
setuptools==70.2.0
|
| 60 |
+
shellingham==1.5.4
|
| 61 |
+
six==1.17.0
|
| 62 |
+
starlette==1.0.0
|
| 63 |
+
sympy==1.14.0
|
| 64 |
+
threadpoolctl==3.6.0
|
| 65 |
+
timm==1.0.26
|
| 66 |
+
torch==2.11.0+cu130
|
| 67 |
+
torchvision==0.26.0+cu130
|
| 68 |
+
tqdm==4.67.3
|
| 69 |
+
typer==0.25.1
|
| 70 |
+
typing-inspection==0.4.2
|
| 71 |
+
typing_extensions==4.15.0
|
| 72 |
+
tzdata==2026.2
|
| 73 |
+
urllib3==2.6.3
|
| 74 |
+
uvicorn==0.46.0
|
| 75 |
+
xxhash==3.7.0
|
| 76 |
+
yarl==1.23.0
|
schemas/predict.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class InferenceRequest(BaseModel):
|
| 6 |
+
"""Request body for Big Five personality trait prediction."""
|
| 7 |
+
model_type: str = Field(
|
| 8 |
+
...,
|
| 9 |
+
description="The ID of the vision model to use for inference",
|
| 10 |
+
examples=["swinv2", "vit", "pvtv2"],
|
| 11 |
+
)
|
| 12 |
+
image_base64: str = Field(
|
| 13 |
+
...,
|
| 14 |
+
description="Base64-encoded image string (JPEG/PNG). Data URI prefix is optional.",
|
| 15 |
+
examples=["iVBORw0KGgoAAAANSUhEUg..."],
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class OCEANTraits(BaseModel):
|
| 20 |
+
"""Big Five (OCEAN) personality trait scores, each ranging from 0.0 to 1.0."""
|
| 21 |
+
Openness: float = Field(..., ge=0.0, le=1.0, description="Openness to experience", examples=[0.62])
|
| 22 |
+
Conscientiousness: float = Field(..., ge=0.0, le=1.0, description="Conscientiousness", examples=[0.63])
|
| 23 |
+
Extraversion: float = Field(..., ge=0.0, le=1.0, description="Extraversion", examples=[0.54])
|
| 24 |
+
Agreeableness: float = Field(..., ge=0.0, le=1.0, description="Agreeableness", examples=[0.63])
|
| 25 |
+
Neuroticism: float = Field(..., ge=0.0, le=1.0, description="Neuroticism", examples=[0.60])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PredictionResponse(BaseModel):
|
| 29 |
+
"""Response containing the model used, predicted OCEAN traits, and the cropped face image."""
|
| 30 |
+
model_used: str = Field(..., description="The ID of the model that produced the prediction", examples=["swinv2"])
|
| 31 |
+
predictions: OCEANTraits = Field(..., description="Predicted Big Five personality trait scores")
|
| 32 |
+
cropped_face_base64: Optional[str] = Field(None, description="Base64 encoded cropped face image, if face extraction was used.", examples=["/9j/4AAQSkZJRgABAQEASABIAAD/4..."])
|
schemas/system.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MetadataResponse(BaseModel):
|
| 6 |
+
"""API metadata returned by the root endpoint."""
|
| 7 |
+
api_name: str = Field(..., description="Name of the API", examples=["Big Five Personality Inference API"])
|
| 8 |
+
description: str = Field(..., description="Brief description of the API", examples=["API for predicting Big Five personality traits"])
|
| 9 |
+
version: str = Field(..., description="Current API version", examples=["1.2.0"])
|
| 10 |
+
status: str = Field(..., description="Current API status", examples=["online"])
|
| 11 |
+
documentation: str = Field(..., description="Path to interactive API docs", examples=["/docs"])
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class HealthResponse(BaseModel):
|
| 15 |
+
"""Health check response."""
|
| 16 |
+
status: str = Field(..., description="Health status of the API", examples=["healthy"])
|
| 17 |
+
device: str = Field(..., description="Compute device in use", examples=["cuda"])
|
| 18 |
+
models_loaded: List[str] = Field(..., description="List of model IDs currently loaded in memory", examples=[["swinv2", "vit", "pvtv2"]])
|
| 19 |
+
port: str = Field(..., description="Port configuration", examples=["auto"])
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ModelDetail(BaseModel):
|
| 23 |
+
"""Details of an available inference model."""
|
| 24 |
+
id: str = Field(..., description="The ID of the model to be used in predictions", examples=["swinv2"])
|
| 25 |
+
name: str = Field(..., description="Human-readable name of the model", examples=["Swin Transformer V2"])
|
| 26 |
+
description: str = Field(..., description="Description of the model", examples=["SwinV2 Base model optimized for Big Five personality traits prediction"])
|
| 27 |
+
repo_id: str = Field(..., description="Hugging Face model repository ID", examples=["lyfesan/swinv2_base_..."])
|
| 28 |
+
|
| 29 |
+
class ModelsListResponse(BaseModel):
|
| 30 |
+
"""List of available inference models."""
|
| 31 |
+
available_models: List[ModelDetail] = Field(..., description="List of models available for inference")
|
services/face_extractor.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import mediapipe as mp
|
| 5 |
+
from mediapipe.tasks import python
|
| 6 |
+
from mediapipe.tasks.python import vision
|
| 7 |
+
|
| 8 |
+
class FaceExtractor:
|
| 9 |
+
def __init__(self, model_path: str = "assets/blaze_face_short_range.tflite"):
|
| 10 |
+
self.model_path = model_path
|
| 11 |
+
base_options = python.BaseOptions(model_asset_path=self.model_path)
|
| 12 |
+
options = vision.FaceDetectorOptions(
|
| 13 |
+
base_options=base_options,
|
| 14 |
+
running_mode=vision.RunningMode.IMAGE,
|
| 15 |
+
min_detection_confidence=0.70
|
| 16 |
+
)
|
| 17 |
+
self.detector = vision.FaceDetector.create_from_options(options)
|
| 18 |
+
self.offset_percentage = 0.30
|
| 19 |
+
|
| 20 |
+
def extract_main_face(self, pil_image: Image.Image) -> Image.Image:
|
| 21 |
+
"""
|
| 22 |
+
Detects faces in the given PIL Image, scores them to find the main face,
|
| 23 |
+
and returns the cropped main face. Returns None if no face is detected.
|
| 24 |
+
"""
|
| 25 |
+
# Convert PIL Image to numpy array (RGB)
|
| 26 |
+
frame = np.array(pil_image)
|
| 27 |
+
|
| 28 |
+
img_h, img_w, _ = frame.shape
|
| 29 |
+
frame_cx, frame_cy = img_w / 2, img_h / 2
|
| 30 |
+
|
| 31 |
+
# Mediapipe requires the image to be in ImageFormat.SRGB
|
| 32 |
+
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
|
| 33 |
+
results = self.detector.detect(mp_image)
|
| 34 |
+
|
| 35 |
+
if not results.detections:
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
best_face_bbox = None
|
| 39 |
+
highest_score = -float('inf')
|
| 40 |
+
|
| 41 |
+
for detection in results.detections:
|
| 42 |
+
bbox = detection.bounding_box
|
| 43 |
+
confidence = detection.categories[0].score
|
| 44 |
+
x, y, w, h = bbox.origin_x, bbox.origin_y, bbox.width, bbox.height
|
| 45 |
+
face_cx, face_cy = x + (w / 2), y + (h / 2)
|
| 46 |
+
|
| 47 |
+
area = w * h
|
| 48 |
+
distance_to_center = math.sqrt((frame_cx - face_cx)**2 + (frame_cy - face_cy)**2)
|
| 49 |
+
score = (area * confidence) - (distance_to_center * 50)
|
| 50 |
+
|
| 51 |
+
if score > highest_score:
|
| 52 |
+
highest_score = score
|
| 53 |
+
best_face_bbox = (x, y, w, h)
|
| 54 |
+
|
| 55 |
+
if not best_face_bbox:
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
# Crop with offset
|
| 59 |
+
x, y, w, h = best_face_bbox
|
| 60 |
+
offset_w = int(w * self.offset_percentage)
|
| 61 |
+
offset_h = int(h * self.offset_percentage)
|
| 62 |
+
|
| 63 |
+
new_x = max(0, x - offset_w)
|
| 64 |
+
new_y = max(0, y - offset_h)
|
| 65 |
+
new_w = min(img_w - new_x, w + (2 * offset_w))
|
| 66 |
+
new_h = min(img_h - new_y, h + (2 * offset_h))
|
| 67 |
+
|
| 68 |
+
cropped_face_np = frame[new_y:new_y+new_h, new_x:new_x+new_w]
|
| 69 |
+
|
| 70 |
+
# Convert back to PIL Image
|
| 71 |
+
if cropped_face_np.size > 0:
|
| 72 |
+
return Image.fromarray(cropped_face_np)
|
| 73 |
+
|
| 74 |
+
return None
|
services/inference.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import timm
|
| 3 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 4 |
+
|
| 5 |
+
class BigFiveRegressor(nn.Module, PyTorchModelHubMixin):
|
| 6 |
+
def __init__(self, timm_name, use_complex_head=True):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.backbone = timm.create_model(timm_name, pretrained=False, num_classes=0)
|
| 9 |
+
num_features = self.backbone.num_features
|
| 10 |
+
|
| 11 |
+
if use_complex_head:
|
| 12 |
+
self.regression_head = nn.Sequential(
|
| 13 |
+
nn.Linear(num_features, 512),
|
| 14 |
+
nn.GELU(),
|
| 15 |
+
nn.Dropout(0.3),
|
| 16 |
+
nn.Linear(512, 5),
|
| 17 |
+
nn.Sigmoid()
|
| 18 |
+
)
|
| 19 |
+
else:
|
| 20 |
+
self.regression_head = nn.Sequential(
|
| 21 |
+
nn.Linear(num_features, 5),
|
| 22 |
+
nn.Sigmoid()
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
features = self.backbone(x)
|
| 27 |
+
return self.regression_head(features)
|
services/model_manager.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import base64
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from fastapi import HTTPException
|
| 7 |
+
from services.inference import BigFiveRegressor
|
| 8 |
+
from schemas.predict import OCEANTraits, PredictionResponse
|
| 9 |
+
from services.face_extractor import FaceExtractor
|
| 10 |
+
|
| 11 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
+
# DEVICE = "cpu"
|
| 13 |
+
class ModelManager:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.models = {}
|
| 16 |
+
self.transforms_dict = {}
|
| 17 |
+
self.model_configs = {}
|
| 18 |
+
try:
|
| 19 |
+
self.face_extractor = FaceExtractor()
|
| 20 |
+
except Exception as e:
|
| 21 |
+
print(f"Warning: Failed to initialize FaceExtractor: {e}")
|
| 22 |
+
self.face_extractor = None
|
| 23 |
+
|
| 24 |
+
def load_hf_model_pipeline(self, model_key: str, repo_id: str, model_info: dict = None):
|
| 25 |
+
"""Loads model from Hugging Face and creates its specific preprocessing transform."""
|
| 26 |
+
try:
|
| 27 |
+
model = BigFiveRegressor.from_pretrained(repo_id)
|
| 28 |
+
model.to(DEVICE)
|
| 29 |
+
model.eval()
|
| 30 |
+
|
| 31 |
+
# SwinV2 uses 256x256, ViT/PVTv2 use 224x224
|
| 32 |
+
IMG_SIZE = 256 if 'swinv2' in model_key else 224
|
| 33 |
+
transform = transforms.Compose([
|
| 34 |
+
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
| 35 |
+
transforms.ToTensor(),
|
| 36 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 37 |
+
])
|
| 38 |
+
|
| 39 |
+
self.models[model_key] = model
|
| 40 |
+
self.transforms_dict[model_key] = transform
|
| 41 |
+
if model_info:
|
| 42 |
+
self.model_configs[model_key] = model_info
|
| 43 |
+
print(f"✅ Loaded {model_key.upper()} from {repo_id}")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"⚠️ Failed to load {model_key} from {repo_id}. Error: {e}")
|
| 46 |
+
|
| 47 |
+
def predict(self, model_type: str, image_base64: str) -> PredictionResponse:
|
| 48 |
+
model_type_lower = model_type.lower()
|
| 49 |
+
if model_type_lower not in self.models:
|
| 50 |
+
raise HTTPException(status_code=400, detail=f"Invalid model type. Choose from: {list(self.models.keys())}")
|
| 51 |
+
|
| 52 |
+
# Decode Base64 to Image
|
| 53 |
+
try:
|
| 54 |
+
# Strip header if frontend accidentally includes "data:image/jpeg;base64,"
|
| 55 |
+
base64_data = image_base64.split(",")[-1]
|
| 56 |
+
image_data = base64.b64decode(base64_data)
|
| 57 |
+
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 58 |
+
except Exception:
|
| 59 |
+
raise HTTPException(status_code=400, detail="Invalid Base64 image payload.")
|
| 60 |
+
|
| 61 |
+
# Face Extraction
|
| 62 |
+
cropped_base64 = None
|
| 63 |
+
if self.face_extractor:
|
| 64 |
+
image = self.face_extractor.extract_main_face(image)
|
| 65 |
+
if image is None:
|
| 66 |
+
raise HTTPException(status_code=400, detail="No face detected in the image.")
|
| 67 |
+
|
| 68 |
+
# Convert back to base64 for response
|
| 69 |
+
buffered = io.BytesIO()
|
| 70 |
+
image.save(buffered, format="JPEG")
|
| 71 |
+
cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 72 |
+
|
| 73 |
+
# Transform and Infer
|
| 74 |
+
transform = self.transforms_dict[model_type_lower]
|
| 75 |
+
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
|
| 76 |
+
|
| 77 |
+
model = self.models[model_type_lower]
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
with torch.amp.autocast('cuda' if DEVICE == 'cuda' else 'cpu'):
|
| 80 |
+
output = model(input_tensor)
|
| 81 |
+
probabilities = output.squeeze().cpu().to(torch.float32).numpy()
|
| 82 |
+
|
| 83 |
+
# 1. Map the raw array to the order the model was trained on
|
| 84 |
+
raw_traits = ['Extraversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness']
|
| 85 |
+
raw_results = {trait: float(score) for trait, score in zip(raw_traits, probabilities)}
|
| 86 |
+
|
| 87 |
+
# 2. Standardize to the OCEAN format using Pydantic
|
| 88 |
+
standardized_ocean = OCEANTraits(
|
| 89 |
+
Openness=raw_results['Openness'],
|
| 90 |
+
Conscientiousness=raw_results['Conscientiousness'],
|
| 91 |
+
Extraversion=raw_results['Extraversion'],
|
| 92 |
+
Agreeableness=raw_results['Agreeableness'],
|
| 93 |
+
Neuroticism=raw_results['Neuroticism']
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# 3. Return the strictly formatted Pydantic object
|
| 97 |
+
return PredictionResponse(
|
| 98 |
+
model_used=model_type_lower,
|
| 99 |
+
predictions=standardized_ocean,
|
| 100 |
+
cropped_face_base64=cropped_base64
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Global instance to be used across the application
|
| 104 |
+
model_manager = ModelManager()
|