Alief Gilang Permana Putra commited on
Commit
af35098
·
1 Parent(s): fb55838

feat: Add files for inference

Browse files
.dockerignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore python virtual environments and dependencies
2
+ pytorch-cuda/
3
+ venv/
4
+ .venv/
5
+ env/
6
+
7
+ # Ignore python cache files
8
+ __pycache__/
9
+ *.pyc
10
+ *.pyo
11
+ *.pyd
12
+
13
+ # Ignore local environment file (production secrets should be injected via env variables)
14
+ .env
15
+
16
+ # Ignore git folder
17
+ .git/
18
+ .gitignore
19
+
20
+ # Ignore docker files
21
+ Dockerfile
22
+ .dockerignore
.env.example ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Server configuration
2
+ HOST=0.0.0.0
3
+ PORT=8000
4
+ DEBUG_MODE=True
5
+
6
+ # Hugging Face Token
7
+ HF_TOKEN=your_huggingface_access_token_here
Dockerfile ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use a stable, official Python base image
2
+ FROM python:3.10-slim
3
+
4
+ # Set environment variables
5
+ # PYTHONUNBUFFERED=1 ensures console logs are printed immediately
6
+ # PYTHONDONTWRITEBYTECODE=1 prevents python from writing .pyc files
7
+ # PORT=7860 is the default port for Hugging Face Spaces
8
+ # HOME=/home/user sets the home folder for the non-root user
9
+ ENV PYTHONUNBUFFERED=1 \
10
+ PYTHONDONTWRITEBYTECODE=1 \
11
+ PORT=7860 \
12
+ HOST=0.0.0.0 \
13
+ HOME=/home/user
14
+
15
+ # Install system dependencies required by OpenCV, MediaPipe, and other libraries
16
+ RUN apt-get update && apt-get install -y --no-install-recommends \
17
+ build-essential \
18
+ libgl1 \
19
+ libglib2.0-0 \
20
+ libgomp1 \
21
+ sed \
22
+ && rm -rf /var/lib/apt/lists/*
23
+
24
+ # Create a non-root user with UID 1000 (Hugging Face Spaces runs as UID 1000)
25
+ RUN useradd -m -u 1000 user
26
+ WORKDIR /app
27
+
28
+ # Copy requirements.txt first for build caching
29
+ COPY --chown=user:user requirements.txt /app/
30
+
31
+ # Remove the custom local PyTorch wheels from requirements.txt to install standard stable versions
32
+ # from PyPI, supporting both CPU and GPU workloads automatically.
33
+ RUN sed -i '/torch==/d' requirements.txt && \
34
+ sed -i '/torchvision==/d' requirements.txt && \
35
+ pip install --no-cache-dir --upgrade pip && \
36
+ pip install --no-cache-dir torch torchvision && \
37
+ pip install --no-cache-dir -r requirements.txt
38
+
39
+ # Copy the rest of the application files
40
+ COPY --chown=user:user . /app/
41
+
42
+ # Setup a writeable Hugging Face cache directory inside the home folder of user 1000
43
+ RUN mkdir -p /home/user/.cache/huggingface && chown -R user:user /home/user
44
+
45
+ # Switch to the non-root user
46
+ USER user
47
+
48
+ # Expose the default port (Hugging Face Spaces automatically forwards traffic to 7860)
49
+ EXPOSE 7860
50
+
51
+ # Run the FastAPI server
52
+ CMD ["python", "main.py"]
api/endpoints/predict.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from schemas.predict import InferenceRequest, PredictionResponse
3
+ from schemas.system import ModelsListResponse
4
+ from services.model_manager import model_manager
5
+
6
+ router = APIRouter(tags=["Inference"])
7
+
8
+ @router.get("/models", response_model=ModelsListResponse)
9
+ async def list_models():
10
+ """List all available models loaded in memory for inference."""
11
+ return {
12
+ "available_models": list(model_manager.model_configs.values())
13
+ }
14
+
15
+ @router.post("/predict", response_model=PredictionResponse)
16
+ async def predict_personality(request: InferenceRequest):
17
+ """Predict Big Five personality traits from a base64 encoded face image."""
18
+ return model_manager.predict(request.model_type, request.image_base64)
api/endpoints/system.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from fastapi import APIRouter
3
+ from schemas.system import MetadataResponse, HealthResponse
4
+ from services.model_manager import model_manager, DEVICE
5
+
6
+ router = APIRouter(tags=["System"])
7
+
8
+ @router.get("/", response_model=MetadataResponse)
9
+ async def root():
10
+ """Standard root endpoint providing API metadata."""
11
+ with open("config/metadata.json", "r") as f:
12
+ metadata = json.load(f)
13
+ metadata["documentation"] = "/docs"
14
+ return metadata
15
+
16
+ @router.get("/health", response_model=HealthResponse)
17
+ async def health_check():
18
+ """API Health check"""
19
+ return {
20
+ "status": "healthy",
21
+ "device": DEVICE,
22
+ "models_loaded": list(model_manager.models.keys()),
23
+ "port": "auto"
24
+ }
25
+
api/router.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from api.endpoints import system, predict
3
+
4
+ api_router = APIRouter()
5
+
6
+ api_router.include_router(system.router)
7
+ api_router.include_router(predict.router)
assets/blaze_face_short_range.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4578f35940bf5a1a655214a1cce5cab13eba73c1297cd78e1a04c2380b0152f
3
+ size 229746
config/metadata.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "api_name": "Big Five Personality Inference API",
3
+ "description": "API for predicting Big Five personality traits (OCEAN) from facial images using deep learning vision models",
4
+ "version": "1.2.0",
5
+ "status": "online"
6
+ }
config/models.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "vit-b16-augreg-in21k",
4
+ "name": "ViT-B/16 AugReg (IN21k) - Arch Tuning + Augmentation",
5
+ "description": "ViT Base patch 16 ArchTuning + AugReg model Trained in 21k images",
6
+ "repo_id": "lyfesan/vit_base_patch16_224_augreg_in21k_Run_D_The_Ultimate_bigfive"
7
+ },
8
+ {
9
+ "id": "vit-b16-augreg-in21k-ft1k",
10
+ "name": "ViT-B/16 AugReg (IN21k+1k) - Arch Tuning + Augmentation",
11
+ "description": "ViT Base patch 16 ArchTuning + AugReg model Trained in 21k images finetuned in 1k images",
12
+ "repo_id": "lyfesan/vit_base_patch16_224_augreg_in21k_ft_in1k_Run_D_The_Ultimate_bigfive"
13
+ },
14
+ {
15
+ "id": "swinv2-w12-16-archtuning-in22k-ft1k",
16
+ "name": "SwinV2-B w12-16 (IN22k+1k) - Arch Tuning",
17
+ "description": "SwinV2 window12-16 ArchTuning model Trained in 22k images finetuned in 1k images",
18
+ "repo_id": "lyfesan/swinv2_base_window12to16_192to256_ms_in22k_ft_in1k_Run_B_Arch_Tuning_bigfive"
19
+ },
20
+ {
21
+ "id": "swinv2-w16-archtuning-in1k",
22
+ "name": "SwinV2-B w16 (IN1k) - Arch Tuning",
23
+ "description": "SwinV2 window16 ArchTuning model trained in 1k images",
24
+ "repo_id": "lyfesan/swinv2_base_window16_256_ms_in1k_Run_B_Arch_Tuning_bigfive"
25
+ }
26
+ ]
core/config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import login
3
+ from dotenv import load_dotenv
4
+
5
+ # Load environment variables
6
+ load_dotenv(override=True)
7
+
8
+ HOST = os.getenv("HOST", "0.0.0.0")
9
+ PORT = int(os.getenv("PORT", 8000))
10
+ DEBUG_MODE = os.getenv("DEBUG_MODE", "False").lower() in ("true", "1", "t")
11
+ HF_TOKEN = os.getenv("HF_TOKEN")
12
+
13
+ # Authenticate with Hugging Face
14
+ if HF_TOKEN and HF_TOKEN != "your_huggingface_access_token_here":
15
+ print("Logging into Hugging Face Hub...")
16
+ login(token=HF_TOKEN)
core/exceptions.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Request
2
+ from fastapi.responses import JSONResponse
3
+ from starlette.exceptions import HTTPException as StarletteHTTPException
4
+
5
+ async def custom_404_handler(request: Request, exc: StarletteHTTPException):
6
+ """Custom handler to format 404 errors cleanly."""
7
+ if exc.status_code == 404:
8
+ return JSONResponse(
9
+ status_code=404,
10
+ content={
11
+ "error": "Endpoint not found",
12
+ "path": request.url.path,
13
+ "message": "Please check the URL or visit /docs for available endpoints."
14
+ }
15
+ )
16
+ return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+ import json
3
+ from fastapi import FastAPI
4
+ from starlette.exceptions import HTTPException as StarletteHTTPException
5
+
6
+ from core.config import HOST, PORT, DEBUG_MODE
7
+ from core.exceptions import custom_404_handler
8
+ from api.router import api_router
9
+ from services.model_manager import model_manager, DEVICE
10
+
11
+ # Load metadata from config file
12
+ with open("config/metadata.json", "r") as f:
13
+ METADATA = json.load(f)
14
+
15
+ @asynccontextmanager
16
+ async def lifespan(app: FastAPI):
17
+ print("Downloading/Loading models into VRAM (this takes a moment on first run)...")
18
+
19
+ try:
20
+ with open("config/models.json", "r") as f:
21
+ models_config = json.load(f)
22
+
23
+ for model_info in models_config:
24
+ model_manager.load_hf_model_pipeline(
25
+ model_info["id"],
26
+ model_info["repo_id"],
27
+ model_info=model_info
28
+ )
29
+ except FileNotFoundError:
30
+ print("⚠️ models.json not found in config/. No models loaded automatically.")
31
+
32
+ yield
33
+ print("Shutting down API and releasing resources...")
34
+ model_manager.models.clear()
35
+ model_manager.transforms_dict.clear()
36
+ if hasattr(model_manager, "model_configs"):
37
+ model_manager.model_configs.clear()
38
+
39
+ app = FastAPI(
40
+ title=METADATA["api_name"],
41
+ description=METADATA["description"],
42
+ version=METADATA["version"],
43
+ debug=DEBUG_MODE,
44
+ lifespan=lifespan,
45
+ )
46
+ print(f"API Engine initialized on: {DEVICE.upper()}")
47
+
48
+ # Register exception handlers
49
+ app.add_exception_handler(StarletteHTTPException, custom_404_handler)
50
+
51
+ # Include routers
52
+ app.include_router(api_router)
53
+
54
+ if __name__ == "__main__":
55
+ import uvicorn
56
+ uvicorn.run(app, host=HOST, port=PORT)
requirements.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.6.1
2
+ aiohttp==3.13.5
3
+ aiosignal==1.4.0
4
+ annotated-doc==0.0.4
5
+ annotated-types==0.7.0
6
+ anyio==4.13.0
7
+ attrs==26.1.0
8
+ certifi==2026.4.22
9
+ charset-normalizer==3.4.7
10
+ click==8.3.3
11
+ colorama==0.4.6
12
+ contourpy==1.3.3
13
+ cycler==0.12.1
14
+ datasets==4.8.5
15
+ dill==0.4.1
16
+ fastapi==0.136.1
17
+ filelock==3.25.2
18
+ fonttools==4.62.1
19
+ frozenlist==1.8.0
20
+ fsspec==2026.2.0
21
+ h11==0.16.0
22
+ hf-xet==1.5.0
23
+ httpcore==1.0.9
24
+ httpx==0.28.1
25
+ huggingface_hub==1.14.0
26
+ idna==3.13
27
+ Jinja2==3.1.6
28
+ joblib==1.5.3
29
+ kiwisolver==1.5.0
30
+ markdown-it-py==4.1.0
31
+ MarkupSafe==3.0.3
32
+ matplotlib==3.10.9
33
+ mdurl==0.1.2
34
+ mediapipe==0.10.35
35
+ mpmath==1.3.0
36
+ multidict==6.7.1
37
+ multiprocess==0.70.19
38
+ networkx==3.6.1
39
+ numpy==2.4.3
40
+ packaging==26.2
41
+ pandas==3.0.2
42
+ pillow==12.1.1
43
+ propcache==0.4.1
44
+ pyarrow==24.0.0
45
+ pydantic==2.13.4
46
+ pydantic_core==2.46.4
47
+ Pygments==2.20.0
48
+ pyparsing==3.3.2
49
+ python-dateutil==2.9.0.post0
50
+ python-dotenv==1.2.2
51
+ python-multipart==0.0.27
52
+ PyYAML==6.0.3
53
+ requests==2.33.1
54
+ rich==15.0.0
55
+ safetensors==0.7.0
56
+ scikit-learn==1.8.0
57
+ scipy==1.17.1
58
+ seaborn==0.13.2
59
+ setuptools==70.2.0
60
+ shellingham==1.5.4
61
+ six==1.17.0
62
+ starlette==1.0.0
63
+ sympy==1.14.0
64
+ threadpoolctl==3.6.0
65
+ timm==1.0.26
66
+ torch==2.11.0+cu130
67
+ torchvision==0.26.0+cu130
68
+ tqdm==4.67.3
69
+ typer==0.25.1
70
+ typing-inspection==0.4.2
71
+ typing_extensions==4.15.0
72
+ tzdata==2026.2
73
+ urllib3==2.6.3
74
+ uvicorn==0.46.0
75
+ xxhash==3.7.0
76
+ yarl==1.23.0
schemas/predict.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional
3
+
4
+
5
+ class InferenceRequest(BaseModel):
6
+ """Request body for Big Five personality trait prediction."""
7
+ model_type: str = Field(
8
+ ...,
9
+ description="The ID of the vision model to use for inference",
10
+ examples=["swinv2", "vit", "pvtv2"],
11
+ )
12
+ image_base64: str = Field(
13
+ ...,
14
+ description="Base64-encoded image string (JPEG/PNG). Data URI prefix is optional.",
15
+ examples=["iVBORw0KGgoAAAANSUhEUg..."],
16
+ )
17
+
18
+
19
+ class OCEANTraits(BaseModel):
20
+ """Big Five (OCEAN) personality trait scores, each ranging from 0.0 to 1.0."""
21
+ Openness: float = Field(..., ge=0.0, le=1.0, description="Openness to experience", examples=[0.62])
22
+ Conscientiousness: float = Field(..., ge=0.0, le=1.0, description="Conscientiousness", examples=[0.63])
23
+ Extraversion: float = Field(..., ge=0.0, le=1.0, description="Extraversion", examples=[0.54])
24
+ Agreeableness: float = Field(..., ge=0.0, le=1.0, description="Agreeableness", examples=[0.63])
25
+ Neuroticism: float = Field(..., ge=0.0, le=1.0, description="Neuroticism", examples=[0.60])
26
+
27
+
28
+ class PredictionResponse(BaseModel):
29
+ """Response containing the model used, predicted OCEAN traits, and the cropped face image."""
30
+ model_used: str = Field(..., description="The ID of the model that produced the prediction", examples=["swinv2"])
31
+ predictions: OCEANTraits = Field(..., description="Predicted Big Five personality trait scores")
32
+ cropped_face_base64: Optional[str] = Field(None, description="Base64 encoded cropped face image, if face extraction was used.", examples=["/9j/4AAQSkZJRgABAQEASABIAAD/4..."])
schemas/system.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List
3
+
4
+
5
+ class MetadataResponse(BaseModel):
6
+ """API metadata returned by the root endpoint."""
7
+ api_name: str = Field(..., description="Name of the API", examples=["Big Five Personality Inference API"])
8
+ description: str = Field(..., description="Brief description of the API", examples=["API for predicting Big Five personality traits"])
9
+ version: str = Field(..., description="Current API version", examples=["1.2.0"])
10
+ status: str = Field(..., description="Current API status", examples=["online"])
11
+ documentation: str = Field(..., description="Path to interactive API docs", examples=["/docs"])
12
+
13
+
14
+ class HealthResponse(BaseModel):
15
+ """Health check response."""
16
+ status: str = Field(..., description="Health status of the API", examples=["healthy"])
17
+ device: str = Field(..., description="Compute device in use", examples=["cuda"])
18
+ models_loaded: List[str] = Field(..., description="List of model IDs currently loaded in memory", examples=[["swinv2", "vit", "pvtv2"]])
19
+ port: str = Field(..., description="Port configuration", examples=["auto"])
20
+
21
+
22
+ class ModelDetail(BaseModel):
23
+ """Details of an available inference model."""
24
+ id: str = Field(..., description="The ID of the model to be used in predictions", examples=["swinv2"])
25
+ name: str = Field(..., description="Human-readable name of the model", examples=["Swin Transformer V2"])
26
+ description: str = Field(..., description="Description of the model", examples=["SwinV2 Base model optimized for Big Five personality traits prediction"])
27
+ repo_id: str = Field(..., description="Hugging Face model repository ID", examples=["lyfesan/swinv2_base_..."])
28
+
29
+ class ModelsListResponse(BaseModel):
30
+ """List of available inference models."""
31
+ available_models: List[ModelDetail] = Field(..., description="List of models available for inference")
services/face_extractor.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from PIL import Image
4
+ import mediapipe as mp
5
+ from mediapipe.tasks import python
6
+ from mediapipe.tasks.python import vision
7
+
8
+ class FaceExtractor:
9
+ def __init__(self, model_path: str = "assets/blaze_face_short_range.tflite"):
10
+ self.model_path = model_path
11
+ base_options = python.BaseOptions(model_asset_path=self.model_path)
12
+ options = vision.FaceDetectorOptions(
13
+ base_options=base_options,
14
+ running_mode=vision.RunningMode.IMAGE,
15
+ min_detection_confidence=0.70
16
+ )
17
+ self.detector = vision.FaceDetector.create_from_options(options)
18
+ self.offset_percentage = 0.30
19
+
20
+ def extract_main_face(self, pil_image: Image.Image) -> Image.Image:
21
+ """
22
+ Detects faces in the given PIL Image, scores them to find the main face,
23
+ and returns the cropped main face. Returns None if no face is detected.
24
+ """
25
+ # Convert PIL Image to numpy array (RGB)
26
+ frame = np.array(pil_image)
27
+
28
+ img_h, img_w, _ = frame.shape
29
+ frame_cx, frame_cy = img_w / 2, img_h / 2
30
+
31
+ # Mediapipe requires the image to be in ImageFormat.SRGB
32
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
33
+ results = self.detector.detect(mp_image)
34
+
35
+ if not results.detections:
36
+ return None
37
+
38
+ best_face_bbox = None
39
+ highest_score = -float('inf')
40
+
41
+ for detection in results.detections:
42
+ bbox = detection.bounding_box
43
+ confidence = detection.categories[0].score
44
+ x, y, w, h = bbox.origin_x, bbox.origin_y, bbox.width, bbox.height
45
+ face_cx, face_cy = x + (w / 2), y + (h / 2)
46
+
47
+ area = w * h
48
+ distance_to_center = math.sqrt((frame_cx - face_cx)**2 + (frame_cy - face_cy)**2)
49
+ score = (area * confidence) - (distance_to_center * 50)
50
+
51
+ if score > highest_score:
52
+ highest_score = score
53
+ best_face_bbox = (x, y, w, h)
54
+
55
+ if not best_face_bbox:
56
+ return None
57
+
58
+ # Crop with offset
59
+ x, y, w, h = best_face_bbox
60
+ offset_w = int(w * self.offset_percentage)
61
+ offset_h = int(h * self.offset_percentage)
62
+
63
+ new_x = max(0, x - offset_w)
64
+ new_y = max(0, y - offset_h)
65
+ new_w = min(img_w - new_x, w + (2 * offset_w))
66
+ new_h = min(img_h - new_y, h + (2 * offset_h))
67
+
68
+ cropped_face_np = frame[new_y:new_y+new_h, new_x:new_x+new_w]
69
+
70
+ # Convert back to PIL Image
71
+ if cropped_face_np.size > 0:
72
+ return Image.fromarray(cropped_face_np)
73
+
74
+ return None
services/inference.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import timm
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+ class BigFiveRegressor(nn.Module, PyTorchModelHubMixin):
6
+ def __init__(self, timm_name, use_complex_head=True):
7
+ super().__init__()
8
+ self.backbone = timm.create_model(timm_name, pretrained=False, num_classes=0)
9
+ num_features = self.backbone.num_features
10
+
11
+ if use_complex_head:
12
+ self.regression_head = nn.Sequential(
13
+ nn.Linear(num_features, 512),
14
+ nn.GELU(),
15
+ nn.Dropout(0.3),
16
+ nn.Linear(512, 5),
17
+ nn.Sigmoid()
18
+ )
19
+ else:
20
+ self.regression_head = nn.Sequential(
21
+ nn.Linear(num_features, 5),
22
+ nn.Sigmoid()
23
+ )
24
+
25
+ def forward(self, x):
26
+ features = self.backbone(x)
27
+ return self.regression_head(features)
services/model_manager.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import base64
3
+ import torch
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from fastapi import HTTPException
7
+ from services.inference import BigFiveRegressor
8
+ from schemas.predict import OCEANTraits, PredictionResponse
9
+ from services.face_extractor import FaceExtractor
10
+
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+ # DEVICE = "cpu"
13
+ class ModelManager:
14
+ def __init__(self):
15
+ self.models = {}
16
+ self.transforms_dict = {}
17
+ self.model_configs = {}
18
+ try:
19
+ self.face_extractor = FaceExtractor()
20
+ except Exception as e:
21
+ print(f"Warning: Failed to initialize FaceExtractor: {e}")
22
+ self.face_extractor = None
23
+
24
+ def load_hf_model_pipeline(self, model_key: str, repo_id: str, model_info: dict = None):
25
+ """Loads model from Hugging Face and creates its specific preprocessing transform."""
26
+ try:
27
+ model = BigFiveRegressor.from_pretrained(repo_id)
28
+ model.to(DEVICE)
29
+ model.eval()
30
+
31
+ # SwinV2 uses 256x256, ViT/PVTv2 use 224x224
32
+ IMG_SIZE = 256 if 'swinv2' in model_key else 224
33
+ transform = transforms.Compose([
34
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
37
+ ])
38
+
39
+ self.models[model_key] = model
40
+ self.transforms_dict[model_key] = transform
41
+ if model_info:
42
+ self.model_configs[model_key] = model_info
43
+ print(f"✅ Loaded {model_key.upper()} from {repo_id}")
44
+ except Exception as e:
45
+ print(f"⚠️ Failed to load {model_key} from {repo_id}. Error: {e}")
46
+
47
+ def predict(self, model_type: str, image_base64: str) -> PredictionResponse:
48
+ model_type_lower = model_type.lower()
49
+ if model_type_lower not in self.models:
50
+ raise HTTPException(status_code=400, detail=f"Invalid model type. Choose from: {list(self.models.keys())}")
51
+
52
+ # Decode Base64 to Image
53
+ try:
54
+ # Strip header if frontend accidentally includes "data:image/jpeg;base64,"
55
+ base64_data = image_base64.split(",")[-1]
56
+ image_data = base64.b64decode(base64_data)
57
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
58
+ except Exception:
59
+ raise HTTPException(status_code=400, detail="Invalid Base64 image payload.")
60
+
61
+ # Face Extraction
62
+ cropped_base64 = None
63
+ if self.face_extractor:
64
+ image = self.face_extractor.extract_main_face(image)
65
+ if image is None:
66
+ raise HTTPException(status_code=400, detail="No face detected in the image.")
67
+
68
+ # Convert back to base64 for response
69
+ buffered = io.BytesIO()
70
+ image.save(buffered, format="JPEG")
71
+ cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
72
+
73
+ # Transform and Infer
74
+ transform = self.transforms_dict[model_type_lower]
75
+ input_tensor = transform(image).unsqueeze(0).to(DEVICE)
76
+
77
+ model = self.models[model_type_lower]
78
+ with torch.no_grad():
79
+ with torch.amp.autocast('cuda' if DEVICE == 'cuda' else 'cpu'):
80
+ output = model(input_tensor)
81
+ probabilities = output.squeeze().cpu().to(torch.float32).numpy()
82
+
83
+ # 1. Map the raw array to the order the model was trained on
84
+ raw_traits = ['Extraversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness']
85
+ raw_results = {trait: float(score) for trait, score in zip(raw_traits, probabilities)}
86
+
87
+ # 2. Standardize to the OCEAN format using Pydantic
88
+ standardized_ocean = OCEANTraits(
89
+ Openness=raw_results['Openness'],
90
+ Conscientiousness=raw_results['Conscientiousness'],
91
+ Extraversion=raw_results['Extraversion'],
92
+ Agreeableness=raw_results['Agreeableness'],
93
+ Neuroticism=raw_results['Neuroticism']
94
+ )
95
+
96
+ # 3. Return the strictly formatted Pydantic object
97
+ return PredictionResponse(
98
+ model_used=model_type_lower,
99
+ predictions=standardized_ocean,
100
+ cropped_face_base64=cropped_base64
101
+ )
102
+
103
+ # Global instance to be used across the application
104
+ model_manager = ModelManager()