Spaces:
Running
Running
upload 32 files for the ml
Browse files- Dockerfile.hf +35 -0
- app/__init__.py +762 -0
- app/__pycache__/__init__.cpython-312.pyc +0 -0
- app/__pycache__/db_logger.cpython-312.pyc +0 -0
- app/__pycache__/error_handlers.cpython-312.pyc +0 -0
- app/__pycache__/image_storage.cpython-312.pyc +0 -0
- app/__pycache__/model_loader.cpython-312.pyc +0 -0
- app/__pycache__/rate_limiter.cpython-312.pyc +0 -0
- app/__pycache__/utils.cpython-312.pyc +0 -0
- app/__pycache__/validators.cpython-312.pyc +0 -0
- app/__pycache__/vit_utils.cpython-312.pyc +0 -0
- app/db_logger.py +286 -0
- app/error_handlers.py +64 -0
- app/image_cleanup.py +200 -0
- app/image_storage.py +124 -0
- app/model_loader.py +187 -0
- app/rate_limiter.py +89 -0
- app/utils.py +180 -0
- app/validators.py +116 -0
- app/vit_utils.py +323 -0
- entrypoint_hf.sh +79 -0
- main.py +41 -0
- requirements.txt +18 -0
Dockerfile.hf
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for Hugging Face Spaces
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# System dependencies for opencv and runtime model download
|
| 8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 9 |
+
build-essential \
|
| 10 |
+
libgl1 \
|
| 11 |
+
libglib2.0-0 \
|
| 12 |
+
curl \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Copy requirements
|
| 16 |
+
COPY requirements.txt /app/requirements.txt
|
| 17 |
+
|
| 18 |
+
# Upgrade pip
|
| 19 |
+
RUN python -m pip install --upgrade pip setuptools wheel
|
| 20 |
+
|
| 21 |
+
# Install requirements
|
| 22 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 23 |
+
|
| 24 |
+
# Copy app code
|
| 25 |
+
COPY . /app/
|
| 26 |
+
|
| 27 |
+
# Make entrypoint executable
|
| 28 |
+
RUN chmod +x /app/scripts/entrypoint.sh
|
| 29 |
+
|
| 30 |
+
# Hugging Face Spaces uses port 7860
|
| 31 |
+
EXPOSE 7860
|
| 32 |
+
|
| 33 |
+
# Use entrypoint script
|
| 34 |
+
ENTRYPOINT ["/app/scripts/entrypoint.sh"]
|
| 35 |
+
|
app/__init__.py
ADDED
|
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/__init__.py
|
| 2 |
+
import os
|
| 3 |
+
import datetime
|
| 4 |
+
import csv
|
| 5 |
+
import traceback
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from flask import Flask, request, jsonify
|
| 9 |
+
from flask_cors import CORS
|
| 10 |
+
from werkzeug.utils import secure_filename
|
| 11 |
+
from werkzeug.exceptions import RequestEntityTooLarge
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ----------------------------
|
| 15 |
+
# Module-level config (deterministic)
|
| 16 |
+
# ----------------------------
|
| 17 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 18 |
+
TMP_DIR_DEFAULT = os.path.join(PROJECT_ROOT, "tmp")
|
| 19 |
+
IMAGES_DIR_DEFAULT = os.path.join(PROJECT_ROOT, "images")
|
| 20 |
+
LOG_CSV = os.path.join(PROJECT_ROOT, "predictions_log.csv")
|
| 21 |
+
DB_PATH = os.path.join(PROJECT_ROOT, "predictions.db")
|
| 22 |
+
|
| 23 |
+
# App-level defaults (can be overridden via app.config)
|
| 24 |
+
DEFAULTS = {
|
| 25 |
+
"MIN_CONFIDENCE": 0.18, # Lowered to 0.18 for ambiguous cases (was 0.20, originally 0.5)
|
| 26 |
+
"MAX_FILE_SIZE": 5 * 1024 * 1024, # 5 MB
|
| 27 |
+
"TMP_DIR": TMP_DIR_DEFAULT,
|
| 28 |
+
"IMAGES_DIR": IMAGES_DIR_DEFAULT,
|
| 29 |
+
"ALLOWED_EXT": (".jpg", ".jpeg", ".png"),
|
| 30 |
+
"CORS_ORIGINS": "*", # Can be overridden for production
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Ensure directories exist
|
| 34 |
+
os.makedirs(DEFAULTS["TMP_DIR"], exist_ok=True)
|
| 35 |
+
os.makedirs(DEFAULTS["IMAGES_DIR"], exist_ok=True)
|
| 36 |
+
|
| 37 |
+
# Ensure CSV header exists (helpful for older logs)
|
| 38 |
+
if not os.path.exists(LOG_CSV):
|
| 39 |
+
try:
|
| 40 |
+
with open(LOG_CSV, "w", newline="", encoding="utf-8") as f:
|
| 41 |
+
writer = csv.writer(f)
|
| 42 |
+
writer.writerow(["timestamp", "filename", "emotion", "confidence"])
|
| 43 |
+
except Exception:
|
| 44 |
+
# Non-fatal — keep module import light.
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ----------------------------
|
| 49 |
+
# Factory
|
| 50 |
+
# ----------------------------
|
| 51 |
+
def create_app(config: dict | None = None):
|
| 52 |
+
"""
|
| 53 |
+
Create and return the Flask application.
|
| 54 |
+
Heavy imports (model loading, db init) are performed inside this factory
|
| 55 |
+
so importing modules from scripts/tests doesn't trigger expensive work.
|
| 56 |
+
"""
|
| 57 |
+
# Merge defaults with provided config
|
| 58 |
+
cfg = DEFAULTS.copy()
|
| 59 |
+
if config:
|
| 60 |
+
cfg.update(config)
|
| 61 |
+
|
| 62 |
+
app = Flask(__name__)
|
| 63 |
+
|
| 64 |
+
# CORS configuration - allow config override
|
| 65 |
+
cors_origins = cfg.get("CORS_ORIGINS", DEFAULTS["CORS_ORIGINS"])
|
| 66 |
+
if cors_origins == "*":
|
| 67 |
+
CORS(app, resources={r"/*": {"origins": "*"}})
|
| 68 |
+
else:
|
| 69 |
+
# Allow list of origins
|
| 70 |
+
origins_list = cors_origins.split(",") if isinstance(cors_origins, str) else cors_origins
|
| 71 |
+
CORS(app, resources={r"/*": {"origins": origins_list}})
|
| 72 |
+
|
| 73 |
+
# ---------- file logging setup (after app created) ----------
|
| 74 |
+
LOG_DIR = os.path.join(PROJECT_ROOT, "logs")
|
| 75 |
+
try:
|
| 76 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
| 77 |
+
except Exception:
|
| 78 |
+
# If logs dir cannot be created, continue; app.logger will still work to stdout
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
log_path = os.path.join(LOG_DIR, "app.log")
|
| 82 |
+
try:
|
| 83 |
+
file_handler = logging.FileHandler(log_path)
|
| 84 |
+
file_handler.setLevel(logging.INFO) # change to ERROR if you prefer
|
| 85 |
+
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(module)s: %(message)s")
|
| 86 |
+
file_handler.setFormatter(formatter)
|
| 87 |
+
|
| 88 |
+
# avoid adding duplicate handlers when reloading
|
| 89 |
+
abs_log_path = os.path.abspath(log_path)
|
| 90 |
+
if not any(
|
| 91 |
+
isinstance(h, logging.FileHandler) and getattr(h, "baseFilename", None) == abs_log_path
|
| 92 |
+
for h in app.logger.handlers
|
| 93 |
+
):
|
| 94 |
+
app.logger.addHandler(file_handler)
|
| 95 |
+
# set app logger level (don't lower if already configured higher)
|
| 96 |
+
app.logger.setLevel(logging.INFO)
|
| 97 |
+
except Exception:
|
| 98 |
+
# If logging can't be configured, keep going — logger will fallback to default handlers.
|
| 99 |
+
app.logger.exception("Failed to configure file logging")
|
| 100 |
+
|
| 101 |
+
# Apply config to app
|
| 102 |
+
app.config["MAX_CONTENT_LENGTH"] = cfg["MAX_FILE_SIZE"]
|
| 103 |
+
app.config["TMP_DIR"] = cfg["TMP_DIR"]
|
| 104 |
+
app.config["IMAGES_DIR"] = cfg.get("IMAGES_DIR", DEFAULTS["IMAGES_DIR"])
|
| 105 |
+
app.config["ALLOWED_EXT"] = cfg["ALLOWED_EXT"]
|
| 106 |
+
app.config["MIN_CONFIDENCE"] = cfg["MIN_CONFIDENCE"]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Ensure tmp directory exists (again, per app)
|
| 110 |
+
os.makedirs(app.config["TMP_DIR"], exist_ok=True)
|
| 111 |
+
|
| 112 |
+
# Local (deferred) imports — avoid import-time side effects
|
| 113 |
+
from .model_loader import load_emotion_model
|
| 114 |
+
from .db_logger import init_db, log_prediction, get_metrics, tail_rows, get_total_count, delete_prediction
|
| 115 |
+
from .utils import preprocess_face
|
| 116 |
+
from .image_storage import save_image, get_image_path, ensure_images_dir
|
| 117 |
+
from .validators import validate_image_file, validate_pagination_params, validate_confidence_range
|
| 118 |
+
from .rate_limiter import detect_limiter, logs_limiter, images_limiter, get_client_identifier
|
| 119 |
+
|
| 120 |
+
# Initialize DB
|
| 121 |
+
try:
|
| 122 |
+
init_db(DB_PATH)
|
| 123 |
+
app.logger.info("Initialized SQLite DB at %s", DB_PATH)
|
| 124 |
+
except Exception:
|
| 125 |
+
app.logger.exception("Failed to initialize DB at startup")
|
| 126 |
+
|
| 127 |
+
# Load model & labels. Keep these local to the factory (no module-level side effects).
|
| 128 |
+
# We'll load models on-demand based on request parameter
|
| 129 |
+
base_model = None
|
| 130 |
+
base_labels = None
|
| 131 |
+
base_model_version = "unknown"
|
| 132 |
+
base_model_type = "unknown"
|
| 133 |
+
finetuned_model = None
|
| 134 |
+
finetuned_labels = None
|
| 135 |
+
finetuned_model_version = "unknown"
|
| 136 |
+
finetuned_model_type = "unknown"
|
| 137 |
+
|
| 138 |
+
# Load base model by default
|
| 139 |
+
try:
|
| 140 |
+
# load_emotion_model returns (model, labels, version, model_type)
|
| 141 |
+
res = load_emotion_model(force_model='base')
|
| 142 |
+
if isinstance(res, tuple) and len(res) == 4:
|
| 143 |
+
base_model, base_labels, base_model_version, base_model_type = res
|
| 144 |
+
elif isinstance(res, tuple) and len(res) == 3:
|
| 145 |
+
base_model, base_labels, base_model_version = res
|
| 146 |
+
base_model_type = "keras" # Default for old format
|
| 147 |
+
elif isinstance(res, tuple) and len(res) == 2:
|
| 148 |
+
base_model, base_labels = res
|
| 149 |
+
base_model_type = "keras" # Default for old format
|
| 150 |
+
else:
|
| 151 |
+
# Unexpected return shape - try to be permissive
|
| 152 |
+
try:
|
| 153 |
+
base_model = res
|
| 154 |
+
base_labels = None
|
| 155 |
+
base_model_type = "keras"
|
| 156 |
+
except Exception:
|
| 157 |
+
base_model = None
|
| 158 |
+
base_labels = None
|
| 159 |
+
base_model_type = "unknown"
|
| 160 |
+
app.logger.info("Base model loaded: %s (version=%s, type=%s)", bool(base_model), base_model_version, base_model_type)
|
| 161 |
+
print(f"[APP] Base model loaded: type={base_model_type}, version={base_model_version}, labels={len(base_labels) if base_labels else 0}")
|
| 162 |
+
except Exception as exc:
|
| 163 |
+
app.logger.exception("Base model failed to load at startup: %s", exc)
|
| 164 |
+
base_model = None
|
| 165 |
+
base_labels = None
|
| 166 |
+
base_model_version = "failed"
|
| 167 |
+
base_model_type = "unknown"
|
| 168 |
+
|
| 169 |
+
# Try to load fine-tuned model
|
| 170 |
+
try:
|
| 171 |
+
res = load_emotion_model(force_model='fine-tuned')
|
| 172 |
+
if isinstance(res, tuple) and len(res) == 4:
|
| 173 |
+
finetuned_model, finetuned_labels, finetuned_model_version, finetuned_model_type = res
|
| 174 |
+
elif isinstance(res, tuple) and len(res) == 3:
|
| 175 |
+
finetuned_model, finetuned_labels, finetuned_model_version = res
|
| 176 |
+
finetuned_model_type = "keras"
|
| 177 |
+
elif isinstance(res, tuple) and len(res) == 2:
|
| 178 |
+
finetuned_model, finetuned_labels = res
|
| 179 |
+
finetuned_model_type = "keras"
|
| 180 |
+
app.logger.info("Asripa model loaded: %s (version=%s, type=%s)", bool(finetuned_model), finetuned_model_version, finetuned_model_type)
|
| 181 |
+
print(f"[APP] Asripa model loaded: type={finetuned_model_type}, version={finetuned_model_version}")
|
| 182 |
+
except Exception as exc:
|
| 183 |
+
app.logger.warning("Asripa model not available: %s", exc)
|
| 184 |
+
finetuned_model = None
|
| 185 |
+
finetuned_labels = None
|
| 186 |
+
finetuned_model_version = "not-available"
|
| 187 |
+
finetuned_model_type = "unknown"
|
| 188 |
+
|
| 189 |
+
# Store in app.config - default to base model
|
| 190 |
+
app.config["BASE_MODEL"] = base_model
|
| 191 |
+
app.config["BASE_LABELS"] = base_labels
|
| 192 |
+
app.config["BASE_MODEL_VERSION"] = base_model_version
|
| 193 |
+
app.config["BASE_MODEL_TYPE"] = base_model_type
|
| 194 |
+
app.config["FINETUNED_MODEL"] = finetuned_model
|
| 195 |
+
app.config["FINETUNED_LABELS"] = finetuned_labels
|
| 196 |
+
app.config["FINETUNED_MODEL_VERSION"] = finetuned_model_version
|
| 197 |
+
app.config["FINETUNED_MODEL_TYPE"] = finetuned_model_type
|
| 198 |
+
# Default to base model
|
| 199 |
+
app.config["MODEL"] = base_model
|
| 200 |
+
app.config["LABELS"] = base_labels
|
| 201 |
+
app.config["MODEL_VERSION"] = base_model_version
|
| 202 |
+
app.config["MODEL_TYPE"] = base_model_type
|
| 203 |
+
|
| 204 |
+
# ----------------------------
|
| 205 |
+
# Error handlers (import before routes to ensure proper handling)
|
| 206 |
+
# ----------------------------
|
| 207 |
+
from .error_handlers import register_error_handlers, APIError, ValidationError, NotFoundError, ServiceUnavailableError
|
| 208 |
+
|
| 209 |
+
register_error_handlers(app)
|
| 210 |
+
|
| 211 |
+
# Make these available in route scope
|
| 212 |
+
globals()['APIError'] = APIError
|
| 213 |
+
globals()['ValidationError'] = ValidationError
|
| 214 |
+
globals()['NotFoundError'] = NotFoundError
|
| 215 |
+
globals()['ServiceUnavailableError'] = ServiceUnavailableError
|
| 216 |
+
|
| 217 |
+
@app.errorhandler(RequestEntityTooLarge)
|
| 218 |
+
def handle_large_file(e):
|
| 219 |
+
return jsonify({"error": "File too large", "max_size_mb": app.config.get("MAX_CONTENT_LENGTH", 5 * 1024 * 1024) / (1024 * 1024)}), 413
|
| 220 |
+
|
| 221 |
+
# ----------------------------
|
| 222 |
+
# Routes
|
| 223 |
+
# ----------------------------
|
| 224 |
+
@app.route("/")
|
| 225 |
+
def index():
|
| 226 |
+
return jsonify({"status": "ok", "message": "Flask backend running"}), 200
|
| 227 |
+
|
| 228 |
+
@app.route("/health", methods=["GET"])
|
| 229 |
+
def health():
|
| 230 |
+
"""
|
| 231 |
+
Lightweight health check endpoint.
|
| 232 |
+
Optimized for speed - minimal checks to avoid timeouts.
|
| 233 |
+
"""
|
| 234 |
+
try:
|
| 235 |
+
# Quick check - don't do expensive operations
|
| 236 |
+
model_loaded = bool(app.config.get("MODEL"))
|
| 237 |
+
model_type = app.config.get("MODEL_TYPE", "unknown")
|
| 238 |
+
model_version = app.config.get("MODEL_VERSION", "unknown")
|
| 239 |
+
|
| 240 |
+
# Get labels count quickly
|
| 241 |
+
labels_obj = app.config.get("LABELS")
|
| 242 |
+
labels_count = len(labels_obj) if labels_obj and hasattr(labels_obj, "__len__") else 0
|
| 243 |
+
|
| 244 |
+
return jsonify(
|
| 245 |
+
{
|
| 246 |
+
"ok": True,
|
| 247 |
+
"model_loaded": model_loaded,
|
| 248 |
+
"model_type": model_type,
|
| 249 |
+
"model_version": model_version,
|
| 250 |
+
"labels_count": labels_count,
|
| 251 |
+
}
|
| 252 |
+
), 200
|
| 253 |
+
except Exception as e:
|
| 254 |
+
# Even if there's an error, return 200 to indicate service is running
|
| 255 |
+
# This prevents false "offline" status
|
| 256 |
+
app.logger.warning(f"Health check error (non-fatal): {e}")
|
| 257 |
+
return jsonify(
|
| 258 |
+
{
|
| 259 |
+
"ok": True,
|
| 260 |
+
"model_loaded": False,
|
| 261 |
+
"model_type": "unknown",
|
| 262 |
+
"model_version": "unknown",
|
| 263 |
+
"labels_count": 0,
|
| 264 |
+
"warning": "Health check had minor issues but service is running",
|
| 265 |
+
}
|
| 266 |
+
), 200
|
| 267 |
+
|
| 268 |
+
@app.route("/metrics")
|
| 269 |
+
def metrics():
|
| 270 |
+
try:
|
| 271 |
+
m = get_metrics(DB_PATH)
|
| 272 |
+
recent = tail_rows(DB_PATH, limit=10)
|
| 273 |
+
return jsonify({"ok": True, "metrics": m, "recent": recent}), 200
|
| 274 |
+
except Exception as exc:
|
| 275 |
+
app.logger.exception("Failed to fetch metrics")
|
| 276 |
+
return jsonify({"error": "Failed to fetch metrics", "details": str(exc)}), 500
|
| 277 |
+
|
| 278 |
+
@app.route("/logs", methods=["GET"])
|
| 279 |
+
def logs():
|
| 280 |
+
"""
|
| 281 |
+
GET /logs?limit=20&offset=0&emotion=happy&min_confidence=0.5&max_confidence=1.0&date_from=2024-01-01&date_to=2024-12-31
|
| 282 |
+
|
| 283 |
+
Returns paginated and filtered logs.
|
| 284 |
+
"""
|
| 285 |
+
# Rate limiting
|
| 286 |
+
client_id = get_client_identifier(request)
|
| 287 |
+
is_allowed, remaining = logs_limiter.is_allowed(client_id)
|
| 288 |
+
if not is_allowed:
|
| 289 |
+
return jsonify({
|
| 290 |
+
"error": "Rate limit exceeded",
|
| 291 |
+
"detail": f"Maximum {logs_limiter.max_requests} requests per {logs_limiter.window_seconds} seconds",
|
| 292 |
+
"retry_after": logs_limiter.window_seconds,
|
| 293 |
+
}), 429
|
| 294 |
+
|
| 295 |
+
try:
|
| 296 |
+
# Validate pagination
|
| 297 |
+
limit, offset, pagination_error = validate_pagination_params(
|
| 298 |
+
request.args.get("limit"),
|
| 299 |
+
request.args.get("offset"),
|
| 300 |
+
)
|
| 301 |
+
if pagination_error:
|
| 302 |
+
return jsonify({"error": pagination_error}), 400
|
| 303 |
+
|
| 304 |
+
# Validate confidence range
|
| 305 |
+
min_confidence, max_confidence, confidence_error = validate_confidence_range(
|
| 306 |
+
request.args.get("min_confidence"),
|
| 307 |
+
request.args.get("max_confidence"),
|
| 308 |
+
)
|
| 309 |
+
if confidence_error:
|
| 310 |
+
return jsonify({"error": confidence_error}), 400
|
| 311 |
+
|
| 312 |
+
# Filters
|
| 313 |
+
emotion_filter = request.args.get("emotion", None)
|
| 314 |
+
if emotion_filter and emotion_filter.strip():
|
| 315 |
+
emotion_filter = emotion_filter.strip()
|
| 316 |
+
else:
|
| 317 |
+
emotion_filter = None
|
| 318 |
+
|
| 319 |
+
date_from = request.args.get("date_from", None)
|
| 320 |
+
date_to = request.args.get("date_to", None)
|
| 321 |
+
|
| 322 |
+
# Fetch data
|
| 323 |
+
rows = tail_rows(
|
| 324 |
+
DB_PATH,
|
| 325 |
+
limit=limit,
|
| 326 |
+
offset=offset,
|
| 327 |
+
emotion_filter=emotion_filter,
|
| 328 |
+
min_confidence=min_confidence,
|
| 329 |
+
max_confidence=max_confidence,
|
| 330 |
+
date_from=date_from,
|
| 331 |
+
date_to=date_to,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
total = get_total_count(
|
| 335 |
+
DB_PATH,
|
| 336 |
+
emotion_filter=emotion_filter,
|
| 337 |
+
min_confidence=min_confidence,
|
| 338 |
+
max_confidence=max_confidence,
|
| 339 |
+
date_from=date_from,
|
| 340 |
+
date_to=date_to,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Convert to list of dicts
|
| 344 |
+
result = []
|
| 345 |
+
for r in rows:
|
| 346 |
+
if len(r) == 6:
|
| 347 |
+
_id, ts, filename, image_path, emotion, confidence = r
|
| 348 |
+
record = {
|
| 349 |
+
"id": _id,
|
| 350 |
+
"ts": ts,
|
| 351 |
+
"filename": filename,
|
| 352 |
+
"image_path": image_path or filename, # Fallback to filename if no image_path
|
| 353 |
+
"emotion": emotion,
|
| 354 |
+
"confidence": confidence,
|
| 355 |
+
}
|
| 356 |
+
elif len(r) == 5:
|
| 357 |
+
_id, ts, filename, emotion, confidence = r
|
| 358 |
+
record = {
|
| 359 |
+
"id": _id,
|
| 360 |
+
"ts": ts,
|
| 361 |
+
"filename": filename,
|
| 362 |
+
"image_path": filename, # Fallback
|
| 363 |
+
"emotion": emotion,
|
| 364 |
+
"confidence": confidence,
|
| 365 |
+
}
|
| 366 |
+
elif len(r) == 4:
|
| 367 |
+
ts, filename, emotion, confidence = r
|
| 368 |
+
record = {
|
| 369 |
+
"ts": ts,
|
| 370 |
+
"filename": filename,
|
| 371 |
+
"image_path": filename, # Fallback
|
| 372 |
+
"emotion": emotion,
|
| 373 |
+
"confidence": confidence,
|
| 374 |
+
}
|
| 375 |
+
else:
|
| 376 |
+
record = {"row": r}
|
| 377 |
+
result.append(record)
|
| 378 |
+
|
| 379 |
+
return jsonify({
|
| 380 |
+
"ok": True,
|
| 381 |
+
"logs": result,
|
| 382 |
+
"pagination": {
|
| 383 |
+
"total": total,
|
| 384 |
+
"limit": limit,
|
| 385 |
+
"offset": offset,
|
| 386 |
+
"has_more": (offset + limit) < total,
|
| 387 |
+
},
|
| 388 |
+
}), 200
|
| 389 |
+
except Exception as exc:
|
| 390 |
+
app.logger.exception("Failed to fetch logs")
|
| 391 |
+
return jsonify({"error": "Failed to fetch logs", "detail": str(exc)}), 500
|
| 392 |
+
|
| 393 |
+
@app.route("/logs/<int:prediction_id>", methods=["DELETE"])
|
| 394 |
+
def delete_log(prediction_id: int):
|
| 395 |
+
"""
|
| 396 |
+
DELETE /logs/<id>
|
| 397 |
+
|
| 398 |
+
Delete a prediction by ID.
|
| 399 |
+
"""
|
| 400 |
+
# Rate limiting
|
| 401 |
+
client_id = get_client_identifier(request)
|
| 402 |
+
is_allowed, remaining = logs_limiter.is_allowed(client_id)
|
| 403 |
+
if not is_allowed:
|
| 404 |
+
return jsonify({
|
| 405 |
+
"error": "Rate limit exceeded",
|
| 406 |
+
"detail": f"Maximum {logs_limiter.max_requests} requests per {logs_limiter.window_seconds} seconds",
|
| 407 |
+
"retry_after": logs_limiter.window_seconds,
|
| 408 |
+
}), 429
|
| 409 |
+
|
| 410 |
+
try:
|
| 411 |
+
# Delete from database
|
| 412 |
+
deleted = delete_prediction(DB_PATH, prediction_id)
|
| 413 |
+
|
| 414 |
+
if not deleted:
|
| 415 |
+
return jsonify({"error": "Prediction not found"}), 404
|
| 416 |
+
|
| 417 |
+
# Optionally delete associated image file
|
| 418 |
+
from .image_storage import delete_image
|
| 419 |
+
# Note: We'd need to fetch the image_path first, but for now just delete from DB
|
| 420 |
+
# You can enhance this later to also delete the image file
|
| 421 |
+
|
| 422 |
+
return jsonify({"ok": True, "message": "Prediction deleted successfully"}), 200
|
| 423 |
+
except Exception as exc:
|
| 424 |
+
app.logger.exception(f"Failed to delete prediction {prediction_id}")
|
| 425 |
+
return jsonify({"error": "Failed to delete prediction", "detail": str(exc)}), 500
|
| 426 |
+
|
| 427 |
+
@app.route("/detect", methods=["POST"])
|
| 428 |
+
def detect():
|
| 429 |
+
"""
|
| 430 |
+
POST form-data: image file under key 'image'
|
| 431 |
+
Returns: JSON {emotion, confidence} or error JSON
|
| 432 |
+
"""
|
| 433 |
+
# Rate limiting
|
| 434 |
+
client_id = get_client_identifier(request)
|
| 435 |
+
is_allowed, remaining = detect_limiter.is_allowed(client_id)
|
| 436 |
+
if not is_allowed:
|
| 437 |
+
return jsonify({
|
| 438 |
+
"error": "Rate limit exceeded",
|
| 439 |
+
"detail": f"Maximum {detect_limiter.max_requests} requests per {detect_limiter.window_seconds} seconds",
|
| 440 |
+
"retry_after": detect_limiter.window_seconds,
|
| 441 |
+
}), 429
|
| 442 |
+
|
| 443 |
+
# Get model selection from query parameter (default: 'base')
|
| 444 |
+
model_selection = request.args.get("model", "base").lower()
|
| 445 |
+
if model_selection == "fine-tuned" or model_selection == "finetuned":
|
| 446 |
+
model_local = app.config.get("FINETUNED_MODEL")
|
| 447 |
+
labels_local = app.config.get("FINETUNED_LABELS") or []
|
| 448 |
+
model_type = app.config.get("FINETUNED_MODEL_TYPE", "keras")
|
| 449 |
+
model_version = app.config.get("FINETUNED_MODEL_VERSION", "unknown")
|
| 450 |
+
if model_local is None:
|
| 451 |
+
app.logger.warning("Asripa model requested but not available, using base model")
|
| 452 |
+
model_local = app.config.get("BASE_MODEL")
|
| 453 |
+
labels_local = app.config.get("BASE_LABELS") or []
|
| 454 |
+
model_type = app.config.get("BASE_MODEL_TYPE", "keras")
|
| 455 |
+
model_version = app.config.get("BASE_MODEL_VERSION", "unknown")
|
| 456 |
+
else:
|
| 457 |
+
# Use base model (default)
|
| 458 |
+
model_local = app.config.get("BASE_MODEL")
|
| 459 |
+
labels_local = app.config.get("BASE_LABELS") or []
|
| 460 |
+
model_type = app.config.get("BASE_MODEL_TYPE", "keras")
|
| 461 |
+
model_version = app.config.get("BASE_MODEL_VERSION", "unknown")
|
| 462 |
+
|
| 463 |
+
app.logger.info(f"Using model: {model_selection} (version: {model_version})")
|
| 464 |
+
|
| 465 |
+
if model_local is None:
|
| 466 |
+
app.logger.error("Detect called but model not loaded")
|
| 467 |
+
raise ServiceUnavailableError("Model not loaded on server")
|
| 468 |
+
|
| 469 |
+
print(f"[DETECT] Using model type: {model_type}")
|
| 470 |
+
|
| 471 |
+
# Validate upload presence
|
| 472 |
+
if "image" not in request.files:
|
| 473 |
+
raise ValidationError("No image provided")
|
| 474 |
+
|
| 475 |
+
file = request.files["image"]
|
| 476 |
+
|
| 477 |
+
# Comprehensive validation
|
| 478 |
+
is_valid, error_msg, filename = validate_image_file(
|
| 479 |
+
file,
|
| 480 |
+
max_size=app.config.get("MAX_CONTENT_LENGTH", DEFAULTS["MAX_FILE_SIZE"]),
|
| 481 |
+
allowed_extensions=app.config.get("ALLOWED_EXT", DEFAULTS["ALLOWED_EXT"]),
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
if not is_valid:
|
| 485 |
+
raise ValidationError(error_msg)
|
| 486 |
+
|
| 487 |
+
tmp_dir = app.config.get("TMP_DIR", TMP_DIR_DEFAULT)
|
| 488 |
+
tmp_path = os.path.join(tmp_dir, filename)
|
| 489 |
+
used_filename = filename
|
| 490 |
+
|
| 491 |
+
try:
|
| 492 |
+
# Save file and verify it was saved
|
| 493 |
+
file.save(tmp_path)
|
| 494 |
+
if not os.path.exists(tmp_path):
|
| 495 |
+
app.logger.error("Failed to save uploaded file to %s", tmp_path)
|
| 496 |
+
raise ValidationError("Failed to save uploaded image")
|
| 497 |
+
|
| 498 |
+
file_size = os.path.getsize(tmp_path)
|
| 499 |
+
if file_size == 0:
|
| 500 |
+
app.logger.error("Saved file is empty: %s", tmp_path)
|
| 501 |
+
raise ValidationError("Uploaded image is empty")
|
| 502 |
+
|
| 503 |
+
print(f"[DETECT] Saved file: {tmp_path}, size: {file_size} bytes")
|
| 504 |
+
app.logger.info("Saved file: %s, size: %d bytes", tmp_path, file_size)
|
| 505 |
+
|
| 506 |
+
# Import numpy for both paths
|
| 507 |
+
import numpy as np
|
| 508 |
+
|
| 509 |
+
# Handle ViT and Keras models differently
|
| 510 |
+
if model_type == "vit":
|
| 511 |
+
# Vision Transformer model - needs RGB PIL Image
|
| 512 |
+
from app.vit_utils import preprocess_face_for_vit, predict_with_vit
|
| 513 |
+
from PIL import Image
|
| 514 |
+
|
| 515 |
+
face_image, used_filename = preprocess_face_for_vit(tmp_path)
|
| 516 |
+
if face_image is None:
|
| 517 |
+
app.logger.warning("No face detected for file %s (size: %d bytes)", filename, file_size)
|
| 518 |
+
raise ValidationError("No face detected in image. Please ensure your face is clearly visible, well-lit, and facing the camera.")
|
| 519 |
+
|
| 520 |
+
# Run ViT prediction
|
| 521 |
+
idx, confidence, all_probs = predict_with_vit(model_local, face_image, labels_local)
|
| 522 |
+
emotion = labels_local[idx] if idx < len(labels_local) else str(idx)
|
| 523 |
+
|
| 524 |
+
# Debug output
|
| 525 |
+
sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
|
| 526 |
+
app.logger.info(f"Prediction probabilities for {filename} (sorted): {sorted_probs}")
|
| 527 |
+
print(f"[DETECT] All emotion probabilities (sorted by confidence):")
|
| 528 |
+
for emo, prob in sorted_probs:
|
| 529 |
+
marker = " <-- SELECTED" if emo == emotion else ""
|
| 530 |
+
print(f" {emo}: {prob:.3f}{marker}")
|
| 531 |
+
print(f"[DETECT] Predicted emotion: {emotion}, confidence: {confidence:.3f}")
|
| 532 |
+
|
| 533 |
+
# Warn if happy probability is suspiciously low (potential misclassification)
|
| 534 |
+
happy_prob = all_probs.get('happy', 0.0)
|
| 535 |
+
if happy_prob < 0.15 and confidence > 0.3 and emotion != 'happy':
|
| 536 |
+
app.logger.warning(f"⚠️ Low happy probability ({happy_prob:.3f}) but high confidence ({confidence:.3f}) for {emotion}. Possible misclassification.")
|
| 537 |
+
print(f"[DETECT] ⚠️ WARNING: Happy probability is very low ({happy_prob:.3f}) - possible misclassification")
|
| 538 |
+
|
| 539 |
+
# Convert to numpy array format for compatibility with rest of code
|
| 540 |
+
probs = np.array([all_probs.get(labels_local[i] if i < len(labels_local) else f"class_{i}", 0.0)
|
| 541 |
+
for i in range(len(labels_local))])
|
| 542 |
+
else:
|
| 543 |
+
# Keras model - existing code path
|
| 544 |
+
# Preprocess face - preprocess_face is imported above in factory scope
|
| 545 |
+
res = preprocess_face(tmp_path)
|
| 546 |
+
if isinstance(res, tuple):
|
| 547 |
+
face_array, used_filename = res
|
| 548 |
+
else:
|
| 549 |
+
face_array = res
|
| 550 |
+
|
| 551 |
+
if face_array is None:
|
| 552 |
+
app.logger.warning("No face detected for file %s (size: %d bytes)", filename, file_size)
|
| 553 |
+
raise ValidationError("No face detected in image. Please ensure your face is clearly visible, well-lit, and facing the camera.")
|
| 554 |
+
|
| 555 |
+
# Defensive conversion and validations (numpy already imported above)
|
| 556 |
+
try:
|
| 557 |
+
face_input = np.asarray(face_array)
|
| 558 |
+
except Exception as exc:
|
| 559 |
+
app.logger.exception("Failed converting preprocessed face to numpy array")
|
| 560 |
+
return jsonify({"error": "Invalid preprocessed face data."}), 500
|
| 561 |
+
|
| 562 |
+
if getattr(face_input, "dtype", None) == object:
|
| 563 |
+
app.logger.error("face_input has object dtype (likely contains None) for file %s", filename)
|
| 564 |
+
return jsonify({"error": "Invalid preprocessed face data (object dtype)."}), 500
|
| 565 |
+
|
| 566 |
+
# Ensure batch dim and channel dim
|
| 567 |
+
if face_input.ndim == 2:
|
| 568 |
+
# (H, W) -> (1, H, W, 1)
|
| 569 |
+
face_input = np.expand_dims(np.expand_dims(face_input, axis=-1), axis=0)
|
| 570 |
+
elif face_input.ndim == 3:
|
| 571 |
+
# (H, W, C) -> (1, H, W, C)
|
| 572 |
+
face_input = np.expand_dims(face_input, axis=0)
|
| 573 |
+
elif face_input.ndim == 4:
|
| 574 |
+
# already batched
|
| 575 |
+
pass
|
| 576 |
+
else:
|
| 577 |
+
app.logger.error("Unsupported preprocessed face ndim %s for file %s", getattr(face_input, "ndim", None), filename)
|
| 578 |
+
return jsonify({"error": "Unsupported preprocessed face shape."}), 500
|
| 579 |
+
|
| 580 |
+
# sanity checks
|
| 581 |
+
if face_input.shape[0] < 1:
|
| 582 |
+
return jsonify({"error": "Empty batch sent to model."}), 500
|
| 583 |
+
try:
|
| 584 |
+
if not np.isfinite(face_input.astype("float32")).all():
|
| 585 |
+
app.logger.error("face_input contains non-finite values for file %s", filename)
|
| 586 |
+
return jsonify({"error": "Preprocessed face contains non-finite values."}), 500
|
| 587 |
+
except Exception:
|
| 588 |
+
app.logger.exception("Failed checking finiteness of face_input")
|
| 589 |
+
return jsonify({"error": "Preprocessed face contains invalid numeric values."}), 500
|
| 590 |
+
|
| 591 |
+
# Run prediction
|
| 592 |
+
try:
|
| 593 |
+
preds = model_local.predict(face_input, verbose=0)
|
| 594 |
+
except Exception as exc:
|
| 595 |
+
app.logger.exception("Model predict failed for file %s", filename)
|
| 596 |
+
return jsonify({"error": "Prediction failed", "detail": str(exc)}), 500
|
| 597 |
+
|
| 598 |
+
if preds is None:
|
| 599 |
+
return jsonify({"error": "Prediction returned no output"}), 500
|
| 600 |
+
|
| 601 |
+
arr = np.asarray(preds)
|
| 602 |
+
if arr.ndim == 2:
|
| 603 |
+
probs = arr[0]
|
| 604 |
+
elif arr.ndim == 1:
|
| 605 |
+
probs = arr
|
| 606 |
+
else:
|
| 607 |
+
app.logger.error("Unexpected prediction shape %s for file %s", getattr(arr, "shape", None), filename)
|
| 608 |
+
return jsonify({"error": "Unexpected prediction shape", "shape": list(getattr(arr, "shape", []))}), 500
|
| 609 |
+
|
| 610 |
+
if probs.size == 0:
|
| 611 |
+
return jsonify({"error": "Empty prediction probabilities"}), 500
|
| 612 |
+
|
| 613 |
+
# Verify model output matches expected number of classes
|
| 614 |
+
expected_classes = len(labels_local) if isinstance(labels_local, (list, dict)) else 7
|
| 615 |
+
if len(probs) != expected_classes:
|
| 616 |
+
app.logger.warning(f"Model output has {len(probs)} classes but labels have {expected_classes}. Labels: {labels_local}")
|
| 617 |
+
print(f"[WARNING] Model output shape mismatch: {len(probs)} classes vs {expected_classes} labels")
|
| 618 |
+
|
| 619 |
+
idx = int(np.argmax(probs))
|
| 620 |
+
confidence = float(probs[idx])
|
| 621 |
+
|
| 622 |
+
# Debug: Log all prediction probabilities to understand model behavior
|
| 623 |
+
all_probs = {}
|
| 624 |
+
for i in range(len(probs)):
|
| 625 |
+
if isinstance(labels_local, list) and i < len(labels_local):
|
| 626 |
+
all_probs[labels_local[i]] = float(probs[i])
|
| 627 |
+
elif isinstance(labels_local, dict):
|
| 628 |
+
label_key = str(i) if str(i) in labels_local else i if i in labels_local else f"class_{i}"
|
| 629 |
+
all_probs[label_key] = float(probs[i])
|
| 630 |
+
else:
|
| 631 |
+
all_probs[str(i)] = float(probs[i])
|
| 632 |
+
|
| 633 |
+
# Sort by probability (highest first) for easier debugging
|
| 634 |
+
sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
|
| 635 |
+
app.logger.info(f"Prediction probabilities for {filename} (sorted): {sorted_probs}")
|
| 636 |
+
print(f"[DETECT] All emotion probabilities (sorted by confidence):")
|
| 637 |
+
for emotion, prob in sorted_probs:
|
| 638 |
+
marker = " <-- SELECTED" if emotion == (labels_local[idx] if isinstance(labels_local, list) and idx < len(labels_local) else str(idx)) else ""
|
| 639 |
+
print(f" {emotion}: {prob:.3f}{marker}")
|
| 640 |
+
print(f"[DETECT] Predicted emotion index: {idx}, confidence: {confidence:.3f}")
|
| 641 |
+
print(f"[DETECT] Available labels: {labels_local}")
|
| 642 |
+
|
| 643 |
+
# Resolve label safely
|
| 644 |
+
if isinstance(labels_local, dict):
|
| 645 |
+
emotion = labels_local.get(str(idx)) or labels_local.get(idx) or list(labels_local.values())[idx]
|
| 646 |
+
elif isinstance(labels_local, list):
|
| 647 |
+
emotion = labels_local[idx] if 0 <= idx < len(labels_local) else str(idx)
|
| 648 |
+
else:
|
| 649 |
+
emotion = str(idx)
|
| 650 |
+
|
| 651 |
+
print(f"[DETECT] Mapped emotion label: {emotion}")
|
| 652 |
+
|
| 653 |
+
# Save image even for low confidence (for debugging/analysis)
|
| 654 |
+
images_dir = app.config.get("IMAGES_DIR", IMAGES_DIR_DEFAULT)
|
| 655 |
+
stored_filename = None
|
| 656 |
+
try:
|
| 657 |
+
stored_filename = save_image(tmp_path, images_dir, used_filename)
|
| 658 |
+
except Exception:
|
| 659 |
+
app.logger.exception("Failed to save image, continuing without storage")
|
| 660 |
+
|
| 661 |
+
# Confidence threshold - slightly lower for better detection in challenging conditions
|
| 662 |
+
# But still maintain quality standards
|
| 663 |
+
min_conf = app.config.get("MIN_CONFIDENCE", DEFAULTS["MIN_CONFIDENCE"])
|
| 664 |
+
# Allow slightly lower confidence (0.45) but warn user
|
| 665 |
+
if confidence < min_conf:
|
| 666 |
+
try:
|
| 667 |
+
log_prediction(DB_PATH, used_filename, "low_confidence", confidence, stored_filename)
|
| 668 |
+
except Exception:
|
| 669 |
+
app.logger.exception("Failed logging low-confidence prediction")
|
| 670 |
+
return jsonify({
|
| 671 |
+
"error": "low confidence",
|
| 672 |
+
"confidence": round(confidence, 3),
|
| 673 |
+
"filename": stored_filename or used_filename,
|
| 674 |
+
}), 422
|
| 675 |
+
|
| 676 |
+
# Log and respond (image already saved above)
|
| 677 |
+
try:
|
| 678 |
+
log_prediction(DB_PATH, used_filename, emotion, confidence, stored_filename)
|
| 679 |
+
except Exception:
|
| 680 |
+
app.logger.exception("Failed to log prediction to DB")
|
| 681 |
+
|
| 682 |
+
# Return all probabilities for debugging (frontend can use this to show top emotions)
|
| 683 |
+
all_emotion_probs = {}
|
| 684 |
+
if model_type == "vit":
|
| 685 |
+
# For ViT, all_probs already contains the dict
|
| 686 |
+
all_emotion_probs = {k: round(v, 4) for k, v in all_probs.items()}
|
| 687 |
+
else:
|
| 688 |
+
# For Keras, build from probs array
|
| 689 |
+
for i in range(len(probs)):
|
| 690 |
+
if isinstance(labels_local, list) and i < len(labels_local):
|
| 691 |
+
all_emotion_probs[labels_local[i]] = round(float(probs[i]), 4)
|
| 692 |
+
elif isinstance(labels_local, dict):
|
| 693 |
+
label_key = str(i) if str(i) in labels_local else i if i in labels_local else f"class_{i}"
|
| 694 |
+
all_emotion_probs[label_key] = round(float(probs[i]), 4)
|
| 695 |
+
|
| 696 |
+
return jsonify({
|
| 697 |
+
"emotion": emotion,
|
| 698 |
+
"confidence": round(confidence, 3),
|
| 699 |
+
"filename": stored_filename or used_filename,
|
| 700 |
+
"all_probabilities": all_emotion_probs, # Include all probabilities for debugging
|
| 701 |
+
"model": model_selection,
|
| 702 |
+
"model_version": model_version,
|
| 703 |
+
}), 200
|
| 704 |
+
|
| 705 |
+
except (ValidationError, APIError, NotFoundError, ServiceUnavailableError) as exc:
|
| 706 |
+
# Let Flask's error handler process these
|
| 707 |
+
raise
|
| 708 |
+
except Exception as exc:
|
| 709 |
+
app.logger.exception("detection error for file %s", filename)
|
| 710 |
+
tb = traceback.format_exc()
|
| 711 |
+
return jsonify({"error": "internal error", "detail": str(exc), "trace": tb}), 500
|
| 712 |
+
|
| 713 |
+
finally:
|
| 714 |
+
# cleanup tmp file (image is already saved to images/ if successful)
|
| 715 |
+
try:
|
| 716 |
+
if os.path.exists(tmp_path):
|
| 717 |
+
os.remove(tmp_path)
|
| 718 |
+
except Exception:
|
| 719 |
+
app.logger.exception("failed removing tmp file")
|
| 720 |
+
|
| 721 |
+
# ----------------------------
|
| 722 |
+
# Image serving endpoint
|
| 723 |
+
# ----------------------------
|
| 724 |
+
@app.route("/images/<filename>", methods=["GET"])
|
| 725 |
+
def serve_image(filename: str):
|
| 726 |
+
"""
|
| 727 |
+
Serve stored images.
|
| 728 |
+
GET /images/{filename}
|
| 729 |
+
"""
|
| 730 |
+
from flask import send_from_directory, abort
|
| 731 |
+
|
| 732 |
+
# Rate limiting
|
| 733 |
+
client_id = get_client_identifier(request)
|
| 734 |
+
is_allowed, remaining = images_limiter.is_allowed(client_id)
|
| 735 |
+
if not is_allowed:
|
| 736 |
+
return jsonify({
|
| 737 |
+
"error": "Rate limit exceeded",
|
| 738 |
+
"detail": f"Maximum {images_limiter.max_requests} requests per {images_limiter.window_seconds} seconds",
|
| 739 |
+
"retry_after": images_limiter.window_seconds,
|
| 740 |
+
}), 429
|
| 741 |
+
|
| 742 |
+
try:
|
| 743 |
+
images_dir = app.config.get("IMAGES_DIR", IMAGES_DIR_DEFAULT)
|
| 744 |
+
image_path = get_image_path(images_dir, filename)
|
| 745 |
+
|
| 746 |
+
if not image_path:
|
| 747 |
+
app.logger.warning("Image not found: %s (checked in %s)", filename, images_dir)
|
| 748 |
+
abort(404)
|
| 749 |
+
|
| 750 |
+
# Extract the actual filename from the path (in case secure_filename changed it)
|
| 751 |
+
actual_filename = os.path.basename(image_path)
|
| 752 |
+
|
| 753 |
+
return send_from_directory(
|
| 754 |
+
images_dir,
|
| 755 |
+
actual_filename,
|
| 756 |
+
mimetype="image/jpeg", # Default, will be auto-detected
|
| 757 |
+
)
|
| 758 |
+
except Exception as exc:
|
| 759 |
+
app.logger.exception("Failed to serve image %s", filename)
|
| 760 |
+
return jsonify({"error": "Failed to serve image", "detail": str(exc)}), 500
|
| 761 |
+
|
| 762 |
+
return app
|
app/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (38.7 kB). View file
|
|
|
app/__pycache__/db_logger.cpython-312.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
app/__pycache__/error_handlers.cpython-312.pyc
ADDED
|
Binary file (3.69 kB). View file
|
|
|
app/__pycache__/image_storage.cpython-312.pyc
ADDED
|
Binary file (4.48 kB). View file
|
|
|
app/__pycache__/model_loader.cpython-312.pyc
ADDED
|
Binary file (7.95 kB). View file
|
|
|
app/__pycache__/rate_limiter.cpython-312.pyc
ADDED
|
Binary file (3.83 kB). View file
|
|
|
app/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (6.33 kB). View file
|
|
|
app/__pycache__/validators.cpython-312.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
app/__pycache__/vit_utils.cpython-312.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
app/db_logger.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import sqlite3
|
| 3 |
+
import os
|
| 4 |
+
import datetime
|
| 5 |
+
from typing import Dict, Tuple, List, Optional
|
| 6 |
+
import threading
|
| 7 |
+
|
| 8 |
+
SCHEMA = """
|
| 9 |
+
PRAGMA foreign_keys = ON;
|
| 10 |
+
CREATE TABLE IF NOT EXISTS predictions (
|
| 11 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 12 |
+
ts TEXT NOT NULL,
|
| 13 |
+
filename TEXT,
|
| 14 |
+
image_path TEXT,
|
| 15 |
+
emotion TEXT,
|
| 16 |
+
confidence REAL
|
| 17 |
+
);
|
| 18 |
+
|
| 19 |
+
-- Indexes for better query performance
|
| 20 |
+
CREATE INDEX IF NOT EXISTS idx_predictions_ts ON predictions(ts DESC);
|
| 21 |
+
CREATE INDEX IF NOT EXISTS idx_predictions_emotion ON predictions(emotion);
|
| 22 |
+
CREATE INDEX IF NOT EXISTS idx_predictions_confidence ON predictions(confidence);
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# Connection pool for better performance
|
| 26 |
+
_db_lock = threading.Lock()
|
| 27 |
+
_connection_pool: Dict[str, sqlite3.Connection] = {}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_connection(db_path: str, timeout: int = 10) -> sqlite3.Connection:
|
| 31 |
+
"""
|
| 32 |
+
Get a database connection with connection pooling.
|
| 33 |
+
For SQLite, we use a simple per-thread connection approach.
|
| 34 |
+
"""
|
| 35 |
+
thread_id = threading.get_ident()
|
| 36 |
+
key = f"{db_path}_{thread_id}"
|
| 37 |
+
|
| 38 |
+
with _db_lock:
|
| 39 |
+
if key not in _connection_pool:
|
| 40 |
+
conn = sqlite3.connect(db_path, timeout=timeout, check_same_thread=False)
|
| 41 |
+
# Optimize SQLite settings
|
| 42 |
+
conn.execute("PRAGMA journal_mode=WAL;")
|
| 43 |
+
conn.execute("PRAGMA synchronous=NORMAL;")
|
| 44 |
+
conn.execute("PRAGMA cache_size=10000;")
|
| 45 |
+
conn.execute("PRAGMA temp_store=MEMORY;")
|
| 46 |
+
_connection_pool[key] = conn
|
| 47 |
+
return _connection_pool[key]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def init_db(db_path: str):
|
| 51 |
+
db_dir = os.path.dirname(db_path)
|
| 52 |
+
if db_dir and not os.path.exists(db_dir):
|
| 53 |
+
os.makedirs(db_dir, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
conn = sqlite3.connect(db_path, timeout=10)
|
| 56 |
+
try:
|
| 57 |
+
conn.execute("PRAGMA journal_mode=WAL;")
|
| 58 |
+
conn.execute("PRAGMA synchronous=NORMAL;")
|
| 59 |
+
conn.execute("PRAGMA cache_size=10000;")
|
| 60 |
+
conn.executescript(SCHEMA)
|
| 61 |
+
conn.commit()
|
| 62 |
+
finally:
|
| 63 |
+
conn.close()
|
| 64 |
+
|
| 65 |
+
def log_prediction(db_path: str, filename: str, emotion: str, confidence: float, image_path: Optional[str] = None):
|
| 66 |
+
"""
|
| 67 |
+
Logs a prediction row. This function ensures ts is a string and that
|
| 68 |
+
values bound to SQLite are primitive types (no functions or callables).
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
db_path: Path to SQLite database
|
| 72 |
+
filename: Original filename
|
| 73 |
+
emotion: Detected emotion
|
| 74 |
+
confidence: Confidence score
|
| 75 |
+
image_path: Path to stored image file (optional)
|
| 76 |
+
"""
|
| 77 |
+
# Defensive conversions
|
| 78 |
+
try:
|
| 79 |
+
ts = datetime.datetime.now(datetime.UTC).isoformat()
|
| 80 |
+
except Exception:
|
| 81 |
+
# fallback to str(datetime)
|
| 82 |
+
ts = str(datetime.datetime.utcnow())
|
| 83 |
+
|
| 84 |
+
if filename is None:
|
| 85 |
+
filename = ""
|
| 86 |
+
else:
|
| 87 |
+
filename = str(filename)
|
| 88 |
+
|
| 89 |
+
if emotion is None:
|
| 90 |
+
emotion = ""
|
| 91 |
+
else:
|
| 92 |
+
emotion = str(emotion)
|
| 93 |
+
|
| 94 |
+
if image_path is None:
|
| 95 |
+
image_path = ""
|
| 96 |
+
else:
|
| 97 |
+
image_path = str(image_path)
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
confidence_val = float(confidence or 0.0)
|
| 101 |
+
except Exception:
|
| 102 |
+
confidence_val = 0.0
|
| 103 |
+
|
| 104 |
+
conn = get_connection(db_path)
|
| 105 |
+
try:
|
| 106 |
+
cur = conn.cursor()
|
| 107 |
+
# Check if image_path column exists, if not, add it
|
| 108 |
+
cur.execute("PRAGMA table_info(predictions)")
|
| 109 |
+
columns = [row[1] for row in cur.fetchall()]
|
| 110 |
+
|
| 111 |
+
if "image_path" not in columns:
|
| 112 |
+
# Migrate schema - add image_path column
|
| 113 |
+
cur.execute("ALTER TABLE predictions ADD COLUMN image_path TEXT")
|
| 114 |
+
conn.commit()
|
| 115 |
+
|
| 116 |
+
cur.execute(
|
| 117 |
+
"INSERT INTO predictions (ts, filename, image_path, emotion, confidence) VALUES (?, ?, ?, ?, ?)",
|
| 118 |
+
(ts, filename, image_path, emotion, confidence_val)
|
| 119 |
+
)
|
| 120 |
+
conn.commit()
|
| 121 |
+
return cur.lastrowid
|
| 122 |
+
except Exception:
|
| 123 |
+
# On error, close connection and retry with new connection
|
| 124 |
+
with _db_lock:
|
| 125 |
+
thread_id = threading.get_ident()
|
| 126 |
+
key = f"{db_path}_{thread_id}"
|
| 127 |
+
if key in _connection_pool:
|
| 128 |
+
try:
|
| 129 |
+
_connection_pool[key].close()
|
| 130 |
+
except:
|
| 131 |
+
pass
|
| 132 |
+
del _connection_pool[key]
|
| 133 |
+
raise
|
| 134 |
+
|
| 135 |
+
def get_metrics(db_path: str) -> Dict:
|
| 136 |
+
conn = get_connection(db_path)
|
| 137 |
+
try:
|
| 138 |
+
cur = conn.cursor()
|
| 139 |
+
cur.execute("SELECT COUNT(*) FROM predictions")
|
| 140 |
+
total = cur.fetchone()[0] or 0
|
| 141 |
+
cur.execute("SELECT emotion, COUNT(*) FROM predictions GROUP BY emotion")
|
| 142 |
+
rows = cur.fetchall()
|
| 143 |
+
by_label = {r[0]: r[1] for r in rows}
|
| 144 |
+
return {"total": total, "by_label": by_label}
|
| 145 |
+
except Exception:
|
| 146 |
+
with _db_lock:
|
| 147 |
+
thread_id = threading.get_ident()
|
| 148 |
+
key = f"{db_path}_{thread_id}"
|
| 149 |
+
if key in _connection_pool:
|
| 150 |
+
try:
|
| 151 |
+
_connection_pool[key].close()
|
| 152 |
+
except:
|
| 153 |
+
pass
|
| 154 |
+
del _connection_pool[key]
|
| 155 |
+
raise
|
| 156 |
+
|
| 157 |
+
def tail_rows(db_path: str, limit: int = 10, offset: int = 0, emotion_filter: Optional[str] = None,
|
| 158 |
+
min_confidence: Optional[float] = None, max_confidence: Optional[float] = None,
|
| 159 |
+
date_from: Optional[str] = None, date_to: Optional[str] = None) -> Tuple:
|
| 160 |
+
"""
|
| 161 |
+
Fetch rows from predictions table with filtering and pagination.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
List of tuples: (id, ts, filename, image_path, emotion, confidence) or
|
| 165 |
+
(ts, filename, image_path, emotion, confidence) depending on query
|
| 166 |
+
"""
|
| 167 |
+
conn = get_connection(db_path)
|
| 168 |
+
try:
|
| 169 |
+
cur = conn.cursor()
|
| 170 |
+
|
| 171 |
+
# Build query with filters
|
| 172 |
+
query = "SELECT id, ts, filename, image_path, emotion, confidence FROM predictions WHERE 1=1"
|
| 173 |
+
params = []
|
| 174 |
+
|
| 175 |
+
if emotion_filter:
|
| 176 |
+
query += " AND emotion = ?"
|
| 177 |
+
params.append(emotion_filter)
|
| 178 |
+
|
| 179 |
+
if min_confidence is not None:
|
| 180 |
+
query += " AND confidence >= ?"
|
| 181 |
+
params.append(min_confidence)
|
| 182 |
+
|
| 183 |
+
if max_confidence is not None:
|
| 184 |
+
query += " AND confidence <= ?"
|
| 185 |
+
params.append(max_confidence)
|
| 186 |
+
|
| 187 |
+
if date_from:
|
| 188 |
+
query += " AND ts >= ?"
|
| 189 |
+
params.append(date_from)
|
| 190 |
+
|
| 191 |
+
if date_to:
|
| 192 |
+
query += " AND ts <= ?"
|
| 193 |
+
params.append(date_to)
|
| 194 |
+
|
| 195 |
+
query += " ORDER BY id DESC LIMIT ? OFFSET ?"
|
| 196 |
+
params.extend([limit, offset])
|
| 197 |
+
|
| 198 |
+
cur.execute(query, params)
|
| 199 |
+
return cur.fetchall()
|
| 200 |
+
except Exception:
|
| 201 |
+
with _db_lock:
|
| 202 |
+
thread_id = threading.get_ident()
|
| 203 |
+
key = f"{db_path}_{thread_id}"
|
| 204 |
+
if key in _connection_pool:
|
| 205 |
+
try:
|
| 206 |
+
_connection_pool[key].close()
|
| 207 |
+
except:
|
| 208 |
+
pass
|
| 209 |
+
del _connection_pool[key]
|
| 210 |
+
raise
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def delete_prediction(db_path: str, prediction_id: int) -> bool:
|
| 214 |
+
"""
|
| 215 |
+
Delete a prediction by ID.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
db_path: Path to SQLite database
|
| 219 |
+
prediction_id: ID of prediction to delete
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
True if deleted, False otherwise
|
| 223 |
+
"""
|
| 224 |
+
conn = get_connection(db_path)
|
| 225 |
+
try:
|
| 226 |
+
cur = conn.cursor()
|
| 227 |
+
cur.execute("DELETE FROM predictions WHERE id = ?", (prediction_id,))
|
| 228 |
+
conn.commit()
|
| 229 |
+
return cur.rowcount > 0
|
| 230 |
+
except Exception:
|
| 231 |
+
with _db_lock:
|
| 232 |
+
thread_id = threading.get_ident()
|
| 233 |
+
key = f"{db_path}_{thread_id}"
|
| 234 |
+
if key in _connection_pool:
|
| 235 |
+
try:
|
| 236 |
+
_connection_pool[key].close()
|
| 237 |
+
except:
|
| 238 |
+
pass
|
| 239 |
+
del _connection_pool[key]
|
| 240 |
+
raise
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def get_total_count(db_path: str, emotion_filter: Optional[str] = None,
|
| 244 |
+
min_confidence: Optional[float] = None, max_confidence: Optional[float] = None,
|
| 245 |
+
date_from: Optional[str] = None, date_to: Optional[str] = None) -> int:
|
| 246 |
+
"""Get total count of predictions matching filters."""
|
| 247 |
+
conn = get_connection(db_path)
|
| 248 |
+
try:
|
| 249 |
+
cur = conn.cursor()
|
| 250 |
+
|
| 251 |
+
query = "SELECT COUNT(*) FROM predictions WHERE 1=1"
|
| 252 |
+
params = []
|
| 253 |
+
|
| 254 |
+
if emotion_filter:
|
| 255 |
+
query += " AND emotion = ?"
|
| 256 |
+
params.append(emotion_filter)
|
| 257 |
+
|
| 258 |
+
if min_confidence is not None:
|
| 259 |
+
query += " AND confidence >= ?"
|
| 260 |
+
params.append(min_confidence)
|
| 261 |
+
|
| 262 |
+
if max_confidence is not None:
|
| 263 |
+
query += " AND confidence <= ?"
|
| 264 |
+
params.append(max_confidence)
|
| 265 |
+
|
| 266 |
+
if date_from:
|
| 267 |
+
query += " AND ts >= ?"
|
| 268 |
+
params.append(date_from)
|
| 269 |
+
|
| 270 |
+
if date_to:
|
| 271 |
+
query += " AND ts <= ?"
|
| 272 |
+
params.append(date_to)
|
| 273 |
+
|
| 274 |
+
cur.execute(query, params)
|
| 275 |
+
return cur.fetchone()[0] or 0
|
| 276 |
+
except Exception:
|
| 277 |
+
with _db_lock:
|
| 278 |
+
thread_id = threading.get_ident()
|
| 279 |
+
key = f"{db_path}_{thread_id}"
|
| 280 |
+
if key in _connection_pool:
|
| 281 |
+
try:
|
| 282 |
+
_connection_pool[key].close()
|
| 283 |
+
except:
|
| 284 |
+
pass
|
| 285 |
+
del _connection_pool[key]
|
| 286 |
+
raise
|
app/error_handlers.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Structured error handling for API responses.
|
| 3 |
+
"""
|
| 4 |
+
from flask import jsonify
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class APIError(Exception):
|
| 9 |
+
"""Base exception for API errors."""
|
| 10 |
+
status_code = 500
|
| 11 |
+
message = "An error occurred"
|
| 12 |
+
|
| 13 |
+
def __init__(self, message: str = None, status_code: int = None, details: Dict[str, Any] = None):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.message = message or self.message
|
| 16 |
+
self.status_code = status_code or self.status_code
|
| 17 |
+
self.details = details or {}
|
| 18 |
+
|
| 19 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 20 |
+
return {
|
| 21 |
+
"error": self.message,
|
| 22 |
+
**self.details,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ValidationError(APIError):
|
| 27 |
+
"""Validation error (400)."""
|
| 28 |
+
status_code = 400
|
| 29 |
+
message = "Validation error"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class NotFoundError(APIError):
|
| 33 |
+
"""Resource not found (404)."""
|
| 34 |
+
status_code = 404
|
| 35 |
+
message = "Resource not found"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ServiceUnavailableError(APIError):
|
| 39 |
+
"""Service unavailable (503)."""
|
| 40 |
+
status_code = 503
|
| 41 |
+
message = "Service unavailable"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def register_error_handlers(app):
|
| 45 |
+
"""Register error handlers for the Flask app."""
|
| 46 |
+
|
| 47 |
+
@app.errorhandler(APIError)
|
| 48 |
+
def handle_api_error(error: APIError):
|
| 49 |
+
response = jsonify(error.to_dict())
|
| 50 |
+
response.status_code = error.status_code
|
| 51 |
+
return response
|
| 52 |
+
|
| 53 |
+
@app.errorhandler(404)
|
| 54 |
+
def handle_not_found(e):
|
| 55 |
+
return jsonify({"error": "Endpoint not found"}), 404
|
| 56 |
+
|
| 57 |
+
@app.errorhandler(405)
|
| 58 |
+
def handle_method_not_allowed(e):
|
| 59 |
+
return jsonify({"error": "Method not allowed"}), 405
|
| 60 |
+
|
| 61 |
+
@app.errorhandler(500)
|
| 62 |
+
def handle_internal_error(e):
|
| 63 |
+
app.logger.exception("Internal server error")
|
| 64 |
+
return jsonify({"error": "Internal server error"}), 500
|
app/image_cleanup.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image cleanup utility to remove orphaned images (not referenced in database).
|
| 3 |
+
Can be run as a scheduled job or manually.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import sqlite3
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Set
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_referenced_images(db_path: str) -> Set[str]:
|
| 15 |
+
"""
|
| 16 |
+
Get set of all image filenames referenced in the database.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Set of image filenames (basenames only)
|
| 20 |
+
"""
|
| 21 |
+
conn = sqlite3.connect(db_path, timeout=10)
|
| 22 |
+
try:
|
| 23 |
+
cur = conn.cursor()
|
| 24 |
+
|
| 25 |
+
# Check if image_path column exists
|
| 26 |
+
cur.execute("PRAGMA table_info(predictions)")
|
| 27 |
+
columns = [row[1] for row in cur.fetchall()]
|
| 28 |
+
|
| 29 |
+
if "image_path" not in columns:
|
| 30 |
+
# Column doesn't exist yet, return empty set
|
| 31 |
+
return set()
|
| 32 |
+
|
| 33 |
+
# Get all non-empty image_path values
|
| 34 |
+
cur.execute("SELECT DISTINCT image_path FROM predictions WHERE image_path IS NOT NULL AND image_path != ''")
|
| 35 |
+
rows = cur.fetchall()
|
| 36 |
+
|
| 37 |
+
# Extract just the filenames (basenames)
|
| 38 |
+
referenced = set()
|
| 39 |
+
for row in rows:
|
| 40 |
+
if row[0]:
|
| 41 |
+
filename = os.path.basename(row[0])
|
| 42 |
+
if filename:
|
| 43 |
+
referenced.add(filename)
|
| 44 |
+
|
| 45 |
+
return referenced
|
| 46 |
+
finally:
|
| 47 |
+
conn.close()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def cleanup_orphaned_images(images_dir: str, db_path: str, dry_run: bool = True) -> dict:
|
| 51 |
+
"""
|
| 52 |
+
Remove image files that are not referenced in the database.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
images_dir: Directory containing images
|
| 56 |
+
db_path: Path to SQLite database
|
| 57 |
+
dry_run: If True, only report what would be deleted without actually deleting
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Dict with cleanup statistics
|
| 61 |
+
"""
|
| 62 |
+
if not os.path.exists(images_dir):
|
| 63 |
+
logger.warning(f"Images directory does not exist: {images_dir}")
|
| 64 |
+
return {
|
| 65 |
+
"total_images": 0,
|
| 66 |
+
"referenced": 0,
|
| 67 |
+
"orphaned": 0,
|
| 68 |
+
"deleted": 0,
|
| 69 |
+
"errors": 0,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
# Get referenced images from database
|
| 73 |
+
referenced = get_referenced_images(db_path)
|
| 74 |
+
logger.info(f"Found {len(referenced)} referenced images in database")
|
| 75 |
+
|
| 76 |
+
# Get all image files in directory
|
| 77 |
+
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp"}
|
| 78 |
+
all_images = []
|
| 79 |
+
|
| 80 |
+
for file_path in Path(images_dir).iterdir():
|
| 81 |
+
if file_path.is_file() and file_path.suffix.lower() in image_extensions:
|
| 82 |
+
all_images.append(file_path.name)
|
| 83 |
+
|
| 84 |
+
total_images = len(all_images)
|
| 85 |
+
logger.info(f"Found {total_images} image files in directory")
|
| 86 |
+
|
| 87 |
+
# Find orphaned images
|
| 88 |
+
orphaned = [img for img in all_images if img not in referenced]
|
| 89 |
+
|
| 90 |
+
stats = {
|
| 91 |
+
"total_images": total_images,
|
| 92 |
+
"referenced": len(referenced),
|
| 93 |
+
"orphaned": len(orphaned),
|
| 94 |
+
"deleted": 0,
|
| 95 |
+
"errors": 0,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
if not orphaned:
|
| 99 |
+
logger.info("No orphaned images found")
|
| 100 |
+
return stats
|
| 101 |
+
|
| 102 |
+
logger.info(f"Found {len(orphaned)} orphaned images")
|
| 103 |
+
|
| 104 |
+
# Delete orphaned images
|
| 105 |
+
for filename in orphaned:
|
| 106 |
+
file_path = os.path.join(images_dir, filename)
|
| 107 |
+
try:
|
| 108 |
+
if not dry_run:
|
| 109 |
+
os.remove(file_path)
|
| 110 |
+
logger.debug(f"Deleted orphaned image: {filename}")
|
| 111 |
+
else:
|
| 112 |
+
logger.debug(f"Would delete orphaned image: {filename}")
|
| 113 |
+
stats["deleted"] += 1
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logger.error(f"Failed to delete {filename}: {e}")
|
| 116 |
+
stats["errors"] += 1
|
| 117 |
+
|
| 118 |
+
if dry_run:
|
| 119 |
+
logger.info(f"DRY RUN: Would delete {stats['deleted']} orphaned images")
|
| 120 |
+
else:
|
| 121 |
+
logger.info(f"Deleted {stats['deleted']} orphaned images")
|
| 122 |
+
|
| 123 |
+
return stats
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def cleanup_old_images(images_dir: str, db_path: str, days_old: int = 30, dry_run: bool = True) -> dict:
|
| 127 |
+
"""
|
| 128 |
+
Remove images older than specified days that are not referenced in recent predictions.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
images_dir: Directory containing images
|
| 132 |
+
db_path: Path to SQLite database
|
| 133 |
+
days_old: Remove images older than this many days
|
| 134 |
+
dry_run: If True, only report what would be deleted
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Dict with cleanup statistics
|
| 138 |
+
"""
|
| 139 |
+
import datetime
|
| 140 |
+
|
| 141 |
+
if not os.path.exists(images_dir):
|
| 142 |
+
return {
|
| 143 |
+
"total_images": 0,
|
| 144 |
+
"old_images": 0,
|
| 145 |
+
"deleted": 0,
|
| 146 |
+
"errors": 0,
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
# Calculate cutoff date
|
| 150 |
+
cutoff_date = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days_old)
|
| 151 |
+
cutoff_iso = cutoff_date.isoformat()
|
| 152 |
+
|
| 153 |
+
# Get images referenced after cutoff
|
| 154 |
+
conn = sqlite3.connect(db_path, timeout=10)
|
| 155 |
+
try:
|
| 156 |
+
cur = conn.cursor()
|
| 157 |
+
cur.execute("""
|
| 158 |
+
SELECT DISTINCT image_path
|
| 159 |
+
FROM predictions
|
| 160 |
+
WHERE image_path IS NOT NULL
|
| 161 |
+
AND image_path != ''
|
| 162 |
+
AND ts >= ?
|
| 163 |
+
""", (cutoff_iso,))
|
| 164 |
+
recent_images = {os.path.basename(row[0]) for row in cur.fetchall() if row[0]}
|
| 165 |
+
finally:
|
| 166 |
+
conn.close()
|
| 167 |
+
|
| 168 |
+
# Find old images
|
| 169 |
+
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp"}
|
| 170 |
+
old_images = []
|
| 171 |
+
|
| 172 |
+
for file_path in Path(images_dir).iterdir():
|
| 173 |
+
if file_path.is_file() and file_path.suffix.lower() in image_extensions:
|
| 174 |
+
# Check file modification time
|
| 175 |
+
mtime = datetime.datetime.fromtimestamp(file_path.stat().st_mtime, tz=datetime.UTC)
|
| 176 |
+
if mtime < cutoff_date:
|
| 177 |
+
# Only delete if not in recent images
|
| 178 |
+
if file_path.name not in recent_images:
|
| 179 |
+
old_images.append(file_path.name)
|
| 180 |
+
|
| 181 |
+
stats = {
|
| 182 |
+
"total_images": len(list(Path(images_dir).iterdir())),
|
| 183 |
+
"old_images": len(old_images),
|
| 184 |
+
"deleted": 0,
|
| 185 |
+
"errors": 0,
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
for filename in old_images:
|
| 189 |
+
file_path = os.path.join(images_dir, filename)
|
| 190 |
+
try:
|
| 191 |
+
if not dry_run:
|
| 192 |
+
os.remove(file_path)
|
| 193 |
+
stats["deleted"] += 1
|
| 194 |
+
except Exception as e:
|
| 195 |
+
logger.error(f"Failed to delete {filename}: {e}")
|
| 196 |
+
stats["errors"] += 1
|
| 197 |
+
|
| 198 |
+
return stats
|
| 199 |
+
|
| 200 |
+
|
app/image_storage.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image storage utilities for saving and serving uploaded images.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import uuid
|
| 6 |
+
import shutil
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
from werkzeug.utils import secure_filename
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def ensure_images_dir(images_dir: str) -> str:
|
| 13 |
+
"""Ensure images directory exists and return its path."""
|
| 14 |
+
os.makedirs(images_dir, exist_ok=True)
|
| 15 |
+
return images_dir
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate_unique_filename(original_filename: str) -> str:
|
| 19 |
+
"""
|
| 20 |
+
Generate a unique filename to avoid collisions.
|
| 21 |
+
Format: {uuid}_{secure_original_name} or just {uuid}.jpg if original is invalid
|
| 22 |
+
"""
|
| 23 |
+
# Get secure base name
|
| 24 |
+
base_name = secure_filename(original_filename)
|
| 25 |
+
if not base_name:
|
| 26 |
+
base_name = "upload.jpg"
|
| 27 |
+
|
| 28 |
+
# Add UUID prefix for uniqueness (use full UUID to ensure uniqueness)
|
| 29 |
+
name, ext = os.path.splitext(base_name)
|
| 30 |
+
if not ext or ext.lower() not in ('.jpg', '.jpeg', '.png'):
|
| 31 |
+
ext = '.jpg'
|
| 32 |
+
unique_id = str(uuid.uuid4()) # Full UUID for better uniqueness
|
| 33 |
+
return f"{unique_id}_{name}{ext}"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def save_image(source_path: str, images_dir: str, original_filename: str) -> Optional[str]:
|
| 37 |
+
"""
|
| 38 |
+
Save an image from source_path to images_dir with a unique filename.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
source_path: Path to source image file
|
| 42 |
+
images_dir: Directory to save images to
|
| 43 |
+
original_filename: Original filename for reference
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Stored filename (relative to images_dir) or None on failure
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
ensure_images_dir(images_dir)
|
| 50 |
+
|
| 51 |
+
# Generate unique filename
|
| 52 |
+
stored_filename = generate_unique_filename(original_filename)
|
| 53 |
+
dest_path = os.path.join(images_dir, stored_filename)
|
| 54 |
+
|
| 55 |
+
# Copy file
|
| 56 |
+
shutil.copy2(source_path, dest_path)
|
| 57 |
+
|
| 58 |
+
return stored_filename
|
| 59 |
+
except Exception as e:
|
| 60 |
+
# Log error but don't fail the request
|
| 61 |
+
import logging
|
| 62 |
+
logging.getLogger(__name__).exception(f"Failed to save image: {e}")
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_image_path(images_dir: str, filename: str) -> Optional[str]:
|
| 67 |
+
"""
|
| 68 |
+
Get full path to an image file if it exists.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
images_dir: Base images directory
|
| 72 |
+
filename: Image filename
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Full path to image or None if not found
|
| 76 |
+
"""
|
| 77 |
+
if not filename:
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
# Security: ensure filename doesn't contain path traversal
|
| 81 |
+
# Extract just the basename to prevent directory traversal
|
| 82 |
+
base_filename = os.path.basename(filename)
|
| 83 |
+
safe_filename = secure_filename(base_filename)
|
| 84 |
+
|
| 85 |
+
if not safe_filename:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
# Use safe_filename for the path (secure_filename may have sanitized it)
|
| 89 |
+
# But also try the original if it's already safe
|
| 90 |
+
image_path = os.path.join(images_dir, safe_filename)
|
| 91 |
+
|
| 92 |
+
if os.path.exists(image_path) and os.path.isfile(image_path):
|
| 93 |
+
return image_path
|
| 94 |
+
|
| 95 |
+
# Also try the original filename if it's different and seems safe
|
| 96 |
+
if safe_filename != base_filename:
|
| 97 |
+
# Check if original is safe (no path separators, no parent dir references)
|
| 98 |
+
if base_filename == filename and '/' not in base_filename and '\\' not in base_filename and '..' not in base_filename:
|
| 99 |
+
alt_path = os.path.join(images_dir, base_filename)
|
| 100 |
+
if os.path.exists(alt_path) and os.path.isfile(alt_path):
|
| 101 |
+
return alt_path
|
| 102 |
+
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def delete_image(images_dir: str, filename: str) -> bool:
|
| 107 |
+
"""
|
| 108 |
+
Delete an image file.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
images_dir: Base images directory
|
| 112 |
+
filename: Image filename to delete
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
True if deleted, False otherwise
|
| 116 |
+
"""
|
| 117 |
+
try:
|
| 118 |
+
image_path = get_image_path(images_dir, filename)
|
| 119 |
+
if image_path and os.path.exists(image_path):
|
| 120 |
+
os.remove(image_path)
|
| 121 |
+
return True
|
| 122 |
+
return False
|
| 123 |
+
except Exception:
|
| 124 |
+
return False
|
app/model_loader.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/model_loader.py
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Tuple, Any, Optional, Dict
|
| 6 |
+
|
| 7 |
+
DEFAULT_LABELS = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
|
| 8 |
+
# HardlyHumans model uses 8 emotions (adds contempt)
|
| 9 |
+
HARDLYHUMANS_LABELS = ['anger', 'contempt', 'sad', 'happy', 'neutral', 'disgust', 'fear', 'surprise']
|
| 10 |
+
|
| 11 |
+
def load_emotion_model(force_model: str = None):
|
| 12 |
+
"""
|
| 13 |
+
Load emotion detection model. Supports both Keras and Vision Transformer models.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
force_model: 'base' to force base model, 'fine-tuned' to force fine-tuned, None for auto
|
| 17 |
+
|
| 18 |
+
Returns: (model_dict, labels, model_version, model_type)
|
| 19 |
+
model_dict: For ViT: {'model': model, 'processor': processor, 'type': 'vit'}
|
| 20 |
+
For Keras: model object
|
| 21 |
+
model_type: 'keras' or 'vit' (Vision Transformer)
|
| 22 |
+
"""
|
| 23 |
+
this_dir = Path(__file__).resolve().parent # app/
|
| 24 |
+
repo_root = this_dir.parent # project root (/app in container)
|
| 25 |
+
models_dir = repo_root / "models"
|
| 26 |
+
fine_tuned_dir = models_dir / "fine_tuned_vit"
|
| 27 |
+
|
| 28 |
+
# Try to load fine-tuned model first (trained on FER2013 for better happy/surprise detection)
|
| 29 |
+
# Unless force_model is 'base'
|
| 30 |
+
if force_model != 'base':
|
| 31 |
+
try:
|
| 32 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
| 33 |
+
|
| 34 |
+
# Check if fine-tuned model exists
|
| 35 |
+
if fine_tuned_dir.exists() and (fine_tuned_dir / "model.safetensors").exists():
|
| 36 |
+
print(f"[MODEL] 🎯 Loading Asripa model (FER2013 Enhanced): {fine_tuned_dir}")
|
| 37 |
+
print(f"[MODEL] Accuracy: 78.26% (fine-tuned on FER2013)")
|
| 38 |
+
print(f"[MODEL] Optimized for happy/surprise detection!")
|
| 39 |
+
|
| 40 |
+
processor = AutoImageProcessor.from_pretrained(
|
| 41 |
+
str(fine_tuned_dir),
|
| 42 |
+
local_files_only=True
|
| 43 |
+
)
|
| 44 |
+
model = AutoModelForImageClassification.from_pretrained(
|
| 45 |
+
str(fine_tuned_dir),
|
| 46 |
+
local_files_only=True,
|
| 47 |
+
low_cpu_mem_usage=True
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Get labels from model config
|
| 51 |
+
raw_labels = [model.config.id2label[i] for i in range(len(model.config.id2label))]
|
| 52 |
+
print(f"[MODEL] Raw labels from model config: {raw_labels}")
|
| 53 |
+
|
| 54 |
+
# Normalize label names to match our format (lowercase, standardize)
|
| 55 |
+
label_map = {
|
| 56 |
+
'anger': 'angry',
|
| 57 |
+
'disgust': 'disgust',
|
| 58 |
+
'fear': 'fear',
|
| 59 |
+
'happy': 'happy',
|
| 60 |
+
'neutral': 'neutral',
|
| 61 |
+
'sad': 'sad',
|
| 62 |
+
'surprise': 'surprise',
|
| 63 |
+
'contempt': 'contempt'
|
| 64 |
+
}
|
| 65 |
+
labels = [label_map.get(label.lower(), label.lower()) for label in raw_labels]
|
| 66 |
+
print(f"[MODEL] Normalized labels: {labels}")
|
| 67 |
+
|
| 68 |
+
print(f"[MODEL] ✅ Fine-tuned ViT model loaded successfully!")
|
| 69 |
+
return {
|
| 70 |
+
'model': model,
|
| 71 |
+
'processor': processor,
|
| 72 |
+
'type': 'vit'
|
| 73 |
+
}, labels, "asripa-vit-78.26%", 'vit'
|
| 74 |
+
else:
|
| 75 |
+
if force_model == 'fine-tuned':
|
| 76 |
+
print(f"[MODEL] ⚠️ Fine-tuned model requested but not found!")
|
| 77 |
+
raise FileNotFoundError("Fine-tuned model not found")
|
| 78 |
+
print(f"[MODEL] Fine-tuned model not found, using base model...")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
if force_model == 'fine-tuned':
|
| 81 |
+
print(f"[MODEL] ⚠️ Failed to load fine-tuned model: {e}")
|
| 82 |
+
raise
|
| 83 |
+
print(f"[MODEL] ⚠️ Failed to load fine-tuned model: {e}")
|
| 84 |
+
print(f"[MODEL] Falling back to base HardlyHumans model...")
|
| 85 |
+
|
| 86 |
+
# Fall back to base HardlyHumans ViT model (best accuracy - 92.2%)
|
| 87 |
+
try:
|
| 88 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
| 89 |
+
|
| 90 |
+
model_id = "HardlyHumans/Facial-expression-detection"
|
| 91 |
+
print(f"[MODEL] Loading Base Model: {model_id}")
|
| 92 |
+
print(f"[MODEL] Accuracy: 92.2% - BASE MODEL")
|
| 93 |
+
print(f"[MODEL] Downloading from HuggingFace if not cached...")
|
| 94 |
+
|
| 95 |
+
# Load from HuggingFace - will download and cache automatically
|
| 96 |
+
# Use low_cpu_mem_usage to reduce memory footprint during loading
|
| 97 |
+
processor = AutoImageProcessor.from_pretrained(
|
| 98 |
+
model_id,
|
| 99 |
+
cache_dir=str(models_dir),
|
| 100 |
+
local_files_only=False # Allow download if not cached
|
| 101 |
+
)
|
| 102 |
+
model = AutoModelForImageClassification.from_pretrained(
|
| 103 |
+
model_id,
|
| 104 |
+
cache_dir=str(models_dir),
|
| 105 |
+
local_files_only=False, # Allow download if not cached
|
| 106 |
+
low_cpu_mem_usage=True # Reduce memory usage during loading
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Get labels from model config
|
| 110 |
+
raw_labels = [model.config.id2label[i] for i in range(len(model.config.id2label))]
|
| 111 |
+
print(f"[MODEL] Raw labels from model config: {raw_labels}")
|
| 112 |
+
print(f"[MODEL] Label mapping (id2label): {model.config.id2label}")
|
| 113 |
+
|
| 114 |
+
# Normalize label names to match our format (lowercase, standardize)
|
| 115 |
+
label_map = {
|
| 116 |
+
'anger': 'angry',
|
| 117 |
+
'disgust': 'disgust',
|
| 118 |
+
'fear': 'fear',
|
| 119 |
+
'happy': 'happy',
|
| 120 |
+
'neutral': 'neutral',
|
| 121 |
+
'sad': 'sad',
|
| 122 |
+
'surprise': 'surprise',
|
| 123 |
+
'contempt': 'contempt' # New emotion in this model
|
| 124 |
+
}
|
| 125 |
+
labels = [label_map.get(label.lower(), label.lower()) for label in raw_labels]
|
| 126 |
+
print(f"[MODEL] Normalized labels: {labels}")
|
| 127 |
+
|
| 128 |
+
print(f"[MODEL] ✅ ViT model loaded successfully!")
|
| 129 |
+
return {
|
| 130 |
+
'model': model,
|
| 131 |
+
'processor': processor,
|
| 132 |
+
'type': 'vit'
|
| 133 |
+
}, labels, "base-vit-92.2%", 'vit'
|
| 134 |
+
except ImportError as e:
|
| 135 |
+
print(f"[MODEL] ❌ transformers library not installed: {e}")
|
| 136 |
+
print("[MODEL] Install with: pip install transformers torch")
|
| 137 |
+
print("[MODEL] Falling back to Keras model...")
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(f"[MODEL] ❌ Failed to load ViT model: {e}")
|
| 140 |
+
print(f"[MODEL] Error type: {type(e).__name__}")
|
| 141 |
+
print(f"[MODEL] Error message: {str(e)}")
|
| 142 |
+
import traceback
|
| 143 |
+
print(f"[MODEL] Full traceback:")
|
| 144 |
+
print(traceback.format_exc())
|
| 145 |
+
print("[MODEL] ⚠️ Falling back to Keras model (lower accuracy)...")
|
| 146 |
+
|
| 147 |
+
# Fall back to Keras models
|
| 148 |
+
try:
|
| 149 |
+
from tensorflow.keras.models import load_model
|
| 150 |
+
except ImportError:
|
| 151 |
+
raise ImportError("Neither transformers nor tensorflow.keras available. Install one of them.")
|
| 152 |
+
|
| 153 |
+
candidate_names = ["emotion_model.keras", "emotion_model.h5", "emotion_model.hdf5"]
|
| 154 |
+
model_path = None
|
| 155 |
+
for name in candidate_names:
|
| 156 |
+
p = models_dir / name
|
| 157 |
+
if p.exists():
|
| 158 |
+
model_path = str(p)
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
if model_path is None:
|
| 162 |
+
raise FileNotFoundError(f"No model file found in {models_dir}. Please add emotion_model.keras or emotion_model.h5")
|
| 163 |
+
|
| 164 |
+
print(f"[MODEL] Loading Keras model: {model_path}")
|
| 165 |
+
model = load_model(model_path)
|
| 166 |
+
|
| 167 |
+
# Load labels if available
|
| 168 |
+
labels_path = models_dir / "labels.json"
|
| 169 |
+
labels = DEFAULT_LABELS
|
| 170 |
+
if labels_path.exists():
|
| 171 |
+
try:
|
| 172 |
+
with labels_path.open("r", encoding="utf-8") as f:
|
| 173 |
+
labels = json.load(f)
|
| 174 |
+
except Exception:
|
| 175 |
+
labels = DEFAULT_LABELS
|
| 176 |
+
|
| 177 |
+
# Model version
|
| 178 |
+
version_path = models_dir / "MODEL_VERSION.txt"
|
| 179 |
+
version = "v_unknown"
|
| 180 |
+
if os.path.exists(version_path):
|
| 181 |
+
try:
|
| 182 |
+
with open(version_path, "r", encoding="utf-8") as f:
|
| 183 |
+
version = f.read().strip()
|
| 184 |
+
except Exception:
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
return model, labels, version, 'keras'
|
app/rate_limiter.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple in-memory rate limiter for API endpoints.
|
| 3 |
+
For production, consider using Redis-based rate limiting.
|
| 4 |
+
"""
|
| 5 |
+
import time
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from typing import Dict, Tuple
|
| 8 |
+
from threading import Lock
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RateLimiter:
|
| 12 |
+
"""
|
| 13 |
+
Simple token bucket rate limiter.
|
| 14 |
+
Thread-safe for basic use cases.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, max_requests: int = 100, window_seconds: int = 60):
|
| 18 |
+
"""
|
| 19 |
+
Args:
|
| 20 |
+
max_requests: Maximum requests allowed in the time window
|
| 21 |
+
window_seconds: Time window in seconds
|
| 22 |
+
"""
|
| 23 |
+
self.max_requests = max_requests
|
| 24 |
+
self.window_seconds = window_seconds
|
| 25 |
+
self.requests: Dict[str, list] = defaultdict(list)
|
| 26 |
+
self.lock = Lock()
|
| 27 |
+
|
| 28 |
+
def is_allowed(self, identifier: str) -> Tuple[bool, int]:
|
| 29 |
+
"""
|
| 30 |
+
Check if a request is allowed.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
identifier: Unique identifier (e.g., IP address, user ID)
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Tuple of (is_allowed, remaining_requests)
|
| 37 |
+
"""
|
| 38 |
+
current_time = time.time()
|
| 39 |
+
|
| 40 |
+
with self.lock:
|
| 41 |
+
# Clean old requests outside the window
|
| 42 |
+
window_start = current_time - self.window_seconds
|
| 43 |
+
self.requests[identifier] = [
|
| 44 |
+
req_time for req_time in self.requests[identifier]
|
| 45 |
+
if req_time > window_start
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
# Check if limit exceeded
|
| 49 |
+
if len(self.requests[identifier]) >= self.max_requests:
|
| 50 |
+
remaining = 0
|
| 51 |
+
return False, remaining
|
| 52 |
+
|
| 53 |
+
# Add current request
|
| 54 |
+
self.requests[identifier].append(current_time)
|
| 55 |
+
remaining = self.max_requests - len(self.requests[identifier])
|
| 56 |
+
|
| 57 |
+
return True, remaining
|
| 58 |
+
|
| 59 |
+
def reset(self, identifier: str = None):
|
| 60 |
+
"""Reset rate limit for an identifier or all identifiers."""
|
| 61 |
+
with self.lock:
|
| 62 |
+
if identifier:
|
| 63 |
+
self.requests.pop(identifier, None)
|
| 64 |
+
else:
|
| 65 |
+
self.requests.clear()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# Global rate limiters for different endpoints
|
| 69 |
+
detect_limiter = RateLimiter(max_requests=30, window_seconds=60) # 30 requests per minute
|
| 70 |
+
logs_limiter = RateLimiter(max_requests=100, window_seconds=60) # 100 requests per minute
|
| 71 |
+
images_limiter = RateLimiter(max_requests=200, window_seconds=60) # 200 requests per minute
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_client_identifier(request) -> str:
|
| 75 |
+
"""
|
| 76 |
+
Get a unique identifier for rate limiting.
|
| 77 |
+
Uses IP address by default.
|
| 78 |
+
"""
|
| 79 |
+
# Try to get real IP (behind proxy)
|
| 80 |
+
forwarded_for = request.headers.get("X-Forwarded-For")
|
| 81 |
+
if forwarded_for:
|
| 82 |
+
# Take the first IP in the chain
|
| 83 |
+
return forwarded_for.split(",")[0].strip()
|
| 84 |
+
|
| 85 |
+
real_ip = request.headers.get("X-Real-IP")
|
| 86 |
+
if real_ip:
|
| 87 |
+
return real_ip
|
| 88 |
+
|
| 89 |
+
return request.remote_addr or "unknown"
|
app/utils.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/utils.py
|
| 2 |
+
import os
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
def _enhance_for_detection(gray: np.ndarray) -> np.ndarray:
|
| 8 |
+
"""
|
| 9 |
+
Apply light preprocessing to improve face detection on low-contrast or slightly blurry images.
|
| 10 |
+
Uses CLAHE (adaptive histogram equalization) and a mild bilateral filter.
|
| 11 |
+
"""
|
| 12 |
+
# CLAHE for contrast
|
| 13 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 14 |
+
enhanced = clahe.apply(gray)
|
| 15 |
+
|
| 16 |
+
# Mild bilateral filtering to reduce noise while preserving edges (helps detection on some images)
|
| 17 |
+
enhanced = cv2.bilateralFilter(enhanced, d=5, sigmaColor=75, sigmaSpace=75)
|
| 18 |
+
return enhanced
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def preprocess_face(
|
| 22 |
+
image_path: str,
|
| 23 |
+
target_size: Tuple[int, int] = (48, 48),
|
| 24 |
+
detect_max_dim: int = 800,
|
| 25 |
+
pad_ratio: float = 0.25, # Increased from 0.15 to 0.25 to preserve more context (eyes, eyebrows, mouth area)
|
| 26 |
+
) -> Tuple[Optional[np.ndarray], Optional[str]]:
|
| 27 |
+
"""
|
| 28 |
+
Load an image at image_path, detect a face and return a preprocessed array:
|
| 29 |
+
- shape: (1, H, W, 1)
|
| 30 |
+
- dtype: np.float32
|
| 31 |
+
- values scaled to [0,1]
|
| 32 |
+
|
| 33 |
+
If no face detected or on error, returns (None, None).
|
| 34 |
+
|
| 35 |
+
Parameters:
|
| 36 |
+
- target_size: size expected by the model (height, width).
|
| 37 |
+
- detect_max_dim: maximum size (longest side) used for the detection pass to speed up detection.
|
| 38 |
+
- pad_ratio: fraction of face box to pad on each side (helps avoid tight crops).
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
- (face_array, used_filename)
|
| 42 |
+
"""
|
| 43 |
+
try:
|
| 44 |
+
img = cv2.imread(image_path)
|
| 45 |
+
if img is None:
|
| 46 |
+
return None, None
|
| 47 |
+
|
| 48 |
+
h0, w0 = img.shape[:2]
|
| 49 |
+
# grayscale copy for detection
|
| 50 |
+
gray_full = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 51 |
+
|
| 52 |
+
# Downscale for faster detection if image is huge
|
| 53 |
+
scale = 1.0
|
| 54 |
+
max_side = max(w0, h0)
|
| 55 |
+
if max_side > detect_max_dim:
|
| 56 |
+
scale = detect_max_dim / float(max_side)
|
| 57 |
+
small = cv2.resize(gray_full, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR)
|
| 58 |
+
else:
|
| 59 |
+
small = gray_full.copy()
|
| 60 |
+
|
| 61 |
+
# Try to enhance small image for better detection on blurry photos
|
| 62 |
+
small_enh = _enhance_for_detection(small)
|
| 63 |
+
|
| 64 |
+
# Try multiple cascade classifiers for better detection
|
| 65 |
+
cascade_paths = [
|
| 66 |
+
"haarcascade_frontalface_default.xml",
|
| 67 |
+
"haarcascade_frontalface_alt.xml",
|
| 68 |
+
"haarcascade_frontalface_alt2.xml",
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
faces = []
|
| 72 |
+
|
| 73 |
+
# Try each cascade with progressively more permissive parameters
|
| 74 |
+
for cascade_name in cascade_paths:
|
| 75 |
+
if len(faces) > 0:
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name)
|
| 80 |
+
if face_cascade.empty():
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
# Attempt 1: Standard detection
|
| 84 |
+
faces = face_cascade.detectMultiScale(
|
| 85 |
+
small_enh,
|
| 86 |
+
scaleFactor=1.1,
|
| 87 |
+
minNeighbors=5,
|
| 88 |
+
minSize=(30, 30),
|
| 89 |
+
flags=cv2.CASCADE_SCALE_IMAGE,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Attempt 2: More permissive (helps blurry / odd-angle photos)
|
| 93 |
+
if len(faces) == 0:
|
| 94 |
+
faces = face_cascade.detectMultiScale(
|
| 95 |
+
small_enh,
|
| 96 |
+
scaleFactor=1.05,
|
| 97 |
+
minNeighbors=3,
|
| 98 |
+
minSize=(20, 20),
|
| 99 |
+
flags=cv2.CASCADE_SCALE_IMAGE,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Attempt 3: Even more permissive (for challenging conditions)
|
| 103 |
+
if len(faces) == 0:
|
| 104 |
+
faces = face_cascade.detectMultiScale(
|
| 105 |
+
small_enh,
|
| 106 |
+
scaleFactor=1.03,
|
| 107 |
+
minNeighbors=2,
|
| 108 |
+
minSize=(15, 15),
|
| 109 |
+
flags=cv2.CASCADE_SCALE_IMAGE,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
except Exception:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
# If still nothing, try on original (non-enhanced) image
|
| 116 |
+
if len(faces) == 0:
|
| 117 |
+
try:
|
| 118 |
+
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
|
| 119 |
+
if not face_cascade.empty():
|
| 120 |
+
# Sometimes enhancement hurts detection, try original
|
| 121 |
+
faces = face_cascade.detectMultiScale(
|
| 122 |
+
small,
|
| 123 |
+
scaleFactor=1.05,
|
| 124 |
+
minNeighbors=3,
|
| 125 |
+
minSize=(20, 20),
|
| 126 |
+
flags=cv2.CASCADE_SCALE_IMAGE,
|
| 127 |
+
)
|
| 128 |
+
except Exception:
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
if len(faces) == 0:
|
| 132 |
+
return None, None
|
| 133 |
+
|
| 134 |
+
# Choose the largest detected face (usually the main subject)
|
| 135 |
+
faces = sorted(faces, key=lambda r: r[2] * r[3], reverse=True)
|
| 136 |
+
(x_s, y_s, w_s, h_s) = faces[0]
|
| 137 |
+
|
| 138 |
+
# Map coordinates back to original image scale
|
| 139 |
+
x = int(x_s / scale)
|
| 140 |
+
y = int(y_s / scale)
|
| 141 |
+
w = int(w_s / scale)
|
| 142 |
+
h = int(h_s / scale)
|
| 143 |
+
|
| 144 |
+
# Pad bounding box slightly (pad_ratio of face size)
|
| 145 |
+
pad_w = int(w * pad_ratio)
|
| 146 |
+
pad_h = int(h * pad_ratio)
|
| 147 |
+
x1 = max(0, x - pad_w)
|
| 148 |
+
y1 = max(0, y - pad_h)
|
| 149 |
+
x2 = min(w0, x + w + pad_w)
|
| 150 |
+
y2 = min(h0, y + h + pad_h)
|
| 151 |
+
|
| 152 |
+
face_crop = gray_full[y1:y2, x1:x2]
|
| 153 |
+
|
| 154 |
+
# final resize to model input
|
| 155 |
+
# Use INTER_CUBIC for better quality when upscaling small faces (preserves more detail for emotion recognition)
|
| 156 |
+
face_resized = cv2.resize(face_crop, (target_size[1], target_size[0]), interpolation=cv2.INTER_CUBIC)
|
| 157 |
+
|
| 158 |
+
# ensure numeric ndarray and float32 dtype
|
| 159 |
+
face_arr = np.asarray(face_resized, dtype=np.float32)
|
| 160 |
+
|
| 161 |
+
# normalize
|
| 162 |
+
face_arr = face_arr / 255.0
|
| 163 |
+
|
| 164 |
+
# channel & batch dims -> (1, H, W, 1)
|
| 165 |
+
if face_arr.ndim == 2:
|
| 166 |
+
face_arr = np.expand_dims(face_arr, axis=-1)
|
| 167 |
+
face_arr = np.expand_dims(face_arr, axis=0)
|
| 168 |
+
|
| 169 |
+
# final sanity checks
|
| 170 |
+
if face_arr.dtype != np.float32:
|
| 171 |
+
face_arr = face_arr.astype(np.float32)
|
| 172 |
+
if not np.isfinite(face_arr).all():
|
| 173 |
+
return None, None
|
| 174 |
+
|
| 175 |
+
used_filename = os.path.basename(image_path) or "upload.jpg"
|
| 176 |
+
return face_arr, used_filename
|
| 177 |
+
|
| 178 |
+
except Exception:
|
| 179 |
+
# don't leak internals to caller; let app log exceptions if needed
|
| 180 |
+
return None, None
|
app/validators.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Request validation utilities.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from typing import Tuple, Optional
|
| 6 |
+
from werkzeug.utils import secure_filename
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def validate_image_file(file, max_size: int, allowed_extensions: tuple) -> Tuple[bool, Optional[str], Optional[str]]:
|
| 11 |
+
"""
|
| 12 |
+
Validate uploaded image file.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
file: FileStorage object from Flask
|
| 16 |
+
max_size: Maximum file size in bytes
|
| 17 |
+
allowed_extensions: Tuple of allowed extensions (e.g., (".jpg", ".png"))
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Tuple of (is_valid, error_message, sanitized_filename)
|
| 21 |
+
If valid: (True, None, filename)
|
| 22 |
+
If invalid: (False, error_message, None)
|
| 23 |
+
"""
|
| 24 |
+
if not file or not file.filename:
|
| 25 |
+
return False, "No file provided", None
|
| 26 |
+
|
| 27 |
+
# Check filename
|
| 28 |
+
filename = secure_filename(file.filename)
|
| 29 |
+
if not filename:
|
| 30 |
+
return False, "Invalid filename", None
|
| 31 |
+
|
| 32 |
+
# Check extension
|
| 33 |
+
ext = os.path.splitext(filename)[1].lower()
|
| 34 |
+
if ext not in allowed_extensions:
|
| 35 |
+
return False, f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}", None
|
| 36 |
+
|
| 37 |
+
# Check file size (if available)
|
| 38 |
+
try:
|
| 39 |
+
file.seek(0, os.SEEK_END)
|
| 40 |
+
file_size = file.tell()
|
| 41 |
+
file.seek(0) # Reset to beginning
|
| 42 |
+
|
| 43 |
+
if file_size > max_size:
|
| 44 |
+
max_mb = max_size / (1024 * 1024)
|
| 45 |
+
return False, f"File too large. Maximum size: {max_mb:.1f}MB", None
|
| 46 |
+
|
| 47 |
+
if file_size == 0:
|
| 48 |
+
return False, "File is empty", None
|
| 49 |
+
except Exception:
|
| 50 |
+
# If we can't check size, continue (will be caught by MAX_CONTENT_LENGTH)
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
# Validate it's actually an image by trying to open it
|
| 54 |
+
try:
|
| 55 |
+
file.seek(0)
|
| 56 |
+
img = Image.open(file)
|
| 57 |
+
img.verify() # Verify it's a valid image
|
| 58 |
+
file.seek(0) # Reset after verification
|
| 59 |
+
except Exception as e:
|
| 60 |
+
return False, f"Invalid image file: {str(e)}", None
|
| 61 |
+
|
| 62 |
+
return True, None, filename
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def validate_pagination_params(limit: Optional[str], offset: Optional[str]) -> Tuple[int, int, Optional[str]]:
|
| 66 |
+
"""
|
| 67 |
+
Validate pagination parameters.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Tuple of (limit, offset, error_message)
|
| 71 |
+
"""
|
| 72 |
+
try:
|
| 73 |
+
limit_val = int(limit) if limit else 20
|
| 74 |
+
limit_val = max(1, min(200, limit_val))
|
| 75 |
+
except ValueError:
|
| 76 |
+
return 20, 0, "Invalid limit parameter. Must be an integer."
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
offset_val = int(offset) if offset else 0
|
| 80 |
+
offset_val = max(0, offset_val)
|
| 81 |
+
except ValueError:
|
| 82 |
+
return limit_val, 0, "Invalid offset parameter. Must be an integer."
|
| 83 |
+
|
| 84 |
+
return limit_val, offset_val, None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def validate_confidence_range(min_conf: Optional[str], max_conf: Optional[str]) -> Tuple[Optional[float], Optional[float], Optional[str]]:
|
| 88 |
+
"""
|
| 89 |
+
Validate confidence range parameters.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Tuple of (min_confidence, max_confidence, error_message)
|
| 93 |
+
"""
|
| 94 |
+
min_val = None
|
| 95 |
+
max_val = None
|
| 96 |
+
|
| 97 |
+
if min_conf:
|
| 98 |
+
try:
|
| 99 |
+
min_val = float(min_conf)
|
| 100 |
+
if not 0 <= min_val <= 1:
|
| 101 |
+
return None, None, "min_confidence must be between 0 and 1"
|
| 102 |
+
except ValueError:
|
| 103 |
+
return None, None, "Invalid min_confidence parameter. Must be a number."
|
| 104 |
+
|
| 105 |
+
if max_conf:
|
| 106 |
+
try:
|
| 107 |
+
max_val = float(max_conf)
|
| 108 |
+
if not 0 <= max_val <= 1:
|
| 109 |
+
return None, None, "max_confidence must be between 0 and 1"
|
| 110 |
+
except ValueError:
|
| 111 |
+
return None, None, "Invalid max_confidence parameter. Must be a number."
|
| 112 |
+
|
| 113 |
+
if min_val is not None and max_val is not None and min_val > max_val:
|
| 114 |
+
return None, None, "min_confidence cannot be greater than max_confidence"
|
| 115 |
+
|
| 116 |
+
return min_val, max_val, None
|
app/vit_utils.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/vit_utils.py
|
| 2 |
+
"""
|
| 3 |
+
Utilities for Vision Transformer (ViT) model preprocessing and prediction.
|
| 4 |
+
"""
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from typing import Optional, Tuple, Dict, Any
|
| 9 |
+
from app.utils import preprocess_face # Reuse face detection
|
| 10 |
+
|
| 11 |
+
def preprocess_face_for_vit(
|
| 12 |
+
image_path: str,
|
| 13 |
+
detect_max_dim: int = 800,
|
| 14 |
+
pad_ratio: float = 0.35, # Increased to 0.35 to include more facial context - helps with happy detection (smile needs more context)
|
| 15 |
+
) -> Tuple[Optional[Image.Image], Optional[str]]:
|
| 16 |
+
"""
|
| 17 |
+
Preprocess face for Vision Transformer model.
|
| 18 |
+
ViT needs RGB images at 224x224, not grayscale 48x48.
|
| 19 |
+
|
| 20 |
+
Returns: (PIL Image, filename) or (None, None) if no face detected
|
| 21 |
+
"""
|
| 22 |
+
# First detect and crop face (reuse existing detection logic)
|
| 23 |
+
# But we'll keep it in RGB and resize to 224x224
|
| 24 |
+
try:
|
| 25 |
+
img = cv2.imread(image_path)
|
| 26 |
+
if img is None:
|
| 27 |
+
return None, None
|
| 28 |
+
|
| 29 |
+
h0, w0 = img.shape[:2]
|
| 30 |
+
# Keep RGB for ViT (not grayscale)
|
| 31 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 32 |
+
gray_full = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 33 |
+
|
| 34 |
+
# Downscale for faster detection if image is huge
|
| 35 |
+
scale = 1.0
|
| 36 |
+
max_side = max(w0, h0)
|
| 37 |
+
if max_side > detect_max_dim:
|
| 38 |
+
scale = detect_max_dim / float(max_side)
|
| 39 |
+
small = cv2.resize(gray_full, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR)
|
| 40 |
+
else:
|
| 41 |
+
small = gray_full.copy()
|
| 42 |
+
|
| 43 |
+
# Enhance for detection
|
| 44 |
+
from app.utils import _enhance_for_detection
|
| 45 |
+
small_enh = _enhance_for_detection(small)
|
| 46 |
+
|
| 47 |
+
# Optimized face detection: 2 cascades × 2 param sets = 4 attempts (fast)
|
| 48 |
+
# Then fallback to 3rd cascade if needed = +2 attempts (total 6 max)
|
| 49 |
+
# This balances speed (4 attempts) with reliability (6 attempts if needed)
|
| 50 |
+
cascade_paths_primary = [
|
| 51 |
+
"haarcascade_frontalface_default.xml", # Most reliable
|
| 52 |
+
"haarcascade_frontalface_alt.xml", # Good fallback
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
cascade_paths_fallback = [
|
| 56 |
+
"haarcascade_frontalface_alt2.xml", # Last resort
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
faces = []
|
| 60 |
+
|
| 61 |
+
# Primary: Try 2 cascades with 2 param sets each (4 attempts, fast path)
|
| 62 |
+
for cascade_name in cascade_paths_primary:
|
| 63 |
+
if len(faces) > 0:
|
| 64 |
+
break
|
| 65 |
+
try:
|
| 66 |
+
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name)
|
| 67 |
+
if face_cascade.empty():
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
# Attempt 1: Most common successful params (catches 90%+ of faces)
|
| 71 |
+
faces = face_cascade.detectMultiScale(
|
| 72 |
+
small_enh,
|
| 73 |
+
scaleFactor=1.05,
|
| 74 |
+
minNeighbors=3,
|
| 75 |
+
minSize=(20, 20),
|
| 76 |
+
flags=cv2.CASCADE_SCALE_IMAGE,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Attempt 2: More permissive (catches challenging cases)
|
| 80 |
+
if len(faces) == 0:
|
| 81 |
+
faces = face_cascade.detectMultiScale(
|
| 82 |
+
small_enh,
|
| 83 |
+
scaleFactor=1.03,
|
| 84 |
+
minNeighbors=2,
|
| 85 |
+
minSize=(15, 15),
|
| 86 |
+
flags=cv2.CASCADE_SCALE_IMAGE,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
except Exception:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
# Fallback: Only try 3rd cascade if primary failed (adds 2 more attempts)
|
| 93 |
+
if len(faces) == 0:
|
| 94 |
+
for cascade_name in cascade_paths_fallback:
|
| 95 |
+
if len(faces) > 0:
|
| 96 |
+
break
|
| 97 |
+
try:
|
| 98 |
+
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name)
|
| 99 |
+
if face_cascade.empty():
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
# Try with permissive params
|
| 103 |
+
for scale_factor, min_neighbors, min_size in [
|
| 104 |
+
(1.05, 3, (20, 20)),
|
| 105 |
+
(1.03, 2, (15, 15)),
|
| 106 |
+
]:
|
| 107 |
+
faces = face_cascade.detectMultiScale(
|
| 108 |
+
small_enh,
|
| 109 |
+
scaleFactor=scale_factor,
|
| 110 |
+
minNeighbors=min_neighbors,
|
| 111 |
+
minSize=min_size,
|
| 112 |
+
flags=cv2.CASCADE_SCALE_IMAGE,
|
| 113 |
+
)
|
| 114 |
+
if len(faces) > 0:
|
| 115 |
+
break
|
| 116 |
+
except Exception:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
# Fallback 1: Try on original (non-enhanced) image if enhanced failed
|
| 120 |
+
# Only try once with best params (don't waste time on multiple attempts)
|
| 121 |
+
if len(faces) == 0:
|
| 122 |
+
try:
|
| 123 |
+
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
|
| 124 |
+
if not face_cascade.empty():
|
| 125 |
+
# Single attempt with most successful params (faster than trying multiple)
|
| 126 |
+
faces = face_cascade.detectMultiScale(
|
| 127 |
+
small, # Use original, not enhanced
|
| 128 |
+
scaleFactor=1.05,
|
| 129 |
+
minNeighbors=3,
|
| 130 |
+
minSize=(20, 20),
|
| 131 |
+
flags=cv2.CASCADE_SCALE_IMAGE,
|
| 132 |
+
)
|
| 133 |
+
except Exception:
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
# Fallback 2: Try on full-size image ONLY if:
|
| 137 |
+
# 1. Still no face found
|
| 138 |
+
# 2. Image was actually downscaled (max_side > 800)
|
| 139 |
+
# 3. Scale is significantly reduced (scale < 0.5, meaning image is 2x+ larger)
|
| 140 |
+
# This prevents slow full-size detection on images that are only slightly over 800px
|
| 141 |
+
if len(faces) == 0 and max_side > detect_max_dim and scale < 0.5:
|
| 142 |
+
try:
|
| 143 |
+
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
|
| 144 |
+
if not face_cascade.empty():
|
| 145 |
+
# Single attempt with permissive params (full-size is slow, so only try once)
|
| 146 |
+
faces = face_cascade.detectMultiScale(
|
| 147 |
+
gray_full,
|
| 148 |
+
scaleFactor=1.05,
|
| 149 |
+
minNeighbors=2,
|
| 150 |
+
minSize=(30, 30), # Larger min size for full-res
|
| 151 |
+
flags=cv2.CASCADE_SCALE_IMAGE,
|
| 152 |
+
)
|
| 153 |
+
except Exception:
|
| 154 |
+
pass
|
| 155 |
+
|
| 156 |
+
if len(faces) == 0:
|
| 157 |
+
return None, None
|
| 158 |
+
|
| 159 |
+
# Choose largest face
|
| 160 |
+
faces = sorted(faces, key=lambda r: r[2] * r[3], reverse=True)
|
| 161 |
+
(x_s, y_s, w_s, h_s) = faces[0]
|
| 162 |
+
|
| 163 |
+
# Map back to original scale (only if we used downscaled detection)
|
| 164 |
+
# If we detected on full-size image, coordinates are already correct
|
| 165 |
+
if max_side > detect_max_dim and scale < 1.0:
|
| 166 |
+
# Detection was on downscaled image
|
| 167 |
+
x = int(x_s / scale)
|
| 168 |
+
y = int(y_s / scale)
|
| 169 |
+
w = int(w_s / scale)
|
| 170 |
+
h = int(h_s / scale)
|
| 171 |
+
else:
|
| 172 |
+
# Detection was on full-size or original scale
|
| 173 |
+
x = x_s
|
| 174 |
+
y = y_s
|
| 175 |
+
w = w_s
|
| 176 |
+
h = h_s
|
| 177 |
+
|
| 178 |
+
# Pad bounding box
|
| 179 |
+
pad_w = int(w * pad_ratio)
|
| 180 |
+
pad_h = int(h * pad_ratio)
|
| 181 |
+
x1 = max(0, x - pad_w)
|
| 182 |
+
y1 = max(0, y - pad_h)
|
| 183 |
+
x2 = min(w0, x + w + pad_w)
|
| 184 |
+
y2 = min(h0, y + h + pad_h)
|
| 185 |
+
|
| 186 |
+
# Crop face from RGB image (not grayscale)
|
| 187 |
+
face_crop = img_rgb[y1:y2, x1:x2]
|
| 188 |
+
|
| 189 |
+
# Convert to PIL Image and resize to 224x224 (ViT input size)
|
| 190 |
+
# Use BICUBIC for best quality (emotion recognition needs detail)
|
| 191 |
+
# Note: ViT processor handles normalization, so we don't apply CLAHE here
|
| 192 |
+
# CLAHE can interfere with the model's expected input distribution
|
| 193 |
+
face_pil = Image.fromarray(face_crop)
|
| 194 |
+
face_pil = face_pil.resize((224, 224), Image.Resampling.BICUBIC)
|
| 195 |
+
|
| 196 |
+
import os
|
| 197 |
+
used_filename = os.path.basename(image_path) or "upload.jpg"
|
| 198 |
+
return face_pil, used_filename
|
| 199 |
+
|
| 200 |
+
except Exception as e:
|
| 201 |
+
import logging
|
| 202 |
+
logger = logging.getLogger(__name__)
|
| 203 |
+
logger.exception(f"Exception in preprocess_face_for_vit for {image_path}: {e}")
|
| 204 |
+
return None, None
|
| 205 |
+
|
| 206 |
+
def predict_with_vit(
|
| 207 |
+
model_dict: Dict[str, Any],
|
| 208 |
+
image: Image.Image,
|
| 209 |
+
labels: list
|
| 210 |
+
) -> Tuple[int, float, Dict[str, float]]:
|
| 211 |
+
"""
|
| 212 |
+
Run prediction using Vision Transformer model.
|
| 213 |
+
Enhanced for better accuracy with image preprocessing.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
model_dict: {'model': model, 'processor': processor, 'type': 'vit'}
|
| 217 |
+
image: PIL Image (224x224 RGB)
|
| 218 |
+
labels: List of emotion labels
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
(predicted_index, confidence, all_probabilities_dict)
|
| 222 |
+
"""
|
| 223 |
+
processor = model_dict['processor']
|
| 224 |
+
model = model_dict['model']
|
| 225 |
+
|
| 226 |
+
# Ensure image is RGB (some images might be RGBA or grayscale)
|
| 227 |
+
if image.mode != 'RGB':
|
| 228 |
+
image = image.convert('RGB')
|
| 229 |
+
|
| 230 |
+
# Preprocess image for ViT (processor handles normalization)
|
| 231 |
+
inputs = processor(image, return_tensors="pt")
|
| 232 |
+
|
| 233 |
+
# Run prediction - optimized for speed
|
| 234 |
+
import torch
|
| 235 |
+
import torch.nn.functional as F
|
| 236 |
+
|
| 237 |
+
model.eval()
|
| 238 |
+
# Use inference_mode() instead of no_grad() - faster for inference-only
|
| 239 |
+
with torch.inference_mode(): # Faster than no_grad() for pure inference
|
| 240 |
+
outputs = model(**inputs)
|
| 241 |
+
logits = outputs.logits
|
| 242 |
+
|
| 243 |
+
# Get probabilities (softmax) - optimized conversion
|
| 244 |
+
probs = F.softmax(logits, dim=-1)
|
| 245 |
+
probs_np = probs[0].cpu().numpy() # Direct indexing, no detach needed in inference_mode
|
| 246 |
+
|
| 247 |
+
# Get predicted class
|
| 248 |
+
predicted_idx = int(torch.argmax(logits, dim=-1).item())
|
| 249 |
+
confidence = float(probs_np[predicted_idx])
|
| 250 |
+
|
| 251 |
+
# Create probabilities dict - use model's id2label directly to ensure correct mapping
|
| 252 |
+
all_probs = {}
|
| 253 |
+
model = model_dict['model']
|
| 254 |
+
for i, prob in enumerate(probs_np):
|
| 255 |
+
# Use model's id2label for accurate label mapping
|
| 256 |
+
if hasattr(model, 'config') and hasattr(model.config, 'id2label'):
|
| 257 |
+
raw_label = model.config.id2label.get(i, f"class_{i}")
|
| 258 |
+
# Normalize label name
|
| 259 |
+
label_map = {
|
| 260 |
+
'anger': 'angry',
|
| 261 |
+
'disgust': 'disgust',
|
| 262 |
+
'fear': 'fear',
|
| 263 |
+
'happy': 'happy',
|
| 264 |
+
'neutral': 'neutral',
|
| 265 |
+
'sad': 'sad',
|
| 266 |
+
'surprise': 'surprise',
|
| 267 |
+
'contempt': 'contempt'
|
| 268 |
+
}
|
| 269 |
+
normalized_label = label_map.get(raw_label.lower(), raw_label.lower())
|
| 270 |
+
all_probs[normalized_label] = float(prob)
|
| 271 |
+
elif i < len(labels):
|
| 272 |
+
all_probs[labels[i]] = float(prob)
|
| 273 |
+
else:
|
| 274 |
+
all_probs[f"class_{i}"] = float(prob)
|
| 275 |
+
|
| 276 |
+
# Post-processing: If happy probability is reasonable (>0.05) but contempt/neutral is high,
|
| 277 |
+
# and happy is in top 3, boost happy probability (model has known happy/contempt confusion)
|
| 278 |
+
happy_prob = all_probs.get('happy', 0.0)
|
| 279 |
+
contempt_prob = all_probs.get('contempt', 0.0)
|
| 280 |
+
neutral_prob = all_probs.get('neutral', 0.0)
|
| 281 |
+
|
| 282 |
+
# If happy is in top 3 probabilities and contempt/neutral is suspiciously high
|
| 283 |
+
sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
|
| 284 |
+
top_3_emotions = [e[0] for e in sorted_probs[:3]]
|
| 285 |
+
|
| 286 |
+
if 'happy' in top_3_emotions and happy_prob > 0.05:
|
| 287 |
+
# If contempt or neutral is highest but happy is close, boost happy
|
| 288 |
+
if (contempt_prob > 0.4 or neutral_prob > 0.4) and happy_prob > 0.05:
|
| 289 |
+
# Boost happy by 30% (helps correct misclassifications)
|
| 290 |
+
boost_factor = 1.3
|
| 291 |
+
boosted_happy = min(1.0, happy_prob * boost_factor)
|
| 292 |
+
|
| 293 |
+
# Reduce contempt/neutral proportionally to maintain probability sum
|
| 294 |
+
reduction = (boosted_happy - happy_prob) / 2
|
| 295 |
+
new_contempt = max(0.0, contempt_prob - reduction)
|
| 296 |
+
new_neutral = max(0.0, neutral_prob - reduction)
|
| 297 |
+
|
| 298 |
+
# Update probabilities
|
| 299 |
+
all_probs['happy'] = boosted_happy
|
| 300 |
+
all_probs['contempt'] = new_contempt
|
| 301 |
+
all_probs['neutral'] = new_neutral
|
| 302 |
+
|
| 303 |
+
# Re-normalize to ensure sum is ~1.0
|
| 304 |
+
total = sum(all_probs.values())
|
| 305 |
+
if total > 0:
|
| 306 |
+
all_probs = {k: v / total for k, v in all_probs.items()}
|
| 307 |
+
|
| 308 |
+
# Recalculate predicted class after boosting - find emotion with highest prob
|
| 309 |
+
new_top_emotion = max(all_probs.items(), key=lambda x: x[1])[0]
|
| 310 |
+
|
| 311 |
+
# Find index in labels list
|
| 312 |
+
if new_top_emotion in labels:
|
| 313 |
+
predicted_idx = labels.index(new_top_emotion)
|
| 314 |
+
confidence = all_probs[new_top_emotion]
|
| 315 |
+
print(f"[VIT] Post-processing: Boosted happy from {happy_prob:.3f} to {all_probs.get('happy', 0.0):.3f}, new prediction: {new_top_emotion}")
|
| 316 |
+
else:
|
| 317 |
+
# Fallback to original prediction if label not found
|
| 318 |
+
print(f"[VIT] Post-processing: Boosted happy but couldn't find label {new_top_emotion} in labels list")
|
| 319 |
+
|
| 320 |
+
print(f"[VIT] Predicted index: {predicted_idx}, Raw label from model: {model.config.id2label.get(predicted_idx, 'unknown')}")
|
| 321 |
+
|
| 322 |
+
return predicted_idx, confidence, all_probs
|
| 323 |
+
|
entrypoint_hf.sh
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
set -eu
|
| 3 |
+
|
| 4 |
+
# Where the app expects the model inside the container
|
| 5 |
+
MODEL_PATH="/app/models/emotion_model.keras"
|
| 6 |
+
|
| 7 |
+
# Public release URL (change if you host elsewhere)
|
| 8 |
+
MODEL_URL="https://github.com/iyinoluwAA/Emotion-detection/releases/download/v1.0.0/emotion_model.keras"
|
| 9 |
+
|
| 10 |
+
# Ensure models dir exists
|
| 11 |
+
mkdir -p "$(dirname "$MODEL_PATH")"
|
| 12 |
+
|
| 13 |
+
if [ ! -f "$MODEL_PATH" ]; then
|
| 14 |
+
echo "Model not found at $MODEL_PATH — attempting download from $MODEL_URL"
|
| 15 |
+
if command -v curl >/dev/null 2>&1; then
|
| 16 |
+
curl -fSL "$MODEL_URL" -o "$MODEL_PATH" || {
|
| 17 |
+
echo "curl failed to download model"; ls -la "$(dirname "$MODEL_PATH")"; exit 1;
|
| 18 |
+
}
|
| 19 |
+
elif command -v wget >/dev/null 2>&1; then
|
| 20 |
+
wget -O "$MODEL_PATH" "$MODEL_URL" || {
|
| 21 |
+
echo "wget failed to download model"; ls -la "$(dirname "$MODEL_PATH")"; exit 1;
|
| 22 |
+
}
|
| 23 |
+
else
|
| 24 |
+
echo "No curl or wget available in the image. Install one in Dockerfile."; exit 1
|
| 25 |
+
fi
|
| 26 |
+
else
|
| 27 |
+
echo "Model already present at $MODEL_PATH"
|
| 28 |
+
fi
|
| 29 |
+
|
| 30 |
+
# ensure readable
|
| 31 |
+
chmod a+r "$MODEL_PATH" || true
|
| 32 |
+
|
| 33 |
+
# Download Asripa model (fine-tuned) if not present
|
| 34 |
+
ASRIPA_MODEL_DIR="/app/models/fine_tuned_vit"
|
| 35 |
+
ASRIPA_MODEL_ID="${ASRIPA_MODEL_ID:-HimAJ/asripa-emotion-detection}"
|
| 36 |
+
|
| 37 |
+
if [ -n "$ASRIPA_MODEL_ID" ] && [ ! -f "$ASRIPA_MODEL_DIR/model.safetensors" ]; then
|
| 38 |
+
echo "📥 Downloading Asripa model from HuggingFace..."
|
| 39 |
+
echo " Model ID: $ASRIPA_MODEL_ID"
|
| 40 |
+
mkdir -p "$ASRIPA_MODEL_DIR"
|
| 41 |
+
|
| 42 |
+
# Use Python to download (huggingface_hub is in requirements)
|
| 43 |
+
python3 -c "
|
| 44 |
+
from huggingface_hub import snapshot_download
|
| 45 |
+
import os
|
| 46 |
+
import sys
|
| 47 |
+
try:
|
| 48 |
+
snapshot_download(
|
| 49 |
+
repo_id='$ASRIPA_MODEL_ID',
|
| 50 |
+
local_dir='$ASRIPA_MODEL_DIR',
|
| 51 |
+
local_dir_use_symlinks=False
|
| 52 |
+
)
|
| 53 |
+
print('✅ Asripa model downloaded successfully!')
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f'⚠️ Failed to download Asripa model: {e}')
|
| 56 |
+
print(' App will use base model only')
|
| 57 |
+
import shutil
|
| 58 |
+
if os.path.exists('$ASRIPA_MODEL_DIR'):
|
| 59 |
+
shutil.rmtree('$ASRIPA_MODEL_DIR')
|
| 60 |
+
sys.exit(0) # Exit gracefully, not an error
|
| 61 |
+
" || {
|
| 62 |
+
echo "⚠️ Asripa model download skipped"
|
| 63 |
+
echo " App will use base model only"
|
| 64 |
+
rm -rf "$ASRIPA_MODEL_DIR" 2>/dev/null || true
|
| 65 |
+
}
|
| 66 |
+
elif [ -f "$ASRIPA_MODEL_DIR/model.safetensors" ]; then
|
| 67 |
+
echo "✅ Asripa model already present"
|
| 68 |
+
elif [ -z "$ASRIPA_MODEL_ID" ]; then
|
| 69 |
+
echo "ℹ️ ASRIPA_MODEL_ID not set - skipping Asripa model download"
|
| 70 |
+
fi
|
| 71 |
+
|
| 72 |
+
# Hugging Face Spaces uses port 7860 by default
|
| 73 |
+
# But we'll use PORT env var if set, otherwise default to 7860
|
| 74 |
+
PORT="${PORT:-7860}"
|
| 75 |
+
echo "Starting gunicorn on 0.0.0.0:${PORT}"
|
| 76 |
+
# Suppress protobuf warnings
|
| 77 |
+
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
| 78 |
+
exec gunicorn main:app --bind 0.0.0.0:"${PORT}" --workers 1 --threads 1 --timeout 120 --worker-class gthread
|
| 79 |
+
|
main.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# main.py
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
# Suppress protobuf version warnings (they're harmless but noisy)
|
| 7 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="google.protobuf")
|
| 8 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
| 9 |
+
|
| 10 |
+
# Make PROJECT_ROOT explicit so module-level code in the container works reliably
|
| 11 |
+
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 12 |
+
|
| 13 |
+
# Ensure logs dir exists
|
| 14 |
+
LOG_DIR = os.path.join(PROJECT_ROOT, "logs")
|
| 15 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
# Configure file logging (keeps container stdout clean and persists errors)
|
| 18 |
+
logfile = os.path.join(LOG_DIR, "app.log")
|
| 19 |
+
handler = logging.FileHandler(logfile)
|
| 20 |
+
handler.setLevel(logging.INFO)
|
| 21 |
+
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(module)s: %(message)s")
|
| 22 |
+
handler.setFormatter(formatter)
|
| 23 |
+
|
| 24 |
+
root_logger = logging.getLogger()
|
| 25 |
+
# Add handler only if not already added (avoids duplicates in dev reload)
|
| 26 |
+
if not any(isinstance(h, logging.FileHandler) and getattr(h, "baseFilename", "") == logfile for h in root_logger.handlers):
|
| 27 |
+
root_logger.addHandler(handler)
|
| 28 |
+
|
| 29 |
+
# Import factory after logging and directory setup so imports don't crash during bootstrap
|
| 30 |
+
from app import create_app
|
| 31 |
+
|
| 32 |
+
# Create app (allow env-driven config if needed)
|
| 33 |
+
app = create_app()
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
# allow overriding host/port via env (useful in Docker)
|
| 37 |
+
host = os.environ.get("HOST", "0.0.0.0")
|
| 38 |
+
port = int(os.environ.get("PORT", os.environ.get("FLASK_RUN_PORT", 5000)))
|
| 39 |
+
debug = os.environ.get("FLASK_DEBUG", "0") in ("1", "true", "True")
|
| 40 |
+
app.logger.info("Starting app on %s:%s (debug=%s)", host, port, debug)
|
| 41 |
+
app.run(host=host, port=port, debug=debug)
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Flask==3.1.1
|
| 2 |
+
flask-cors==4.0.0
|
| 3 |
+
|
| 4 |
+
# ML helpers (no TF/numpy here)
|
| 5 |
+
# Note: numpy<2 required for opencv-python-headless compatibility
|
| 6 |
+
numpy>=1.26.0,<2.0.0
|
| 7 |
+
h5py>=3.7.0
|
| 8 |
+
Pillow>=9.0.0
|
| 9 |
+
opencv-python-headless==4.9.0.80
|
| 10 |
+
|
| 11 |
+
# Vision Transformer support (for HardlyHumans model - 92.2% accuracy)
|
| 12 |
+
transformers>=4.30.0
|
| 13 |
+
torch>=2.0.0
|
| 14 |
+
huggingface_hub>=0.20.0 # For downloading Asripa model
|
| 15 |
+
|
| 16 |
+
# utilities & production
|
| 17 |
+
requests>=2.28.0
|
| 18 |
+
gunicorn>=23.0.0
|