HimAJ commited on
Commit
1e4fc28
·
verified ·
1 Parent(s): 96b2061

upload 32 files for the ml

Browse files
Dockerfile.hf ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for Hugging Face Spaces
2
+ FROM python:3.11-slim
3
+
4
+ ENV PYTHONUNBUFFERED=1
5
+ WORKDIR /app
6
+
7
+ # System dependencies for opencv and runtime model download
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ build-essential \
10
+ libgl1 \
11
+ libglib2.0-0 \
12
+ curl \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy requirements
16
+ COPY requirements.txt /app/requirements.txt
17
+
18
+ # Upgrade pip
19
+ RUN python -m pip install --upgrade pip setuptools wheel
20
+
21
+ # Install requirements
22
+ RUN pip install --no-cache-dir -r requirements.txt
23
+
24
+ # Copy app code
25
+ COPY . /app/
26
+
27
+ # Make entrypoint executable
28
+ RUN chmod +x /app/scripts/entrypoint.sh
29
+
30
+ # Hugging Face Spaces uses port 7860
31
+ EXPOSE 7860
32
+
33
+ # Use entrypoint script
34
+ ENTRYPOINT ["/app/scripts/entrypoint.sh"]
35
+
app/__init__.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/__init__.py
2
+ import os
3
+ import datetime
4
+ import csv
5
+ import traceback
6
+ import logging
7
+
8
+ from flask import Flask, request, jsonify
9
+ from flask_cors import CORS
10
+ from werkzeug.utils import secure_filename
11
+ from werkzeug.exceptions import RequestEntityTooLarge
12
+
13
+
14
+ # ----------------------------
15
+ # Module-level config (deterministic)
16
+ # ----------------------------
17
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
18
+ TMP_DIR_DEFAULT = os.path.join(PROJECT_ROOT, "tmp")
19
+ IMAGES_DIR_DEFAULT = os.path.join(PROJECT_ROOT, "images")
20
+ LOG_CSV = os.path.join(PROJECT_ROOT, "predictions_log.csv")
21
+ DB_PATH = os.path.join(PROJECT_ROOT, "predictions.db")
22
+
23
+ # App-level defaults (can be overridden via app.config)
24
+ DEFAULTS = {
25
+ "MIN_CONFIDENCE": 0.18, # Lowered to 0.18 for ambiguous cases (was 0.20, originally 0.5)
26
+ "MAX_FILE_SIZE": 5 * 1024 * 1024, # 5 MB
27
+ "TMP_DIR": TMP_DIR_DEFAULT,
28
+ "IMAGES_DIR": IMAGES_DIR_DEFAULT,
29
+ "ALLOWED_EXT": (".jpg", ".jpeg", ".png"),
30
+ "CORS_ORIGINS": "*", # Can be overridden for production
31
+ }
32
+
33
+ # Ensure directories exist
34
+ os.makedirs(DEFAULTS["TMP_DIR"], exist_ok=True)
35
+ os.makedirs(DEFAULTS["IMAGES_DIR"], exist_ok=True)
36
+
37
+ # Ensure CSV header exists (helpful for older logs)
38
+ if not os.path.exists(LOG_CSV):
39
+ try:
40
+ with open(LOG_CSV, "w", newline="", encoding="utf-8") as f:
41
+ writer = csv.writer(f)
42
+ writer.writerow(["timestamp", "filename", "emotion", "confidence"])
43
+ except Exception:
44
+ # Non-fatal — keep module import light.
45
+ pass
46
+
47
+
48
+ # ----------------------------
49
+ # Factory
50
+ # ----------------------------
51
+ def create_app(config: dict | None = None):
52
+ """
53
+ Create and return the Flask application.
54
+ Heavy imports (model loading, db init) are performed inside this factory
55
+ so importing modules from scripts/tests doesn't trigger expensive work.
56
+ """
57
+ # Merge defaults with provided config
58
+ cfg = DEFAULTS.copy()
59
+ if config:
60
+ cfg.update(config)
61
+
62
+ app = Flask(__name__)
63
+
64
+ # CORS configuration - allow config override
65
+ cors_origins = cfg.get("CORS_ORIGINS", DEFAULTS["CORS_ORIGINS"])
66
+ if cors_origins == "*":
67
+ CORS(app, resources={r"/*": {"origins": "*"}})
68
+ else:
69
+ # Allow list of origins
70
+ origins_list = cors_origins.split(",") if isinstance(cors_origins, str) else cors_origins
71
+ CORS(app, resources={r"/*": {"origins": origins_list}})
72
+
73
+ # ---------- file logging setup (after app created) ----------
74
+ LOG_DIR = os.path.join(PROJECT_ROOT, "logs")
75
+ try:
76
+ os.makedirs(LOG_DIR, exist_ok=True)
77
+ except Exception:
78
+ # If logs dir cannot be created, continue; app.logger will still work to stdout
79
+ pass
80
+
81
+ log_path = os.path.join(LOG_DIR, "app.log")
82
+ try:
83
+ file_handler = logging.FileHandler(log_path)
84
+ file_handler.setLevel(logging.INFO) # change to ERROR if you prefer
85
+ formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(module)s: %(message)s")
86
+ file_handler.setFormatter(formatter)
87
+
88
+ # avoid adding duplicate handlers when reloading
89
+ abs_log_path = os.path.abspath(log_path)
90
+ if not any(
91
+ isinstance(h, logging.FileHandler) and getattr(h, "baseFilename", None) == abs_log_path
92
+ for h in app.logger.handlers
93
+ ):
94
+ app.logger.addHandler(file_handler)
95
+ # set app logger level (don't lower if already configured higher)
96
+ app.logger.setLevel(logging.INFO)
97
+ except Exception:
98
+ # If logging can't be configured, keep going — logger will fallback to default handlers.
99
+ app.logger.exception("Failed to configure file logging")
100
+
101
+ # Apply config to app
102
+ app.config["MAX_CONTENT_LENGTH"] = cfg["MAX_FILE_SIZE"]
103
+ app.config["TMP_DIR"] = cfg["TMP_DIR"]
104
+ app.config["IMAGES_DIR"] = cfg.get("IMAGES_DIR", DEFAULTS["IMAGES_DIR"])
105
+ app.config["ALLOWED_EXT"] = cfg["ALLOWED_EXT"]
106
+ app.config["MIN_CONFIDENCE"] = cfg["MIN_CONFIDENCE"]
107
+
108
+
109
+ # Ensure tmp directory exists (again, per app)
110
+ os.makedirs(app.config["TMP_DIR"], exist_ok=True)
111
+
112
+ # Local (deferred) imports — avoid import-time side effects
113
+ from .model_loader import load_emotion_model
114
+ from .db_logger import init_db, log_prediction, get_metrics, tail_rows, get_total_count, delete_prediction
115
+ from .utils import preprocess_face
116
+ from .image_storage import save_image, get_image_path, ensure_images_dir
117
+ from .validators import validate_image_file, validate_pagination_params, validate_confidence_range
118
+ from .rate_limiter import detect_limiter, logs_limiter, images_limiter, get_client_identifier
119
+
120
+ # Initialize DB
121
+ try:
122
+ init_db(DB_PATH)
123
+ app.logger.info("Initialized SQLite DB at %s", DB_PATH)
124
+ except Exception:
125
+ app.logger.exception("Failed to initialize DB at startup")
126
+
127
+ # Load model & labels. Keep these local to the factory (no module-level side effects).
128
+ # We'll load models on-demand based on request parameter
129
+ base_model = None
130
+ base_labels = None
131
+ base_model_version = "unknown"
132
+ base_model_type = "unknown"
133
+ finetuned_model = None
134
+ finetuned_labels = None
135
+ finetuned_model_version = "unknown"
136
+ finetuned_model_type = "unknown"
137
+
138
+ # Load base model by default
139
+ try:
140
+ # load_emotion_model returns (model, labels, version, model_type)
141
+ res = load_emotion_model(force_model='base')
142
+ if isinstance(res, tuple) and len(res) == 4:
143
+ base_model, base_labels, base_model_version, base_model_type = res
144
+ elif isinstance(res, tuple) and len(res) == 3:
145
+ base_model, base_labels, base_model_version = res
146
+ base_model_type = "keras" # Default for old format
147
+ elif isinstance(res, tuple) and len(res) == 2:
148
+ base_model, base_labels = res
149
+ base_model_type = "keras" # Default for old format
150
+ else:
151
+ # Unexpected return shape - try to be permissive
152
+ try:
153
+ base_model = res
154
+ base_labels = None
155
+ base_model_type = "keras"
156
+ except Exception:
157
+ base_model = None
158
+ base_labels = None
159
+ base_model_type = "unknown"
160
+ app.logger.info("Base model loaded: %s (version=%s, type=%s)", bool(base_model), base_model_version, base_model_type)
161
+ print(f"[APP] Base model loaded: type={base_model_type}, version={base_model_version}, labels={len(base_labels) if base_labels else 0}")
162
+ except Exception as exc:
163
+ app.logger.exception("Base model failed to load at startup: %s", exc)
164
+ base_model = None
165
+ base_labels = None
166
+ base_model_version = "failed"
167
+ base_model_type = "unknown"
168
+
169
+ # Try to load fine-tuned model
170
+ try:
171
+ res = load_emotion_model(force_model='fine-tuned')
172
+ if isinstance(res, tuple) and len(res) == 4:
173
+ finetuned_model, finetuned_labels, finetuned_model_version, finetuned_model_type = res
174
+ elif isinstance(res, tuple) and len(res) == 3:
175
+ finetuned_model, finetuned_labels, finetuned_model_version = res
176
+ finetuned_model_type = "keras"
177
+ elif isinstance(res, tuple) and len(res) == 2:
178
+ finetuned_model, finetuned_labels = res
179
+ finetuned_model_type = "keras"
180
+ app.logger.info("Asripa model loaded: %s (version=%s, type=%s)", bool(finetuned_model), finetuned_model_version, finetuned_model_type)
181
+ print(f"[APP] Asripa model loaded: type={finetuned_model_type}, version={finetuned_model_version}")
182
+ except Exception as exc:
183
+ app.logger.warning("Asripa model not available: %s", exc)
184
+ finetuned_model = None
185
+ finetuned_labels = None
186
+ finetuned_model_version = "not-available"
187
+ finetuned_model_type = "unknown"
188
+
189
+ # Store in app.config - default to base model
190
+ app.config["BASE_MODEL"] = base_model
191
+ app.config["BASE_LABELS"] = base_labels
192
+ app.config["BASE_MODEL_VERSION"] = base_model_version
193
+ app.config["BASE_MODEL_TYPE"] = base_model_type
194
+ app.config["FINETUNED_MODEL"] = finetuned_model
195
+ app.config["FINETUNED_LABELS"] = finetuned_labels
196
+ app.config["FINETUNED_MODEL_VERSION"] = finetuned_model_version
197
+ app.config["FINETUNED_MODEL_TYPE"] = finetuned_model_type
198
+ # Default to base model
199
+ app.config["MODEL"] = base_model
200
+ app.config["LABELS"] = base_labels
201
+ app.config["MODEL_VERSION"] = base_model_version
202
+ app.config["MODEL_TYPE"] = base_model_type
203
+
204
+ # ----------------------------
205
+ # Error handlers (import before routes to ensure proper handling)
206
+ # ----------------------------
207
+ from .error_handlers import register_error_handlers, APIError, ValidationError, NotFoundError, ServiceUnavailableError
208
+
209
+ register_error_handlers(app)
210
+
211
+ # Make these available in route scope
212
+ globals()['APIError'] = APIError
213
+ globals()['ValidationError'] = ValidationError
214
+ globals()['NotFoundError'] = NotFoundError
215
+ globals()['ServiceUnavailableError'] = ServiceUnavailableError
216
+
217
+ @app.errorhandler(RequestEntityTooLarge)
218
+ def handle_large_file(e):
219
+ return jsonify({"error": "File too large", "max_size_mb": app.config.get("MAX_CONTENT_LENGTH", 5 * 1024 * 1024) / (1024 * 1024)}), 413
220
+
221
+ # ----------------------------
222
+ # Routes
223
+ # ----------------------------
224
+ @app.route("/")
225
+ def index():
226
+ return jsonify({"status": "ok", "message": "Flask backend running"}), 200
227
+
228
+ @app.route("/health", methods=["GET"])
229
+ def health():
230
+ """
231
+ Lightweight health check endpoint.
232
+ Optimized for speed - minimal checks to avoid timeouts.
233
+ """
234
+ try:
235
+ # Quick check - don't do expensive operations
236
+ model_loaded = bool(app.config.get("MODEL"))
237
+ model_type = app.config.get("MODEL_TYPE", "unknown")
238
+ model_version = app.config.get("MODEL_VERSION", "unknown")
239
+
240
+ # Get labels count quickly
241
+ labels_obj = app.config.get("LABELS")
242
+ labels_count = len(labels_obj) if labels_obj and hasattr(labels_obj, "__len__") else 0
243
+
244
+ return jsonify(
245
+ {
246
+ "ok": True,
247
+ "model_loaded": model_loaded,
248
+ "model_type": model_type,
249
+ "model_version": model_version,
250
+ "labels_count": labels_count,
251
+ }
252
+ ), 200
253
+ except Exception as e:
254
+ # Even if there's an error, return 200 to indicate service is running
255
+ # This prevents false "offline" status
256
+ app.logger.warning(f"Health check error (non-fatal): {e}")
257
+ return jsonify(
258
+ {
259
+ "ok": True,
260
+ "model_loaded": False,
261
+ "model_type": "unknown",
262
+ "model_version": "unknown",
263
+ "labels_count": 0,
264
+ "warning": "Health check had minor issues but service is running",
265
+ }
266
+ ), 200
267
+
268
+ @app.route("/metrics")
269
+ def metrics():
270
+ try:
271
+ m = get_metrics(DB_PATH)
272
+ recent = tail_rows(DB_PATH, limit=10)
273
+ return jsonify({"ok": True, "metrics": m, "recent": recent}), 200
274
+ except Exception as exc:
275
+ app.logger.exception("Failed to fetch metrics")
276
+ return jsonify({"error": "Failed to fetch metrics", "details": str(exc)}), 500
277
+
278
+ @app.route("/logs", methods=["GET"])
279
+ def logs():
280
+ """
281
+ GET /logs?limit=20&offset=0&emotion=happy&min_confidence=0.5&max_confidence=1.0&date_from=2024-01-01&date_to=2024-12-31
282
+
283
+ Returns paginated and filtered logs.
284
+ """
285
+ # Rate limiting
286
+ client_id = get_client_identifier(request)
287
+ is_allowed, remaining = logs_limiter.is_allowed(client_id)
288
+ if not is_allowed:
289
+ return jsonify({
290
+ "error": "Rate limit exceeded",
291
+ "detail": f"Maximum {logs_limiter.max_requests} requests per {logs_limiter.window_seconds} seconds",
292
+ "retry_after": logs_limiter.window_seconds,
293
+ }), 429
294
+
295
+ try:
296
+ # Validate pagination
297
+ limit, offset, pagination_error = validate_pagination_params(
298
+ request.args.get("limit"),
299
+ request.args.get("offset"),
300
+ )
301
+ if pagination_error:
302
+ return jsonify({"error": pagination_error}), 400
303
+
304
+ # Validate confidence range
305
+ min_confidence, max_confidence, confidence_error = validate_confidence_range(
306
+ request.args.get("min_confidence"),
307
+ request.args.get("max_confidence"),
308
+ )
309
+ if confidence_error:
310
+ return jsonify({"error": confidence_error}), 400
311
+
312
+ # Filters
313
+ emotion_filter = request.args.get("emotion", None)
314
+ if emotion_filter and emotion_filter.strip():
315
+ emotion_filter = emotion_filter.strip()
316
+ else:
317
+ emotion_filter = None
318
+
319
+ date_from = request.args.get("date_from", None)
320
+ date_to = request.args.get("date_to", None)
321
+
322
+ # Fetch data
323
+ rows = tail_rows(
324
+ DB_PATH,
325
+ limit=limit,
326
+ offset=offset,
327
+ emotion_filter=emotion_filter,
328
+ min_confidence=min_confidence,
329
+ max_confidence=max_confidence,
330
+ date_from=date_from,
331
+ date_to=date_to,
332
+ )
333
+
334
+ total = get_total_count(
335
+ DB_PATH,
336
+ emotion_filter=emotion_filter,
337
+ min_confidence=min_confidence,
338
+ max_confidence=max_confidence,
339
+ date_from=date_from,
340
+ date_to=date_to,
341
+ )
342
+
343
+ # Convert to list of dicts
344
+ result = []
345
+ for r in rows:
346
+ if len(r) == 6:
347
+ _id, ts, filename, image_path, emotion, confidence = r
348
+ record = {
349
+ "id": _id,
350
+ "ts": ts,
351
+ "filename": filename,
352
+ "image_path": image_path or filename, # Fallback to filename if no image_path
353
+ "emotion": emotion,
354
+ "confidence": confidence,
355
+ }
356
+ elif len(r) == 5:
357
+ _id, ts, filename, emotion, confidence = r
358
+ record = {
359
+ "id": _id,
360
+ "ts": ts,
361
+ "filename": filename,
362
+ "image_path": filename, # Fallback
363
+ "emotion": emotion,
364
+ "confidence": confidence,
365
+ }
366
+ elif len(r) == 4:
367
+ ts, filename, emotion, confidence = r
368
+ record = {
369
+ "ts": ts,
370
+ "filename": filename,
371
+ "image_path": filename, # Fallback
372
+ "emotion": emotion,
373
+ "confidence": confidence,
374
+ }
375
+ else:
376
+ record = {"row": r}
377
+ result.append(record)
378
+
379
+ return jsonify({
380
+ "ok": True,
381
+ "logs": result,
382
+ "pagination": {
383
+ "total": total,
384
+ "limit": limit,
385
+ "offset": offset,
386
+ "has_more": (offset + limit) < total,
387
+ },
388
+ }), 200
389
+ except Exception as exc:
390
+ app.logger.exception("Failed to fetch logs")
391
+ return jsonify({"error": "Failed to fetch logs", "detail": str(exc)}), 500
392
+
393
+ @app.route("/logs/<int:prediction_id>", methods=["DELETE"])
394
+ def delete_log(prediction_id: int):
395
+ """
396
+ DELETE /logs/<id>
397
+
398
+ Delete a prediction by ID.
399
+ """
400
+ # Rate limiting
401
+ client_id = get_client_identifier(request)
402
+ is_allowed, remaining = logs_limiter.is_allowed(client_id)
403
+ if not is_allowed:
404
+ return jsonify({
405
+ "error": "Rate limit exceeded",
406
+ "detail": f"Maximum {logs_limiter.max_requests} requests per {logs_limiter.window_seconds} seconds",
407
+ "retry_after": logs_limiter.window_seconds,
408
+ }), 429
409
+
410
+ try:
411
+ # Delete from database
412
+ deleted = delete_prediction(DB_PATH, prediction_id)
413
+
414
+ if not deleted:
415
+ return jsonify({"error": "Prediction not found"}), 404
416
+
417
+ # Optionally delete associated image file
418
+ from .image_storage import delete_image
419
+ # Note: We'd need to fetch the image_path first, but for now just delete from DB
420
+ # You can enhance this later to also delete the image file
421
+
422
+ return jsonify({"ok": True, "message": "Prediction deleted successfully"}), 200
423
+ except Exception as exc:
424
+ app.logger.exception(f"Failed to delete prediction {prediction_id}")
425
+ return jsonify({"error": "Failed to delete prediction", "detail": str(exc)}), 500
426
+
427
+ @app.route("/detect", methods=["POST"])
428
+ def detect():
429
+ """
430
+ POST form-data: image file under key 'image'
431
+ Returns: JSON {emotion, confidence} or error JSON
432
+ """
433
+ # Rate limiting
434
+ client_id = get_client_identifier(request)
435
+ is_allowed, remaining = detect_limiter.is_allowed(client_id)
436
+ if not is_allowed:
437
+ return jsonify({
438
+ "error": "Rate limit exceeded",
439
+ "detail": f"Maximum {detect_limiter.max_requests} requests per {detect_limiter.window_seconds} seconds",
440
+ "retry_after": detect_limiter.window_seconds,
441
+ }), 429
442
+
443
+ # Get model selection from query parameter (default: 'base')
444
+ model_selection = request.args.get("model", "base").lower()
445
+ if model_selection == "fine-tuned" or model_selection == "finetuned":
446
+ model_local = app.config.get("FINETUNED_MODEL")
447
+ labels_local = app.config.get("FINETUNED_LABELS") or []
448
+ model_type = app.config.get("FINETUNED_MODEL_TYPE", "keras")
449
+ model_version = app.config.get("FINETUNED_MODEL_VERSION", "unknown")
450
+ if model_local is None:
451
+ app.logger.warning("Asripa model requested but not available, using base model")
452
+ model_local = app.config.get("BASE_MODEL")
453
+ labels_local = app.config.get("BASE_LABELS") or []
454
+ model_type = app.config.get("BASE_MODEL_TYPE", "keras")
455
+ model_version = app.config.get("BASE_MODEL_VERSION", "unknown")
456
+ else:
457
+ # Use base model (default)
458
+ model_local = app.config.get("BASE_MODEL")
459
+ labels_local = app.config.get("BASE_LABELS") or []
460
+ model_type = app.config.get("BASE_MODEL_TYPE", "keras")
461
+ model_version = app.config.get("BASE_MODEL_VERSION", "unknown")
462
+
463
+ app.logger.info(f"Using model: {model_selection} (version: {model_version})")
464
+
465
+ if model_local is None:
466
+ app.logger.error("Detect called but model not loaded")
467
+ raise ServiceUnavailableError("Model not loaded on server")
468
+
469
+ print(f"[DETECT] Using model type: {model_type}")
470
+
471
+ # Validate upload presence
472
+ if "image" not in request.files:
473
+ raise ValidationError("No image provided")
474
+
475
+ file = request.files["image"]
476
+
477
+ # Comprehensive validation
478
+ is_valid, error_msg, filename = validate_image_file(
479
+ file,
480
+ max_size=app.config.get("MAX_CONTENT_LENGTH", DEFAULTS["MAX_FILE_SIZE"]),
481
+ allowed_extensions=app.config.get("ALLOWED_EXT", DEFAULTS["ALLOWED_EXT"]),
482
+ )
483
+
484
+ if not is_valid:
485
+ raise ValidationError(error_msg)
486
+
487
+ tmp_dir = app.config.get("TMP_DIR", TMP_DIR_DEFAULT)
488
+ tmp_path = os.path.join(tmp_dir, filename)
489
+ used_filename = filename
490
+
491
+ try:
492
+ # Save file and verify it was saved
493
+ file.save(tmp_path)
494
+ if not os.path.exists(tmp_path):
495
+ app.logger.error("Failed to save uploaded file to %s", tmp_path)
496
+ raise ValidationError("Failed to save uploaded image")
497
+
498
+ file_size = os.path.getsize(tmp_path)
499
+ if file_size == 0:
500
+ app.logger.error("Saved file is empty: %s", tmp_path)
501
+ raise ValidationError("Uploaded image is empty")
502
+
503
+ print(f"[DETECT] Saved file: {tmp_path}, size: {file_size} bytes")
504
+ app.logger.info("Saved file: %s, size: %d bytes", tmp_path, file_size)
505
+
506
+ # Import numpy for both paths
507
+ import numpy as np
508
+
509
+ # Handle ViT and Keras models differently
510
+ if model_type == "vit":
511
+ # Vision Transformer model - needs RGB PIL Image
512
+ from app.vit_utils import preprocess_face_for_vit, predict_with_vit
513
+ from PIL import Image
514
+
515
+ face_image, used_filename = preprocess_face_for_vit(tmp_path)
516
+ if face_image is None:
517
+ app.logger.warning("No face detected for file %s (size: %d bytes)", filename, file_size)
518
+ raise ValidationError("No face detected in image. Please ensure your face is clearly visible, well-lit, and facing the camera.")
519
+
520
+ # Run ViT prediction
521
+ idx, confidence, all_probs = predict_with_vit(model_local, face_image, labels_local)
522
+ emotion = labels_local[idx] if idx < len(labels_local) else str(idx)
523
+
524
+ # Debug output
525
+ sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
526
+ app.logger.info(f"Prediction probabilities for {filename} (sorted): {sorted_probs}")
527
+ print(f"[DETECT] All emotion probabilities (sorted by confidence):")
528
+ for emo, prob in sorted_probs:
529
+ marker = " <-- SELECTED" if emo == emotion else ""
530
+ print(f" {emo}: {prob:.3f}{marker}")
531
+ print(f"[DETECT] Predicted emotion: {emotion}, confidence: {confidence:.3f}")
532
+
533
+ # Warn if happy probability is suspiciously low (potential misclassification)
534
+ happy_prob = all_probs.get('happy', 0.0)
535
+ if happy_prob < 0.15 and confidence > 0.3 and emotion != 'happy':
536
+ app.logger.warning(f"⚠️ Low happy probability ({happy_prob:.3f}) but high confidence ({confidence:.3f}) for {emotion}. Possible misclassification.")
537
+ print(f"[DETECT] ⚠️ WARNING: Happy probability is very low ({happy_prob:.3f}) - possible misclassification")
538
+
539
+ # Convert to numpy array format for compatibility with rest of code
540
+ probs = np.array([all_probs.get(labels_local[i] if i < len(labels_local) else f"class_{i}", 0.0)
541
+ for i in range(len(labels_local))])
542
+ else:
543
+ # Keras model - existing code path
544
+ # Preprocess face - preprocess_face is imported above in factory scope
545
+ res = preprocess_face(tmp_path)
546
+ if isinstance(res, tuple):
547
+ face_array, used_filename = res
548
+ else:
549
+ face_array = res
550
+
551
+ if face_array is None:
552
+ app.logger.warning("No face detected for file %s (size: %d bytes)", filename, file_size)
553
+ raise ValidationError("No face detected in image. Please ensure your face is clearly visible, well-lit, and facing the camera.")
554
+
555
+ # Defensive conversion and validations (numpy already imported above)
556
+ try:
557
+ face_input = np.asarray(face_array)
558
+ except Exception as exc:
559
+ app.logger.exception("Failed converting preprocessed face to numpy array")
560
+ return jsonify({"error": "Invalid preprocessed face data."}), 500
561
+
562
+ if getattr(face_input, "dtype", None) == object:
563
+ app.logger.error("face_input has object dtype (likely contains None) for file %s", filename)
564
+ return jsonify({"error": "Invalid preprocessed face data (object dtype)."}), 500
565
+
566
+ # Ensure batch dim and channel dim
567
+ if face_input.ndim == 2:
568
+ # (H, W) -> (1, H, W, 1)
569
+ face_input = np.expand_dims(np.expand_dims(face_input, axis=-1), axis=0)
570
+ elif face_input.ndim == 3:
571
+ # (H, W, C) -> (1, H, W, C)
572
+ face_input = np.expand_dims(face_input, axis=0)
573
+ elif face_input.ndim == 4:
574
+ # already batched
575
+ pass
576
+ else:
577
+ app.logger.error("Unsupported preprocessed face ndim %s for file %s", getattr(face_input, "ndim", None), filename)
578
+ return jsonify({"error": "Unsupported preprocessed face shape."}), 500
579
+
580
+ # sanity checks
581
+ if face_input.shape[0] < 1:
582
+ return jsonify({"error": "Empty batch sent to model."}), 500
583
+ try:
584
+ if not np.isfinite(face_input.astype("float32")).all():
585
+ app.logger.error("face_input contains non-finite values for file %s", filename)
586
+ return jsonify({"error": "Preprocessed face contains non-finite values."}), 500
587
+ except Exception:
588
+ app.logger.exception("Failed checking finiteness of face_input")
589
+ return jsonify({"error": "Preprocessed face contains invalid numeric values."}), 500
590
+
591
+ # Run prediction
592
+ try:
593
+ preds = model_local.predict(face_input, verbose=0)
594
+ except Exception as exc:
595
+ app.logger.exception("Model predict failed for file %s", filename)
596
+ return jsonify({"error": "Prediction failed", "detail": str(exc)}), 500
597
+
598
+ if preds is None:
599
+ return jsonify({"error": "Prediction returned no output"}), 500
600
+
601
+ arr = np.asarray(preds)
602
+ if arr.ndim == 2:
603
+ probs = arr[0]
604
+ elif arr.ndim == 1:
605
+ probs = arr
606
+ else:
607
+ app.logger.error("Unexpected prediction shape %s for file %s", getattr(arr, "shape", None), filename)
608
+ return jsonify({"error": "Unexpected prediction shape", "shape": list(getattr(arr, "shape", []))}), 500
609
+
610
+ if probs.size == 0:
611
+ return jsonify({"error": "Empty prediction probabilities"}), 500
612
+
613
+ # Verify model output matches expected number of classes
614
+ expected_classes = len(labels_local) if isinstance(labels_local, (list, dict)) else 7
615
+ if len(probs) != expected_classes:
616
+ app.logger.warning(f"Model output has {len(probs)} classes but labels have {expected_classes}. Labels: {labels_local}")
617
+ print(f"[WARNING] Model output shape mismatch: {len(probs)} classes vs {expected_classes} labels")
618
+
619
+ idx = int(np.argmax(probs))
620
+ confidence = float(probs[idx])
621
+
622
+ # Debug: Log all prediction probabilities to understand model behavior
623
+ all_probs = {}
624
+ for i in range(len(probs)):
625
+ if isinstance(labels_local, list) and i < len(labels_local):
626
+ all_probs[labels_local[i]] = float(probs[i])
627
+ elif isinstance(labels_local, dict):
628
+ label_key = str(i) if str(i) in labels_local else i if i in labels_local else f"class_{i}"
629
+ all_probs[label_key] = float(probs[i])
630
+ else:
631
+ all_probs[str(i)] = float(probs[i])
632
+
633
+ # Sort by probability (highest first) for easier debugging
634
+ sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
635
+ app.logger.info(f"Prediction probabilities for {filename} (sorted): {sorted_probs}")
636
+ print(f"[DETECT] All emotion probabilities (sorted by confidence):")
637
+ for emotion, prob in sorted_probs:
638
+ marker = " <-- SELECTED" if emotion == (labels_local[idx] if isinstance(labels_local, list) and idx < len(labels_local) else str(idx)) else ""
639
+ print(f" {emotion}: {prob:.3f}{marker}")
640
+ print(f"[DETECT] Predicted emotion index: {idx}, confidence: {confidence:.3f}")
641
+ print(f"[DETECT] Available labels: {labels_local}")
642
+
643
+ # Resolve label safely
644
+ if isinstance(labels_local, dict):
645
+ emotion = labels_local.get(str(idx)) or labels_local.get(idx) or list(labels_local.values())[idx]
646
+ elif isinstance(labels_local, list):
647
+ emotion = labels_local[idx] if 0 <= idx < len(labels_local) else str(idx)
648
+ else:
649
+ emotion = str(idx)
650
+
651
+ print(f"[DETECT] Mapped emotion label: {emotion}")
652
+
653
+ # Save image even for low confidence (for debugging/analysis)
654
+ images_dir = app.config.get("IMAGES_DIR", IMAGES_DIR_DEFAULT)
655
+ stored_filename = None
656
+ try:
657
+ stored_filename = save_image(tmp_path, images_dir, used_filename)
658
+ except Exception:
659
+ app.logger.exception("Failed to save image, continuing without storage")
660
+
661
+ # Confidence threshold - slightly lower for better detection in challenging conditions
662
+ # But still maintain quality standards
663
+ min_conf = app.config.get("MIN_CONFIDENCE", DEFAULTS["MIN_CONFIDENCE"])
664
+ # Allow slightly lower confidence (0.45) but warn user
665
+ if confidence < min_conf:
666
+ try:
667
+ log_prediction(DB_PATH, used_filename, "low_confidence", confidence, stored_filename)
668
+ except Exception:
669
+ app.logger.exception("Failed logging low-confidence prediction")
670
+ return jsonify({
671
+ "error": "low confidence",
672
+ "confidence": round(confidence, 3),
673
+ "filename": stored_filename or used_filename,
674
+ }), 422
675
+
676
+ # Log and respond (image already saved above)
677
+ try:
678
+ log_prediction(DB_PATH, used_filename, emotion, confidence, stored_filename)
679
+ except Exception:
680
+ app.logger.exception("Failed to log prediction to DB")
681
+
682
+ # Return all probabilities for debugging (frontend can use this to show top emotions)
683
+ all_emotion_probs = {}
684
+ if model_type == "vit":
685
+ # For ViT, all_probs already contains the dict
686
+ all_emotion_probs = {k: round(v, 4) for k, v in all_probs.items()}
687
+ else:
688
+ # For Keras, build from probs array
689
+ for i in range(len(probs)):
690
+ if isinstance(labels_local, list) and i < len(labels_local):
691
+ all_emotion_probs[labels_local[i]] = round(float(probs[i]), 4)
692
+ elif isinstance(labels_local, dict):
693
+ label_key = str(i) if str(i) in labels_local else i if i in labels_local else f"class_{i}"
694
+ all_emotion_probs[label_key] = round(float(probs[i]), 4)
695
+
696
+ return jsonify({
697
+ "emotion": emotion,
698
+ "confidence": round(confidence, 3),
699
+ "filename": stored_filename or used_filename,
700
+ "all_probabilities": all_emotion_probs, # Include all probabilities for debugging
701
+ "model": model_selection,
702
+ "model_version": model_version,
703
+ }), 200
704
+
705
+ except (ValidationError, APIError, NotFoundError, ServiceUnavailableError) as exc:
706
+ # Let Flask's error handler process these
707
+ raise
708
+ except Exception as exc:
709
+ app.logger.exception("detection error for file %s", filename)
710
+ tb = traceback.format_exc()
711
+ return jsonify({"error": "internal error", "detail": str(exc), "trace": tb}), 500
712
+
713
+ finally:
714
+ # cleanup tmp file (image is already saved to images/ if successful)
715
+ try:
716
+ if os.path.exists(tmp_path):
717
+ os.remove(tmp_path)
718
+ except Exception:
719
+ app.logger.exception("failed removing tmp file")
720
+
721
+ # ----------------------------
722
+ # Image serving endpoint
723
+ # ----------------------------
724
+ @app.route("/images/<filename>", methods=["GET"])
725
+ def serve_image(filename: str):
726
+ """
727
+ Serve stored images.
728
+ GET /images/{filename}
729
+ """
730
+ from flask import send_from_directory, abort
731
+
732
+ # Rate limiting
733
+ client_id = get_client_identifier(request)
734
+ is_allowed, remaining = images_limiter.is_allowed(client_id)
735
+ if not is_allowed:
736
+ return jsonify({
737
+ "error": "Rate limit exceeded",
738
+ "detail": f"Maximum {images_limiter.max_requests} requests per {images_limiter.window_seconds} seconds",
739
+ "retry_after": images_limiter.window_seconds,
740
+ }), 429
741
+
742
+ try:
743
+ images_dir = app.config.get("IMAGES_DIR", IMAGES_DIR_DEFAULT)
744
+ image_path = get_image_path(images_dir, filename)
745
+
746
+ if not image_path:
747
+ app.logger.warning("Image not found: %s (checked in %s)", filename, images_dir)
748
+ abort(404)
749
+
750
+ # Extract the actual filename from the path (in case secure_filename changed it)
751
+ actual_filename = os.path.basename(image_path)
752
+
753
+ return send_from_directory(
754
+ images_dir,
755
+ actual_filename,
756
+ mimetype="image/jpeg", # Default, will be auto-detected
757
+ )
758
+ except Exception as exc:
759
+ app.logger.exception("Failed to serve image %s", filename)
760
+ return jsonify({"error": "Failed to serve image", "detail": str(exc)}), 500
761
+
762
+ return app
app/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (38.7 kB). View file
 
app/__pycache__/db_logger.cpython-312.pyc ADDED
Binary file (11.5 kB). View file
 
app/__pycache__/error_handlers.cpython-312.pyc ADDED
Binary file (3.69 kB). View file
 
app/__pycache__/image_storage.cpython-312.pyc ADDED
Binary file (4.48 kB). View file
 
app/__pycache__/model_loader.cpython-312.pyc ADDED
Binary file (7.95 kB). View file
 
app/__pycache__/rate_limiter.cpython-312.pyc ADDED
Binary file (3.83 kB). View file
 
app/__pycache__/utils.cpython-312.pyc ADDED
Binary file (6.33 kB). View file
 
app/__pycache__/validators.cpython-312.pyc ADDED
Binary file (4.52 kB). View file
 
app/__pycache__/vit_utils.cpython-312.pyc ADDED
Binary file (11.2 kB). View file
 
app/db_logger.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sqlite3
3
+ import os
4
+ import datetime
5
+ from typing import Dict, Tuple, List, Optional
6
+ import threading
7
+
8
+ SCHEMA = """
9
+ PRAGMA foreign_keys = ON;
10
+ CREATE TABLE IF NOT EXISTS predictions (
11
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
12
+ ts TEXT NOT NULL,
13
+ filename TEXT,
14
+ image_path TEXT,
15
+ emotion TEXT,
16
+ confidence REAL
17
+ );
18
+
19
+ -- Indexes for better query performance
20
+ CREATE INDEX IF NOT EXISTS idx_predictions_ts ON predictions(ts DESC);
21
+ CREATE INDEX IF NOT EXISTS idx_predictions_emotion ON predictions(emotion);
22
+ CREATE INDEX IF NOT EXISTS idx_predictions_confidence ON predictions(confidence);
23
+ """
24
+
25
+ # Connection pool for better performance
26
+ _db_lock = threading.Lock()
27
+ _connection_pool: Dict[str, sqlite3.Connection] = {}
28
+
29
+
30
+ def get_connection(db_path: str, timeout: int = 10) -> sqlite3.Connection:
31
+ """
32
+ Get a database connection with connection pooling.
33
+ For SQLite, we use a simple per-thread connection approach.
34
+ """
35
+ thread_id = threading.get_ident()
36
+ key = f"{db_path}_{thread_id}"
37
+
38
+ with _db_lock:
39
+ if key not in _connection_pool:
40
+ conn = sqlite3.connect(db_path, timeout=timeout, check_same_thread=False)
41
+ # Optimize SQLite settings
42
+ conn.execute("PRAGMA journal_mode=WAL;")
43
+ conn.execute("PRAGMA synchronous=NORMAL;")
44
+ conn.execute("PRAGMA cache_size=10000;")
45
+ conn.execute("PRAGMA temp_store=MEMORY;")
46
+ _connection_pool[key] = conn
47
+ return _connection_pool[key]
48
+
49
+
50
+ def init_db(db_path: str):
51
+ db_dir = os.path.dirname(db_path)
52
+ if db_dir and not os.path.exists(db_dir):
53
+ os.makedirs(db_dir, exist_ok=True)
54
+
55
+ conn = sqlite3.connect(db_path, timeout=10)
56
+ try:
57
+ conn.execute("PRAGMA journal_mode=WAL;")
58
+ conn.execute("PRAGMA synchronous=NORMAL;")
59
+ conn.execute("PRAGMA cache_size=10000;")
60
+ conn.executescript(SCHEMA)
61
+ conn.commit()
62
+ finally:
63
+ conn.close()
64
+
65
+ def log_prediction(db_path: str, filename: str, emotion: str, confidence: float, image_path: Optional[str] = None):
66
+ """
67
+ Logs a prediction row. This function ensures ts is a string and that
68
+ values bound to SQLite are primitive types (no functions or callables).
69
+
70
+ Args:
71
+ db_path: Path to SQLite database
72
+ filename: Original filename
73
+ emotion: Detected emotion
74
+ confidence: Confidence score
75
+ image_path: Path to stored image file (optional)
76
+ """
77
+ # Defensive conversions
78
+ try:
79
+ ts = datetime.datetime.now(datetime.UTC).isoformat()
80
+ except Exception:
81
+ # fallback to str(datetime)
82
+ ts = str(datetime.datetime.utcnow())
83
+
84
+ if filename is None:
85
+ filename = ""
86
+ else:
87
+ filename = str(filename)
88
+
89
+ if emotion is None:
90
+ emotion = ""
91
+ else:
92
+ emotion = str(emotion)
93
+
94
+ if image_path is None:
95
+ image_path = ""
96
+ else:
97
+ image_path = str(image_path)
98
+
99
+ try:
100
+ confidence_val = float(confidence or 0.0)
101
+ except Exception:
102
+ confidence_val = 0.0
103
+
104
+ conn = get_connection(db_path)
105
+ try:
106
+ cur = conn.cursor()
107
+ # Check if image_path column exists, if not, add it
108
+ cur.execute("PRAGMA table_info(predictions)")
109
+ columns = [row[1] for row in cur.fetchall()]
110
+
111
+ if "image_path" not in columns:
112
+ # Migrate schema - add image_path column
113
+ cur.execute("ALTER TABLE predictions ADD COLUMN image_path TEXT")
114
+ conn.commit()
115
+
116
+ cur.execute(
117
+ "INSERT INTO predictions (ts, filename, image_path, emotion, confidence) VALUES (?, ?, ?, ?, ?)",
118
+ (ts, filename, image_path, emotion, confidence_val)
119
+ )
120
+ conn.commit()
121
+ return cur.lastrowid
122
+ except Exception:
123
+ # On error, close connection and retry with new connection
124
+ with _db_lock:
125
+ thread_id = threading.get_ident()
126
+ key = f"{db_path}_{thread_id}"
127
+ if key in _connection_pool:
128
+ try:
129
+ _connection_pool[key].close()
130
+ except:
131
+ pass
132
+ del _connection_pool[key]
133
+ raise
134
+
135
+ def get_metrics(db_path: str) -> Dict:
136
+ conn = get_connection(db_path)
137
+ try:
138
+ cur = conn.cursor()
139
+ cur.execute("SELECT COUNT(*) FROM predictions")
140
+ total = cur.fetchone()[0] or 0
141
+ cur.execute("SELECT emotion, COUNT(*) FROM predictions GROUP BY emotion")
142
+ rows = cur.fetchall()
143
+ by_label = {r[0]: r[1] for r in rows}
144
+ return {"total": total, "by_label": by_label}
145
+ except Exception:
146
+ with _db_lock:
147
+ thread_id = threading.get_ident()
148
+ key = f"{db_path}_{thread_id}"
149
+ if key in _connection_pool:
150
+ try:
151
+ _connection_pool[key].close()
152
+ except:
153
+ pass
154
+ del _connection_pool[key]
155
+ raise
156
+
157
+ def tail_rows(db_path: str, limit: int = 10, offset: int = 0, emotion_filter: Optional[str] = None,
158
+ min_confidence: Optional[float] = None, max_confidence: Optional[float] = None,
159
+ date_from: Optional[str] = None, date_to: Optional[str] = None) -> Tuple:
160
+ """
161
+ Fetch rows from predictions table with filtering and pagination.
162
+
163
+ Returns:
164
+ List of tuples: (id, ts, filename, image_path, emotion, confidence) or
165
+ (ts, filename, image_path, emotion, confidence) depending on query
166
+ """
167
+ conn = get_connection(db_path)
168
+ try:
169
+ cur = conn.cursor()
170
+
171
+ # Build query with filters
172
+ query = "SELECT id, ts, filename, image_path, emotion, confidence FROM predictions WHERE 1=1"
173
+ params = []
174
+
175
+ if emotion_filter:
176
+ query += " AND emotion = ?"
177
+ params.append(emotion_filter)
178
+
179
+ if min_confidence is not None:
180
+ query += " AND confidence >= ?"
181
+ params.append(min_confidence)
182
+
183
+ if max_confidence is not None:
184
+ query += " AND confidence <= ?"
185
+ params.append(max_confidence)
186
+
187
+ if date_from:
188
+ query += " AND ts >= ?"
189
+ params.append(date_from)
190
+
191
+ if date_to:
192
+ query += " AND ts <= ?"
193
+ params.append(date_to)
194
+
195
+ query += " ORDER BY id DESC LIMIT ? OFFSET ?"
196
+ params.extend([limit, offset])
197
+
198
+ cur.execute(query, params)
199
+ return cur.fetchall()
200
+ except Exception:
201
+ with _db_lock:
202
+ thread_id = threading.get_ident()
203
+ key = f"{db_path}_{thread_id}"
204
+ if key in _connection_pool:
205
+ try:
206
+ _connection_pool[key].close()
207
+ except:
208
+ pass
209
+ del _connection_pool[key]
210
+ raise
211
+
212
+
213
+ def delete_prediction(db_path: str, prediction_id: int) -> bool:
214
+ """
215
+ Delete a prediction by ID.
216
+
217
+ Args:
218
+ db_path: Path to SQLite database
219
+ prediction_id: ID of prediction to delete
220
+
221
+ Returns:
222
+ True if deleted, False otherwise
223
+ """
224
+ conn = get_connection(db_path)
225
+ try:
226
+ cur = conn.cursor()
227
+ cur.execute("DELETE FROM predictions WHERE id = ?", (prediction_id,))
228
+ conn.commit()
229
+ return cur.rowcount > 0
230
+ except Exception:
231
+ with _db_lock:
232
+ thread_id = threading.get_ident()
233
+ key = f"{db_path}_{thread_id}"
234
+ if key in _connection_pool:
235
+ try:
236
+ _connection_pool[key].close()
237
+ except:
238
+ pass
239
+ del _connection_pool[key]
240
+ raise
241
+
242
+
243
+ def get_total_count(db_path: str, emotion_filter: Optional[str] = None,
244
+ min_confidence: Optional[float] = None, max_confidence: Optional[float] = None,
245
+ date_from: Optional[str] = None, date_to: Optional[str] = None) -> int:
246
+ """Get total count of predictions matching filters."""
247
+ conn = get_connection(db_path)
248
+ try:
249
+ cur = conn.cursor()
250
+
251
+ query = "SELECT COUNT(*) FROM predictions WHERE 1=1"
252
+ params = []
253
+
254
+ if emotion_filter:
255
+ query += " AND emotion = ?"
256
+ params.append(emotion_filter)
257
+
258
+ if min_confidence is not None:
259
+ query += " AND confidence >= ?"
260
+ params.append(min_confidence)
261
+
262
+ if max_confidence is not None:
263
+ query += " AND confidence <= ?"
264
+ params.append(max_confidence)
265
+
266
+ if date_from:
267
+ query += " AND ts >= ?"
268
+ params.append(date_from)
269
+
270
+ if date_to:
271
+ query += " AND ts <= ?"
272
+ params.append(date_to)
273
+
274
+ cur.execute(query, params)
275
+ return cur.fetchone()[0] or 0
276
+ except Exception:
277
+ with _db_lock:
278
+ thread_id = threading.get_ident()
279
+ key = f"{db_path}_{thread_id}"
280
+ if key in _connection_pool:
281
+ try:
282
+ _connection_pool[key].close()
283
+ except:
284
+ pass
285
+ del _connection_pool[key]
286
+ raise
app/error_handlers.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Structured error handling for API responses.
3
+ """
4
+ from flask import jsonify
5
+ from typing import Dict, Any
6
+
7
+
8
+ class APIError(Exception):
9
+ """Base exception for API errors."""
10
+ status_code = 500
11
+ message = "An error occurred"
12
+
13
+ def __init__(self, message: str = None, status_code: int = None, details: Dict[str, Any] = None):
14
+ super().__init__()
15
+ self.message = message or self.message
16
+ self.status_code = status_code or self.status_code
17
+ self.details = details or {}
18
+
19
+ def to_dict(self) -> Dict[str, Any]:
20
+ return {
21
+ "error": self.message,
22
+ **self.details,
23
+ }
24
+
25
+
26
+ class ValidationError(APIError):
27
+ """Validation error (400)."""
28
+ status_code = 400
29
+ message = "Validation error"
30
+
31
+
32
+ class NotFoundError(APIError):
33
+ """Resource not found (404)."""
34
+ status_code = 404
35
+ message = "Resource not found"
36
+
37
+
38
+ class ServiceUnavailableError(APIError):
39
+ """Service unavailable (503)."""
40
+ status_code = 503
41
+ message = "Service unavailable"
42
+
43
+
44
+ def register_error_handlers(app):
45
+ """Register error handlers for the Flask app."""
46
+
47
+ @app.errorhandler(APIError)
48
+ def handle_api_error(error: APIError):
49
+ response = jsonify(error.to_dict())
50
+ response.status_code = error.status_code
51
+ return response
52
+
53
+ @app.errorhandler(404)
54
+ def handle_not_found(e):
55
+ return jsonify({"error": "Endpoint not found"}), 404
56
+
57
+ @app.errorhandler(405)
58
+ def handle_method_not_allowed(e):
59
+ return jsonify({"error": "Method not allowed"}), 405
60
+
61
+ @app.errorhandler(500)
62
+ def handle_internal_error(e):
63
+ app.logger.exception("Internal server error")
64
+ return jsonify({"error": "Internal server error"}), 500
app/image_cleanup.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image cleanup utility to remove orphaned images (not referenced in database).
3
+ Can be run as a scheduled job or manually.
4
+ """
5
+ import os
6
+ import sqlite3
7
+ from pathlib import Path
8
+ from typing import Set
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def get_referenced_images(db_path: str) -> Set[str]:
15
+ """
16
+ Get set of all image filenames referenced in the database.
17
+
18
+ Returns:
19
+ Set of image filenames (basenames only)
20
+ """
21
+ conn = sqlite3.connect(db_path, timeout=10)
22
+ try:
23
+ cur = conn.cursor()
24
+
25
+ # Check if image_path column exists
26
+ cur.execute("PRAGMA table_info(predictions)")
27
+ columns = [row[1] for row in cur.fetchall()]
28
+
29
+ if "image_path" not in columns:
30
+ # Column doesn't exist yet, return empty set
31
+ return set()
32
+
33
+ # Get all non-empty image_path values
34
+ cur.execute("SELECT DISTINCT image_path FROM predictions WHERE image_path IS NOT NULL AND image_path != ''")
35
+ rows = cur.fetchall()
36
+
37
+ # Extract just the filenames (basenames)
38
+ referenced = set()
39
+ for row in rows:
40
+ if row[0]:
41
+ filename = os.path.basename(row[0])
42
+ if filename:
43
+ referenced.add(filename)
44
+
45
+ return referenced
46
+ finally:
47
+ conn.close()
48
+
49
+
50
+ def cleanup_orphaned_images(images_dir: str, db_path: str, dry_run: bool = True) -> dict:
51
+ """
52
+ Remove image files that are not referenced in the database.
53
+
54
+ Args:
55
+ images_dir: Directory containing images
56
+ db_path: Path to SQLite database
57
+ dry_run: If True, only report what would be deleted without actually deleting
58
+
59
+ Returns:
60
+ Dict with cleanup statistics
61
+ """
62
+ if not os.path.exists(images_dir):
63
+ logger.warning(f"Images directory does not exist: {images_dir}")
64
+ return {
65
+ "total_images": 0,
66
+ "referenced": 0,
67
+ "orphaned": 0,
68
+ "deleted": 0,
69
+ "errors": 0,
70
+ }
71
+
72
+ # Get referenced images from database
73
+ referenced = get_referenced_images(db_path)
74
+ logger.info(f"Found {len(referenced)} referenced images in database")
75
+
76
+ # Get all image files in directory
77
+ image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp"}
78
+ all_images = []
79
+
80
+ for file_path in Path(images_dir).iterdir():
81
+ if file_path.is_file() and file_path.suffix.lower() in image_extensions:
82
+ all_images.append(file_path.name)
83
+
84
+ total_images = len(all_images)
85
+ logger.info(f"Found {total_images} image files in directory")
86
+
87
+ # Find orphaned images
88
+ orphaned = [img for img in all_images if img not in referenced]
89
+
90
+ stats = {
91
+ "total_images": total_images,
92
+ "referenced": len(referenced),
93
+ "orphaned": len(orphaned),
94
+ "deleted": 0,
95
+ "errors": 0,
96
+ }
97
+
98
+ if not orphaned:
99
+ logger.info("No orphaned images found")
100
+ return stats
101
+
102
+ logger.info(f"Found {len(orphaned)} orphaned images")
103
+
104
+ # Delete orphaned images
105
+ for filename in orphaned:
106
+ file_path = os.path.join(images_dir, filename)
107
+ try:
108
+ if not dry_run:
109
+ os.remove(file_path)
110
+ logger.debug(f"Deleted orphaned image: {filename}")
111
+ else:
112
+ logger.debug(f"Would delete orphaned image: {filename}")
113
+ stats["deleted"] += 1
114
+ except Exception as e:
115
+ logger.error(f"Failed to delete {filename}: {e}")
116
+ stats["errors"] += 1
117
+
118
+ if dry_run:
119
+ logger.info(f"DRY RUN: Would delete {stats['deleted']} orphaned images")
120
+ else:
121
+ logger.info(f"Deleted {stats['deleted']} orphaned images")
122
+
123
+ return stats
124
+
125
+
126
+ def cleanup_old_images(images_dir: str, db_path: str, days_old: int = 30, dry_run: bool = True) -> dict:
127
+ """
128
+ Remove images older than specified days that are not referenced in recent predictions.
129
+
130
+ Args:
131
+ images_dir: Directory containing images
132
+ db_path: Path to SQLite database
133
+ days_old: Remove images older than this many days
134
+ dry_run: If True, only report what would be deleted
135
+
136
+ Returns:
137
+ Dict with cleanup statistics
138
+ """
139
+ import datetime
140
+
141
+ if not os.path.exists(images_dir):
142
+ return {
143
+ "total_images": 0,
144
+ "old_images": 0,
145
+ "deleted": 0,
146
+ "errors": 0,
147
+ }
148
+
149
+ # Calculate cutoff date
150
+ cutoff_date = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days_old)
151
+ cutoff_iso = cutoff_date.isoformat()
152
+
153
+ # Get images referenced after cutoff
154
+ conn = sqlite3.connect(db_path, timeout=10)
155
+ try:
156
+ cur = conn.cursor()
157
+ cur.execute("""
158
+ SELECT DISTINCT image_path
159
+ FROM predictions
160
+ WHERE image_path IS NOT NULL
161
+ AND image_path != ''
162
+ AND ts >= ?
163
+ """, (cutoff_iso,))
164
+ recent_images = {os.path.basename(row[0]) for row in cur.fetchall() if row[0]}
165
+ finally:
166
+ conn.close()
167
+
168
+ # Find old images
169
+ image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp"}
170
+ old_images = []
171
+
172
+ for file_path in Path(images_dir).iterdir():
173
+ if file_path.is_file() and file_path.suffix.lower() in image_extensions:
174
+ # Check file modification time
175
+ mtime = datetime.datetime.fromtimestamp(file_path.stat().st_mtime, tz=datetime.UTC)
176
+ if mtime < cutoff_date:
177
+ # Only delete if not in recent images
178
+ if file_path.name not in recent_images:
179
+ old_images.append(file_path.name)
180
+
181
+ stats = {
182
+ "total_images": len(list(Path(images_dir).iterdir())),
183
+ "old_images": len(old_images),
184
+ "deleted": 0,
185
+ "errors": 0,
186
+ }
187
+
188
+ for filename in old_images:
189
+ file_path = os.path.join(images_dir, filename)
190
+ try:
191
+ if not dry_run:
192
+ os.remove(file_path)
193
+ stats["deleted"] += 1
194
+ except Exception as e:
195
+ logger.error(f"Failed to delete {filename}: {e}")
196
+ stats["errors"] += 1
197
+
198
+ return stats
199
+
200
+
app/image_storage.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image storage utilities for saving and serving uploaded images.
3
+ """
4
+ import os
5
+ import uuid
6
+ import shutil
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple
9
+ from werkzeug.utils import secure_filename
10
+
11
+
12
+ def ensure_images_dir(images_dir: str) -> str:
13
+ """Ensure images directory exists and return its path."""
14
+ os.makedirs(images_dir, exist_ok=True)
15
+ return images_dir
16
+
17
+
18
+ def generate_unique_filename(original_filename: str) -> str:
19
+ """
20
+ Generate a unique filename to avoid collisions.
21
+ Format: {uuid}_{secure_original_name} or just {uuid}.jpg if original is invalid
22
+ """
23
+ # Get secure base name
24
+ base_name = secure_filename(original_filename)
25
+ if not base_name:
26
+ base_name = "upload.jpg"
27
+
28
+ # Add UUID prefix for uniqueness (use full UUID to ensure uniqueness)
29
+ name, ext = os.path.splitext(base_name)
30
+ if not ext or ext.lower() not in ('.jpg', '.jpeg', '.png'):
31
+ ext = '.jpg'
32
+ unique_id = str(uuid.uuid4()) # Full UUID for better uniqueness
33
+ return f"{unique_id}_{name}{ext}"
34
+
35
+
36
+ def save_image(source_path: str, images_dir: str, original_filename: str) -> Optional[str]:
37
+ """
38
+ Save an image from source_path to images_dir with a unique filename.
39
+
40
+ Args:
41
+ source_path: Path to source image file
42
+ images_dir: Directory to save images to
43
+ original_filename: Original filename for reference
44
+
45
+ Returns:
46
+ Stored filename (relative to images_dir) or None on failure
47
+ """
48
+ try:
49
+ ensure_images_dir(images_dir)
50
+
51
+ # Generate unique filename
52
+ stored_filename = generate_unique_filename(original_filename)
53
+ dest_path = os.path.join(images_dir, stored_filename)
54
+
55
+ # Copy file
56
+ shutil.copy2(source_path, dest_path)
57
+
58
+ return stored_filename
59
+ except Exception as e:
60
+ # Log error but don't fail the request
61
+ import logging
62
+ logging.getLogger(__name__).exception(f"Failed to save image: {e}")
63
+ return None
64
+
65
+
66
+ def get_image_path(images_dir: str, filename: str) -> Optional[str]:
67
+ """
68
+ Get full path to an image file if it exists.
69
+
70
+ Args:
71
+ images_dir: Base images directory
72
+ filename: Image filename
73
+
74
+ Returns:
75
+ Full path to image or None if not found
76
+ """
77
+ if not filename:
78
+ return None
79
+
80
+ # Security: ensure filename doesn't contain path traversal
81
+ # Extract just the basename to prevent directory traversal
82
+ base_filename = os.path.basename(filename)
83
+ safe_filename = secure_filename(base_filename)
84
+
85
+ if not safe_filename:
86
+ return None
87
+
88
+ # Use safe_filename for the path (secure_filename may have sanitized it)
89
+ # But also try the original if it's already safe
90
+ image_path = os.path.join(images_dir, safe_filename)
91
+
92
+ if os.path.exists(image_path) and os.path.isfile(image_path):
93
+ return image_path
94
+
95
+ # Also try the original filename if it's different and seems safe
96
+ if safe_filename != base_filename:
97
+ # Check if original is safe (no path separators, no parent dir references)
98
+ if base_filename == filename and '/' not in base_filename and '\\' not in base_filename and '..' not in base_filename:
99
+ alt_path = os.path.join(images_dir, base_filename)
100
+ if os.path.exists(alt_path) and os.path.isfile(alt_path):
101
+ return alt_path
102
+
103
+ return None
104
+
105
+
106
+ def delete_image(images_dir: str, filename: str) -> bool:
107
+ """
108
+ Delete an image file.
109
+
110
+ Args:
111
+ images_dir: Base images directory
112
+ filename: Image filename to delete
113
+
114
+ Returns:
115
+ True if deleted, False otherwise
116
+ """
117
+ try:
118
+ image_path = get_image_path(images_dir, filename)
119
+ if image_path and os.path.exists(image_path):
120
+ os.remove(image_path)
121
+ return True
122
+ return False
123
+ except Exception:
124
+ return False
app/model_loader.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/model_loader.py
2
+ import os
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Tuple, Any, Optional, Dict
6
+
7
+ DEFAULT_LABELS = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
8
+ # HardlyHumans model uses 8 emotions (adds contempt)
9
+ HARDLYHUMANS_LABELS = ['anger', 'contempt', 'sad', 'happy', 'neutral', 'disgust', 'fear', 'surprise']
10
+
11
+ def load_emotion_model(force_model: str = None):
12
+ """
13
+ Load emotion detection model. Supports both Keras and Vision Transformer models.
14
+
15
+ Args:
16
+ force_model: 'base' to force base model, 'fine-tuned' to force fine-tuned, None for auto
17
+
18
+ Returns: (model_dict, labels, model_version, model_type)
19
+ model_dict: For ViT: {'model': model, 'processor': processor, 'type': 'vit'}
20
+ For Keras: model object
21
+ model_type: 'keras' or 'vit' (Vision Transformer)
22
+ """
23
+ this_dir = Path(__file__).resolve().parent # app/
24
+ repo_root = this_dir.parent # project root (/app in container)
25
+ models_dir = repo_root / "models"
26
+ fine_tuned_dir = models_dir / "fine_tuned_vit"
27
+
28
+ # Try to load fine-tuned model first (trained on FER2013 for better happy/surprise detection)
29
+ # Unless force_model is 'base'
30
+ if force_model != 'base':
31
+ try:
32
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
33
+
34
+ # Check if fine-tuned model exists
35
+ if fine_tuned_dir.exists() and (fine_tuned_dir / "model.safetensors").exists():
36
+ print(f"[MODEL] 🎯 Loading Asripa model (FER2013 Enhanced): {fine_tuned_dir}")
37
+ print(f"[MODEL] Accuracy: 78.26% (fine-tuned on FER2013)")
38
+ print(f"[MODEL] Optimized for happy/surprise detection!")
39
+
40
+ processor = AutoImageProcessor.from_pretrained(
41
+ str(fine_tuned_dir),
42
+ local_files_only=True
43
+ )
44
+ model = AutoModelForImageClassification.from_pretrained(
45
+ str(fine_tuned_dir),
46
+ local_files_only=True,
47
+ low_cpu_mem_usage=True
48
+ )
49
+
50
+ # Get labels from model config
51
+ raw_labels = [model.config.id2label[i] for i in range(len(model.config.id2label))]
52
+ print(f"[MODEL] Raw labels from model config: {raw_labels}")
53
+
54
+ # Normalize label names to match our format (lowercase, standardize)
55
+ label_map = {
56
+ 'anger': 'angry',
57
+ 'disgust': 'disgust',
58
+ 'fear': 'fear',
59
+ 'happy': 'happy',
60
+ 'neutral': 'neutral',
61
+ 'sad': 'sad',
62
+ 'surprise': 'surprise',
63
+ 'contempt': 'contempt'
64
+ }
65
+ labels = [label_map.get(label.lower(), label.lower()) for label in raw_labels]
66
+ print(f"[MODEL] Normalized labels: {labels}")
67
+
68
+ print(f"[MODEL] ✅ Fine-tuned ViT model loaded successfully!")
69
+ return {
70
+ 'model': model,
71
+ 'processor': processor,
72
+ 'type': 'vit'
73
+ }, labels, "asripa-vit-78.26%", 'vit'
74
+ else:
75
+ if force_model == 'fine-tuned':
76
+ print(f"[MODEL] ⚠️ Fine-tuned model requested but not found!")
77
+ raise FileNotFoundError("Fine-tuned model not found")
78
+ print(f"[MODEL] Fine-tuned model not found, using base model...")
79
+ except Exception as e:
80
+ if force_model == 'fine-tuned':
81
+ print(f"[MODEL] ⚠️ Failed to load fine-tuned model: {e}")
82
+ raise
83
+ print(f"[MODEL] ⚠️ Failed to load fine-tuned model: {e}")
84
+ print(f"[MODEL] Falling back to base HardlyHumans model...")
85
+
86
+ # Fall back to base HardlyHumans ViT model (best accuracy - 92.2%)
87
+ try:
88
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
89
+
90
+ model_id = "HardlyHumans/Facial-expression-detection"
91
+ print(f"[MODEL] Loading Base Model: {model_id}")
92
+ print(f"[MODEL] Accuracy: 92.2% - BASE MODEL")
93
+ print(f"[MODEL] Downloading from HuggingFace if not cached...")
94
+
95
+ # Load from HuggingFace - will download and cache automatically
96
+ # Use low_cpu_mem_usage to reduce memory footprint during loading
97
+ processor = AutoImageProcessor.from_pretrained(
98
+ model_id,
99
+ cache_dir=str(models_dir),
100
+ local_files_only=False # Allow download if not cached
101
+ )
102
+ model = AutoModelForImageClassification.from_pretrained(
103
+ model_id,
104
+ cache_dir=str(models_dir),
105
+ local_files_only=False, # Allow download if not cached
106
+ low_cpu_mem_usage=True # Reduce memory usage during loading
107
+ )
108
+
109
+ # Get labels from model config
110
+ raw_labels = [model.config.id2label[i] for i in range(len(model.config.id2label))]
111
+ print(f"[MODEL] Raw labels from model config: {raw_labels}")
112
+ print(f"[MODEL] Label mapping (id2label): {model.config.id2label}")
113
+
114
+ # Normalize label names to match our format (lowercase, standardize)
115
+ label_map = {
116
+ 'anger': 'angry',
117
+ 'disgust': 'disgust',
118
+ 'fear': 'fear',
119
+ 'happy': 'happy',
120
+ 'neutral': 'neutral',
121
+ 'sad': 'sad',
122
+ 'surprise': 'surprise',
123
+ 'contempt': 'contempt' # New emotion in this model
124
+ }
125
+ labels = [label_map.get(label.lower(), label.lower()) for label in raw_labels]
126
+ print(f"[MODEL] Normalized labels: {labels}")
127
+
128
+ print(f"[MODEL] ✅ ViT model loaded successfully!")
129
+ return {
130
+ 'model': model,
131
+ 'processor': processor,
132
+ 'type': 'vit'
133
+ }, labels, "base-vit-92.2%", 'vit'
134
+ except ImportError as e:
135
+ print(f"[MODEL] ❌ transformers library not installed: {e}")
136
+ print("[MODEL] Install with: pip install transformers torch")
137
+ print("[MODEL] Falling back to Keras model...")
138
+ except Exception as e:
139
+ print(f"[MODEL] ❌ Failed to load ViT model: {e}")
140
+ print(f"[MODEL] Error type: {type(e).__name__}")
141
+ print(f"[MODEL] Error message: {str(e)}")
142
+ import traceback
143
+ print(f"[MODEL] Full traceback:")
144
+ print(traceback.format_exc())
145
+ print("[MODEL] ⚠️ Falling back to Keras model (lower accuracy)...")
146
+
147
+ # Fall back to Keras models
148
+ try:
149
+ from tensorflow.keras.models import load_model
150
+ except ImportError:
151
+ raise ImportError("Neither transformers nor tensorflow.keras available. Install one of them.")
152
+
153
+ candidate_names = ["emotion_model.keras", "emotion_model.h5", "emotion_model.hdf5"]
154
+ model_path = None
155
+ for name in candidate_names:
156
+ p = models_dir / name
157
+ if p.exists():
158
+ model_path = str(p)
159
+ break
160
+
161
+ if model_path is None:
162
+ raise FileNotFoundError(f"No model file found in {models_dir}. Please add emotion_model.keras or emotion_model.h5")
163
+
164
+ print(f"[MODEL] Loading Keras model: {model_path}")
165
+ model = load_model(model_path)
166
+
167
+ # Load labels if available
168
+ labels_path = models_dir / "labels.json"
169
+ labels = DEFAULT_LABELS
170
+ if labels_path.exists():
171
+ try:
172
+ with labels_path.open("r", encoding="utf-8") as f:
173
+ labels = json.load(f)
174
+ except Exception:
175
+ labels = DEFAULT_LABELS
176
+
177
+ # Model version
178
+ version_path = models_dir / "MODEL_VERSION.txt"
179
+ version = "v_unknown"
180
+ if os.path.exists(version_path):
181
+ try:
182
+ with open(version_path, "r", encoding="utf-8") as f:
183
+ version = f.read().strip()
184
+ except Exception:
185
+ pass
186
+
187
+ return model, labels, version, 'keras'
app/rate_limiter.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple in-memory rate limiter for API endpoints.
3
+ For production, consider using Redis-based rate limiting.
4
+ """
5
+ import time
6
+ from collections import defaultdict
7
+ from typing import Dict, Tuple
8
+ from threading import Lock
9
+
10
+
11
+ class RateLimiter:
12
+ """
13
+ Simple token bucket rate limiter.
14
+ Thread-safe for basic use cases.
15
+ """
16
+
17
+ def __init__(self, max_requests: int = 100, window_seconds: int = 60):
18
+ """
19
+ Args:
20
+ max_requests: Maximum requests allowed in the time window
21
+ window_seconds: Time window in seconds
22
+ """
23
+ self.max_requests = max_requests
24
+ self.window_seconds = window_seconds
25
+ self.requests: Dict[str, list] = defaultdict(list)
26
+ self.lock = Lock()
27
+
28
+ def is_allowed(self, identifier: str) -> Tuple[bool, int]:
29
+ """
30
+ Check if a request is allowed.
31
+
32
+ Args:
33
+ identifier: Unique identifier (e.g., IP address, user ID)
34
+
35
+ Returns:
36
+ Tuple of (is_allowed, remaining_requests)
37
+ """
38
+ current_time = time.time()
39
+
40
+ with self.lock:
41
+ # Clean old requests outside the window
42
+ window_start = current_time - self.window_seconds
43
+ self.requests[identifier] = [
44
+ req_time for req_time in self.requests[identifier]
45
+ if req_time > window_start
46
+ ]
47
+
48
+ # Check if limit exceeded
49
+ if len(self.requests[identifier]) >= self.max_requests:
50
+ remaining = 0
51
+ return False, remaining
52
+
53
+ # Add current request
54
+ self.requests[identifier].append(current_time)
55
+ remaining = self.max_requests - len(self.requests[identifier])
56
+
57
+ return True, remaining
58
+
59
+ def reset(self, identifier: str = None):
60
+ """Reset rate limit for an identifier or all identifiers."""
61
+ with self.lock:
62
+ if identifier:
63
+ self.requests.pop(identifier, None)
64
+ else:
65
+ self.requests.clear()
66
+
67
+
68
+ # Global rate limiters for different endpoints
69
+ detect_limiter = RateLimiter(max_requests=30, window_seconds=60) # 30 requests per minute
70
+ logs_limiter = RateLimiter(max_requests=100, window_seconds=60) # 100 requests per minute
71
+ images_limiter = RateLimiter(max_requests=200, window_seconds=60) # 200 requests per minute
72
+
73
+
74
+ def get_client_identifier(request) -> str:
75
+ """
76
+ Get a unique identifier for rate limiting.
77
+ Uses IP address by default.
78
+ """
79
+ # Try to get real IP (behind proxy)
80
+ forwarded_for = request.headers.get("X-Forwarded-For")
81
+ if forwarded_for:
82
+ # Take the first IP in the chain
83
+ return forwarded_for.split(",")[0].strip()
84
+
85
+ real_ip = request.headers.get("X-Real-IP")
86
+ if real_ip:
87
+ return real_ip
88
+
89
+ return request.remote_addr or "unknown"
app/utils.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/utils.py
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ from typing import Optional, Tuple
6
+
7
+ def _enhance_for_detection(gray: np.ndarray) -> np.ndarray:
8
+ """
9
+ Apply light preprocessing to improve face detection on low-contrast or slightly blurry images.
10
+ Uses CLAHE (adaptive histogram equalization) and a mild bilateral filter.
11
+ """
12
+ # CLAHE for contrast
13
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
14
+ enhanced = clahe.apply(gray)
15
+
16
+ # Mild bilateral filtering to reduce noise while preserving edges (helps detection on some images)
17
+ enhanced = cv2.bilateralFilter(enhanced, d=5, sigmaColor=75, sigmaSpace=75)
18
+ return enhanced
19
+
20
+
21
+ def preprocess_face(
22
+ image_path: str,
23
+ target_size: Tuple[int, int] = (48, 48),
24
+ detect_max_dim: int = 800,
25
+ pad_ratio: float = 0.25, # Increased from 0.15 to 0.25 to preserve more context (eyes, eyebrows, mouth area)
26
+ ) -> Tuple[Optional[np.ndarray], Optional[str]]:
27
+ """
28
+ Load an image at image_path, detect a face and return a preprocessed array:
29
+ - shape: (1, H, W, 1)
30
+ - dtype: np.float32
31
+ - values scaled to [0,1]
32
+
33
+ If no face detected or on error, returns (None, None).
34
+
35
+ Parameters:
36
+ - target_size: size expected by the model (height, width).
37
+ - detect_max_dim: maximum size (longest side) used for the detection pass to speed up detection.
38
+ - pad_ratio: fraction of face box to pad on each side (helps avoid tight crops).
39
+
40
+ Returns:
41
+ - (face_array, used_filename)
42
+ """
43
+ try:
44
+ img = cv2.imread(image_path)
45
+ if img is None:
46
+ return None, None
47
+
48
+ h0, w0 = img.shape[:2]
49
+ # grayscale copy for detection
50
+ gray_full = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
51
+
52
+ # Downscale for faster detection if image is huge
53
+ scale = 1.0
54
+ max_side = max(w0, h0)
55
+ if max_side > detect_max_dim:
56
+ scale = detect_max_dim / float(max_side)
57
+ small = cv2.resize(gray_full, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR)
58
+ else:
59
+ small = gray_full.copy()
60
+
61
+ # Try to enhance small image for better detection on blurry photos
62
+ small_enh = _enhance_for_detection(small)
63
+
64
+ # Try multiple cascade classifiers for better detection
65
+ cascade_paths = [
66
+ "haarcascade_frontalface_default.xml",
67
+ "haarcascade_frontalface_alt.xml",
68
+ "haarcascade_frontalface_alt2.xml",
69
+ ]
70
+
71
+ faces = []
72
+
73
+ # Try each cascade with progressively more permissive parameters
74
+ for cascade_name in cascade_paths:
75
+ if len(faces) > 0:
76
+ break
77
+
78
+ try:
79
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name)
80
+ if face_cascade.empty():
81
+ continue
82
+
83
+ # Attempt 1: Standard detection
84
+ faces = face_cascade.detectMultiScale(
85
+ small_enh,
86
+ scaleFactor=1.1,
87
+ minNeighbors=5,
88
+ minSize=(30, 30),
89
+ flags=cv2.CASCADE_SCALE_IMAGE,
90
+ )
91
+
92
+ # Attempt 2: More permissive (helps blurry / odd-angle photos)
93
+ if len(faces) == 0:
94
+ faces = face_cascade.detectMultiScale(
95
+ small_enh,
96
+ scaleFactor=1.05,
97
+ minNeighbors=3,
98
+ minSize=(20, 20),
99
+ flags=cv2.CASCADE_SCALE_IMAGE,
100
+ )
101
+
102
+ # Attempt 3: Even more permissive (for challenging conditions)
103
+ if len(faces) == 0:
104
+ faces = face_cascade.detectMultiScale(
105
+ small_enh,
106
+ scaleFactor=1.03,
107
+ minNeighbors=2,
108
+ minSize=(15, 15),
109
+ flags=cv2.CASCADE_SCALE_IMAGE,
110
+ )
111
+
112
+ except Exception:
113
+ continue
114
+
115
+ # If still nothing, try on original (non-enhanced) image
116
+ if len(faces) == 0:
117
+ try:
118
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
119
+ if not face_cascade.empty():
120
+ # Sometimes enhancement hurts detection, try original
121
+ faces = face_cascade.detectMultiScale(
122
+ small,
123
+ scaleFactor=1.05,
124
+ minNeighbors=3,
125
+ minSize=(20, 20),
126
+ flags=cv2.CASCADE_SCALE_IMAGE,
127
+ )
128
+ except Exception:
129
+ pass
130
+
131
+ if len(faces) == 0:
132
+ return None, None
133
+
134
+ # Choose the largest detected face (usually the main subject)
135
+ faces = sorted(faces, key=lambda r: r[2] * r[3], reverse=True)
136
+ (x_s, y_s, w_s, h_s) = faces[0]
137
+
138
+ # Map coordinates back to original image scale
139
+ x = int(x_s / scale)
140
+ y = int(y_s / scale)
141
+ w = int(w_s / scale)
142
+ h = int(h_s / scale)
143
+
144
+ # Pad bounding box slightly (pad_ratio of face size)
145
+ pad_w = int(w * pad_ratio)
146
+ pad_h = int(h * pad_ratio)
147
+ x1 = max(0, x - pad_w)
148
+ y1 = max(0, y - pad_h)
149
+ x2 = min(w0, x + w + pad_w)
150
+ y2 = min(h0, y + h + pad_h)
151
+
152
+ face_crop = gray_full[y1:y2, x1:x2]
153
+
154
+ # final resize to model input
155
+ # Use INTER_CUBIC for better quality when upscaling small faces (preserves more detail for emotion recognition)
156
+ face_resized = cv2.resize(face_crop, (target_size[1], target_size[0]), interpolation=cv2.INTER_CUBIC)
157
+
158
+ # ensure numeric ndarray and float32 dtype
159
+ face_arr = np.asarray(face_resized, dtype=np.float32)
160
+
161
+ # normalize
162
+ face_arr = face_arr / 255.0
163
+
164
+ # channel & batch dims -> (1, H, W, 1)
165
+ if face_arr.ndim == 2:
166
+ face_arr = np.expand_dims(face_arr, axis=-1)
167
+ face_arr = np.expand_dims(face_arr, axis=0)
168
+
169
+ # final sanity checks
170
+ if face_arr.dtype != np.float32:
171
+ face_arr = face_arr.astype(np.float32)
172
+ if not np.isfinite(face_arr).all():
173
+ return None, None
174
+
175
+ used_filename = os.path.basename(image_path) or "upload.jpg"
176
+ return face_arr, used_filename
177
+
178
+ except Exception:
179
+ # don't leak internals to caller; let app log exceptions if needed
180
+ return None, None
app/validators.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Request validation utilities.
3
+ """
4
+ import os
5
+ from typing import Tuple, Optional
6
+ from werkzeug.utils import secure_filename
7
+ from PIL import Image
8
+
9
+
10
+ def validate_image_file(file, max_size: int, allowed_extensions: tuple) -> Tuple[bool, Optional[str], Optional[str]]:
11
+ """
12
+ Validate uploaded image file.
13
+
14
+ Args:
15
+ file: FileStorage object from Flask
16
+ max_size: Maximum file size in bytes
17
+ allowed_extensions: Tuple of allowed extensions (e.g., (".jpg", ".png"))
18
+
19
+ Returns:
20
+ Tuple of (is_valid, error_message, sanitized_filename)
21
+ If valid: (True, None, filename)
22
+ If invalid: (False, error_message, None)
23
+ """
24
+ if not file or not file.filename:
25
+ return False, "No file provided", None
26
+
27
+ # Check filename
28
+ filename = secure_filename(file.filename)
29
+ if not filename:
30
+ return False, "Invalid filename", None
31
+
32
+ # Check extension
33
+ ext = os.path.splitext(filename)[1].lower()
34
+ if ext not in allowed_extensions:
35
+ return False, f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}", None
36
+
37
+ # Check file size (if available)
38
+ try:
39
+ file.seek(0, os.SEEK_END)
40
+ file_size = file.tell()
41
+ file.seek(0) # Reset to beginning
42
+
43
+ if file_size > max_size:
44
+ max_mb = max_size / (1024 * 1024)
45
+ return False, f"File too large. Maximum size: {max_mb:.1f}MB", None
46
+
47
+ if file_size == 0:
48
+ return False, "File is empty", None
49
+ except Exception:
50
+ # If we can't check size, continue (will be caught by MAX_CONTENT_LENGTH)
51
+ pass
52
+
53
+ # Validate it's actually an image by trying to open it
54
+ try:
55
+ file.seek(0)
56
+ img = Image.open(file)
57
+ img.verify() # Verify it's a valid image
58
+ file.seek(0) # Reset after verification
59
+ except Exception as e:
60
+ return False, f"Invalid image file: {str(e)}", None
61
+
62
+ return True, None, filename
63
+
64
+
65
+ def validate_pagination_params(limit: Optional[str], offset: Optional[str]) -> Tuple[int, int, Optional[str]]:
66
+ """
67
+ Validate pagination parameters.
68
+
69
+ Returns:
70
+ Tuple of (limit, offset, error_message)
71
+ """
72
+ try:
73
+ limit_val = int(limit) if limit else 20
74
+ limit_val = max(1, min(200, limit_val))
75
+ except ValueError:
76
+ return 20, 0, "Invalid limit parameter. Must be an integer."
77
+
78
+ try:
79
+ offset_val = int(offset) if offset else 0
80
+ offset_val = max(0, offset_val)
81
+ except ValueError:
82
+ return limit_val, 0, "Invalid offset parameter. Must be an integer."
83
+
84
+ return limit_val, offset_val, None
85
+
86
+
87
+ def validate_confidence_range(min_conf: Optional[str], max_conf: Optional[str]) -> Tuple[Optional[float], Optional[float], Optional[str]]:
88
+ """
89
+ Validate confidence range parameters.
90
+
91
+ Returns:
92
+ Tuple of (min_confidence, max_confidence, error_message)
93
+ """
94
+ min_val = None
95
+ max_val = None
96
+
97
+ if min_conf:
98
+ try:
99
+ min_val = float(min_conf)
100
+ if not 0 <= min_val <= 1:
101
+ return None, None, "min_confidence must be between 0 and 1"
102
+ except ValueError:
103
+ return None, None, "Invalid min_confidence parameter. Must be a number."
104
+
105
+ if max_conf:
106
+ try:
107
+ max_val = float(max_conf)
108
+ if not 0 <= max_val <= 1:
109
+ return None, None, "max_confidence must be between 0 and 1"
110
+ except ValueError:
111
+ return None, None, "Invalid max_confidence parameter. Must be a number."
112
+
113
+ if min_val is not None and max_val is not None and min_val > max_val:
114
+ return None, None, "min_confidence cannot be greater than max_confidence"
115
+
116
+ return min_val, max_val, None
app/vit_utils.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/vit_utils.py
2
+ """
3
+ Utilities for Vision Transformer (ViT) model preprocessing and prediction.
4
+ """
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ from typing import Optional, Tuple, Dict, Any
9
+ from app.utils import preprocess_face # Reuse face detection
10
+
11
+ def preprocess_face_for_vit(
12
+ image_path: str,
13
+ detect_max_dim: int = 800,
14
+ pad_ratio: float = 0.35, # Increased to 0.35 to include more facial context - helps with happy detection (smile needs more context)
15
+ ) -> Tuple[Optional[Image.Image], Optional[str]]:
16
+ """
17
+ Preprocess face for Vision Transformer model.
18
+ ViT needs RGB images at 224x224, not grayscale 48x48.
19
+
20
+ Returns: (PIL Image, filename) or (None, None) if no face detected
21
+ """
22
+ # First detect and crop face (reuse existing detection logic)
23
+ # But we'll keep it in RGB and resize to 224x224
24
+ try:
25
+ img = cv2.imread(image_path)
26
+ if img is None:
27
+ return None, None
28
+
29
+ h0, w0 = img.shape[:2]
30
+ # Keep RGB for ViT (not grayscale)
31
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
32
+ gray_full = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
33
+
34
+ # Downscale for faster detection if image is huge
35
+ scale = 1.0
36
+ max_side = max(w0, h0)
37
+ if max_side > detect_max_dim:
38
+ scale = detect_max_dim / float(max_side)
39
+ small = cv2.resize(gray_full, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR)
40
+ else:
41
+ small = gray_full.copy()
42
+
43
+ # Enhance for detection
44
+ from app.utils import _enhance_for_detection
45
+ small_enh = _enhance_for_detection(small)
46
+
47
+ # Optimized face detection: 2 cascades × 2 param sets = 4 attempts (fast)
48
+ # Then fallback to 3rd cascade if needed = +2 attempts (total 6 max)
49
+ # This balances speed (4 attempts) with reliability (6 attempts if needed)
50
+ cascade_paths_primary = [
51
+ "haarcascade_frontalface_default.xml", # Most reliable
52
+ "haarcascade_frontalface_alt.xml", # Good fallback
53
+ ]
54
+
55
+ cascade_paths_fallback = [
56
+ "haarcascade_frontalface_alt2.xml", # Last resort
57
+ ]
58
+
59
+ faces = []
60
+
61
+ # Primary: Try 2 cascades with 2 param sets each (4 attempts, fast path)
62
+ for cascade_name in cascade_paths_primary:
63
+ if len(faces) > 0:
64
+ break
65
+ try:
66
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name)
67
+ if face_cascade.empty():
68
+ continue
69
+
70
+ # Attempt 1: Most common successful params (catches 90%+ of faces)
71
+ faces = face_cascade.detectMultiScale(
72
+ small_enh,
73
+ scaleFactor=1.05,
74
+ minNeighbors=3,
75
+ minSize=(20, 20),
76
+ flags=cv2.CASCADE_SCALE_IMAGE,
77
+ )
78
+
79
+ # Attempt 2: More permissive (catches challenging cases)
80
+ if len(faces) == 0:
81
+ faces = face_cascade.detectMultiScale(
82
+ small_enh,
83
+ scaleFactor=1.03,
84
+ minNeighbors=2,
85
+ minSize=(15, 15),
86
+ flags=cv2.CASCADE_SCALE_IMAGE,
87
+ )
88
+
89
+ except Exception:
90
+ continue
91
+
92
+ # Fallback: Only try 3rd cascade if primary failed (adds 2 more attempts)
93
+ if len(faces) == 0:
94
+ for cascade_name in cascade_paths_fallback:
95
+ if len(faces) > 0:
96
+ break
97
+ try:
98
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_name)
99
+ if face_cascade.empty():
100
+ continue
101
+
102
+ # Try with permissive params
103
+ for scale_factor, min_neighbors, min_size in [
104
+ (1.05, 3, (20, 20)),
105
+ (1.03, 2, (15, 15)),
106
+ ]:
107
+ faces = face_cascade.detectMultiScale(
108
+ small_enh,
109
+ scaleFactor=scale_factor,
110
+ minNeighbors=min_neighbors,
111
+ minSize=min_size,
112
+ flags=cv2.CASCADE_SCALE_IMAGE,
113
+ )
114
+ if len(faces) > 0:
115
+ break
116
+ except Exception:
117
+ continue
118
+
119
+ # Fallback 1: Try on original (non-enhanced) image if enhanced failed
120
+ # Only try once with best params (don't waste time on multiple attempts)
121
+ if len(faces) == 0:
122
+ try:
123
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
124
+ if not face_cascade.empty():
125
+ # Single attempt with most successful params (faster than trying multiple)
126
+ faces = face_cascade.detectMultiScale(
127
+ small, # Use original, not enhanced
128
+ scaleFactor=1.05,
129
+ minNeighbors=3,
130
+ minSize=(20, 20),
131
+ flags=cv2.CASCADE_SCALE_IMAGE,
132
+ )
133
+ except Exception:
134
+ pass
135
+
136
+ # Fallback 2: Try on full-size image ONLY if:
137
+ # 1. Still no face found
138
+ # 2. Image was actually downscaled (max_side > 800)
139
+ # 3. Scale is significantly reduced (scale < 0.5, meaning image is 2x+ larger)
140
+ # This prevents slow full-size detection on images that are only slightly over 800px
141
+ if len(faces) == 0 and max_side > detect_max_dim and scale < 0.5:
142
+ try:
143
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
144
+ if not face_cascade.empty():
145
+ # Single attempt with permissive params (full-size is slow, so only try once)
146
+ faces = face_cascade.detectMultiScale(
147
+ gray_full,
148
+ scaleFactor=1.05,
149
+ minNeighbors=2,
150
+ minSize=(30, 30), # Larger min size for full-res
151
+ flags=cv2.CASCADE_SCALE_IMAGE,
152
+ )
153
+ except Exception:
154
+ pass
155
+
156
+ if len(faces) == 0:
157
+ return None, None
158
+
159
+ # Choose largest face
160
+ faces = sorted(faces, key=lambda r: r[2] * r[3], reverse=True)
161
+ (x_s, y_s, w_s, h_s) = faces[0]
162
+
163
+ # Map back to original scale (only if we used downscaled detection)
164
+ # If we detected on full-size image, coordinates are already correct
165
+ if max_side > detect_max_dim and scale < 1.0:
166
+ # Detection was on downscaled image
167
+ x = int(x_s / scale)
168
+ y = int(y_s / scale)
169
+ w = int(w_s / scale)
170
+ h = int(h_s / scale)
171
+ else:
172
+ # Detection was on full-size or original scale
173
+ x = x_s
174
+ y = y_s
175
+ w = w_s
176
+ h = h_s
177
+
178
+ # Pad bounding box
179
+ pad_w = int(w * pad_ratio)
180
+ pad_h = int(h * pad_ratio)
181
+ x1 = max(0, x - pad_w)
182
+ y1 = max(0, y - pad_h)
183
+ x2 = min(w0, x + w + pad_w)
184
+ y2 = min(h0, y + h + pad_h)
185
+
186
+ # Crop face from RGB image (not grayscale)
187
+ face_crop = img_rgb[y1:y2, x1:x2]
188
+
189
+ # Convert to PIL Image and resize to 224x224 (ViT input size)
190
+ # Use BICUBIC for best quality (emotion recognition needs detail)
191
+ # Note: ViT processor handles normalization, so we don't apply CLAHE here
192
+ # CLAHE can interfere with the model's expected input distribution
193
+ face_pil = Image.fromarray(face_crop)
194
+ face_pil = face_pil.resize((224, 224), Image.Resampling.BICUBIC)
195
+
196
+ import os
197
+ used_filename = os.path.basename(image_path) or "upload.jpg"
198
+ return face_pil, used_filename
199
+
200
+ except Exception as e:
201
+ import logging
202
+ logger = logging.getLogger(__name__)
203
+ logger.exception(f"Exception in preprocess_face_for_vit for {image_path}: {e}")
204
+ return None, None
205
+
206
+ def predict_with_vit(
207
+ model_dict: Dict[str, Any],
208
+ image: Image.Image,
209
+ labels: list
210
+ ) -> Tuple[int, float, Dict[str, float]]:
211
+ """
212
+ Run prediction using Vision Transformer model.
213
+ Enhanced for better accuracy with image preprocessing.
214
+
215
+ Args:
216
+ model_dict: {'model': model, 'processor': processor, 'type': 'vit'}
217
+ image: PIL Image (224x224 RGB)
218
+ labels: List of emotion labels
219
+
220
+ Returns:
221
+ (predicted_index, confidence, all_probabilities_dict)
222
+ """
223
+ processor = model_dict['processor']
224
+ model = model_dict['model']
225
+
226
+ # Ensure image is RGB (some images might be RGBA or grayscale)
227
+ if image.mode != 'RGB':
228
+ image = image.convert('RGB')
229
+
230
+ # Preprocess image for ViT (processor handles normalization)
231
+ inputs = processor(image, return_tensors="pt")
232
+
233
+ # Run prediction - optimized for speed
234
+ import torch
235
+ import torch.nn.functional as F
236
+
237
+ model.eval()
238
+ # Use inference_mode() instead of no_grad() - faster for inference-only
239
+ with torch.inference_mode(): # Faster than no_grad() for pure inference
240
+ outputs = model(**inputs)
241
+ logits = outputs.logits
242
+
243
+ # Get probabilities (softmax) - optimized conversion
244
+ probs = F.softmax(logits, dim=-1)
245
+ probs_np = probs[0].cpu().numpy() # Direct indexing, no detach needed in inference_mode
246
+
247
+ # Get predicted class
248
+ predicted_idx = int(torch.argmax(logits, dim=-1).item())
249
+ confidence = float(probs_np[predicted_idx])
250
+
251
+ # Create probabilities dict - use model's id2label directly to ensure correct mapping
252
+ all_probs = {}
253
+ model = model_dict['model']
254
+ for i, prob in enumerate(probs_np):
255
+ # Use model's id2label for accurate label mapping
256
+ if hasattr(model, 'config') and hasattr(model.config, 'id2label'):
257
+ raw_label = model.config.id2label.get(i, f"class_{i}")
258
+ # Normalize label name
259
+ label_map = {
260
+ 'anger': 'angry',
261
+ 'disgust': 'disgust',
262
+ 'fear': 'fear',
263
+ 'happy': 'happy',
264
+ 'neutral': 'neutral',
265
+ 'sad': 'sad',
266
+ 'surprise': 'surprise',
267
+ 'contempt': 'contempt'
268
+ }
269
+ normalized_label = label_map.get(raw_label.lower(), raw_label.lower())
270
+ all_probs[normalized_label] = float(prob)
271
+ elif i < len(labels):
272
+ all_probs[labels[i]] = float(prob)
273
+ else:
274
+ all_probs[f"class_{i}"] = float(prob)
275
+
276
+ # Post-processing: If happy probability is reasonable (>0.05) but contempt/neutral is high,
277
+ # and happy is in top 3, boost happy probability (model has known happy/contempt confusion)
278
+ happy_prob = all_probs.get('happy', 0.0)
279
+ contempt_prob = all_probs.get('contempt', 0.0)
280
+ neutral_prob = all_probs.get('neutral', 0.0)
281
+
282
+ # If happy is in top 3 probabilities and contempt/neutral is suspiciously high
283
+ sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
284
+ top_3_emotions = [e[0] for e in sorted_probs[:3]]
285
+
286
+ if 'happy' in top_3_emotions and happy_prob > 0.05:
287
+ # If contempt or neutral is highest but happy is close, boost happy
288
+ if (contempt_prob > 0.4 or neutral_prob > 0.4) and happy_prob > 0.05:
289
+ # Boost happy by 30% (helps correct misclassifications)
290
+ boost_factor = 1.3
291
+ boosted_happy = min(1.0, happy_prob * boost_factor)
292
+
293
+ # Reduce contempt/neutral proportionally to maintain probability sum
294
+ reduction = (boosted_happy - happy_prob) / 2
295
+ new_contempt = max(0.0, contempt_prob - reduction)
296
+ new_neutral = max(0.0, neutral_prob - reduction)
297
+
298
+ # Update probabilities
299
+ all_probs['happy'] = boosted_happy
300
+ all_probs['contempt'] = new_contempt
301
+ all_probs['neutral'] = new_neutral
302
+
303
+ # Re-normalize to ensure sum is ~1.0
304
+ total = sum(all_probs.values())
305
+ if total > 0:
306
+ all_probs = {k: v / total for k, v in all_probs.items()}
307
+
308
+ # Recalculate predicted class after boosting - find emotion with highest prob
309
+ new_top_emotion = max(all_probs.items(), key=lambda x: x[1])[0]
310
+
311
+ # Find index in labels list
312
+ if new_top_emotion in labels:
313
+ predicted_idx = labels.index(new_top_emotion)
314
+ confidence = all_probs[new_top_emotion]
315
+ print(f"[VIT] Post-processing: Boosted happy from {happy_prob:.3f} to {all_probs.get('happy', 0.0):.3f}, new prediction: {new_top_emotion}")
316
+ else:
317
+ # Fallback to original prediction if label not found
318
+ print(f"[VIT] Post-processing: Boosted happy but couldn't find label {new_top_emotion} in labels list")
319
+
320
+ print(f"[VIT] Predicted index: {predicted_idx}, Raw label from model: {model.config.id2label.get(predicted_idx, 'unknown')}")
321
+
322
+ return predicted_idx, confidence, all_probs
323
+
entrypoint_hf.sh ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ set -eu
3
+
4
+ # Where the app expects the model inside the container
5
+ MODEL_PATH="/app/models/emotion_model.keras"
6
+
7
+ # Public release URL (change if you host elsewhere)
8
+ MODEL_URL="https://github.com/iyinoluwAA/Emotion-detection/releases/download/v1.0.0/emotion_model.keras"
9
+
10
+ # Ensure models dir exists
11
+ mkdir -p "$(dirname "$MODEL_PATH")"
12
+
13
+ if [ ! -f "$MODEL_PATH" ]; then
14
+ echo "Model not found at $MODEL_PATH — attempting download from $MODEL_URL"
15
+ if command -v curl >/dev/null 2>&1; then
16
+ curl -fSL "$MODEL_URL" -o "$MODEL_PATH" || {
17
+ echo "curl failed to download model"; ls -la "$(dirname "$MODEL_PATH")"; exit 1;
18
+ }
19
+ elif command -v wget >/dev/null 2>&1; then
20
+ wget -O "$MODEL_PATH" "$MODEL_URL" || {
21
+ echo "wget failed to download model"; ls -la "$(dirname "$MODEL_PATH")"; exit 1;
22
+ }
23
+ else
24
+ echo "No curl or wget available in the image. Install one in Dockerfile."; exit 1
25
+ fi
26
+ else
27
+ echo "Model already present at $MODEL_PATH"
28
+ fi
29
+
30
+ # ensure readable
31
+ chmod a+r "$MODEL_PATH" || true
32
+
33
+ # Download Asripa model (fine-tuned) if not present
34
+ ASRIPA_MODEL_DIR="/app/models/fine_tuned_vit"
35
+ ASRIPA_MODEL_ID="${ASRIPA_MODEL_ID:-HimAJ/asripa-emotion-detection}"
36
+
37
+ if [ -n "$ASRIPA_MODEL_ID" ] && [ ! -f "$ASRIPA_MODEL_DIR/model.safetensors" ]; then
38
+ echo "📥 Downloading Asripa model from HuggingFace..."
39
+ echo " Model ID: $ASRIPA_MODEL_ID"
40
+ mkdir -p "$ASRIPA_MODEL_DIR"
41
+
42
+ # Use Python to download (huggingface_hub is in requirements)
43
+ python3 -c "
44
+ from huggingface_hub import snapshot_download
45
+ import os
46
+ import sys
47
+ try:
48
+ snapshot_download(
49
+ repo_id='$ASRIPA_MODEL_ID',
50
+ local_dir='$ASRIPA_MODEL_DIR',
51
+ local_dir_use_symlinks=False
52
+ )
53
+ print('✅ Asripa model downloaded successfully!')
54
+ except Exception as e:
55
+ print(f'⚠️ Failed to download Asripa model: {e}')
56
+ print(' App will use base model only')
57
+ import shutil
58
+ if os.path.exists('$ASRIPA_MODEL_DIR'):
59
+ shutil.rmtree('$ASRIPA_MODEL_DIR')
60
+ sys.exit(0) # Exit gracefully, not an error
61
+ " || {
62
+ echo "⚠️ Asripa model download skipped"
63
+ echo " App will use base model only"
64
+ rm -rf "$ASRIPA_MODEL_DIR" 2>/dev/null || true
65
+ }
66
+ elif [ -f "$ASRIPA_MODEL_DIR/model.safetensors" ]; then
67
+ echo "✅ Asripa model already present"
68
+ elif [ -z "$ASRIPA_MODEL_ID" ]; then
69
+ echo "ℹ️ ASRIPA_MODEL_ID not set - skipping Asripa model download"
70
+ fi
71
+
72
+ # Hugging Face Spaces uses port 7860 by default
73
+ # But we'll use PORT env var if set, otherwise default to 7860
74
+ PORT="${PORT:-7860}"
75
+ echo "Starting gunicorn on 0.0.0.0:${PORT}"
76
+ # Suppress protobuf warnings
77
+ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
78
+ exec gunicorn main:app --bind 0.0.0.0:"${PORT}" --workers 1 --threads 1 --timeout 120 --worker-class gthread
79
+
main.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import os
3
+ import logging
4
+ import warnings
5
+
6
+ # Suppress protobuf version warnings (they're harmless but noisy)
7
+ warnings.filterwarnings("ignore", category=UserWarning, module="google.protobuf")
8
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
9
+
10
+ # Make PROJECT_ROOT explicit so module-level code in the container works reliably
11
+ PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
12
+
13
+ # Ensure logs dir exists
14
+ LOG_DIR = os.path.join(PROJECT_ROOT, "logs")
15
+ os.makedirs(LOG_DIR, exist_ok=True)
16
+
17
+ # Configure file logging (keeps container stdout clean and persists errors)
18
+ logfile = os.path.join(LOG_DIR, "app.log")
19
+ handler = logging.FileHandler(logfile)
20
+ handler.setLevel(logging.INFO)
21
+ formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(module)s: %(message)s")
22
+ handler.setFormatter(formatter)
23
+
24
+ root_logger = logging.getLogger()
25
+ # Add handler only if not already added (avoids duplicates in dev reload)
26
+ if not any(isinstance(h, logging.FileHandler) and getattr(h, "baseFilename", "") == logfile for h in root_logger.handlers):
27
+ root_logger.addHandler(handler)
28
+
29
+ # Import factory after logging and directory setup so imports don't crash during bootstrap
30
+ from app import create_app
31
+
32
+ # Create app (allow env-driven config if needed)
33
+ app = create_app()
34
+
35
+ if __name__ == "__main__":
36
+ # allow overriding host/port via env (useful in Docker)
37
+ host = os.environ.get("HOST", "0.0.0.0")
38
+ port = int(os.environ.get("PORT", os.environ.get("FLASK_RUN_PORT", 5000)))
39
+ debug = os.environ.get("FLASK_DEBUG", "0") in ("1", "true", "True")
40
+ app.logger.info("Starting app on %s:%s (debug=%s)", host, port, debug)
41
+ app.run(host=host, port=port, debug=debug)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Flask==3.1.1
2
+ flask-cors==4.0.0
3
+
4
+ # ML helpers (no TF/numpy here)
5
+ # Note: numpy<2 required for opencv-python-headless compatibility
6
+ numpy>=1.26.0,<2.0.0
7
+ h5py>=3.7.0
8
+ Pillow>=9.0.0
9
+ opencv-python-headless==4.9.0.80
10
+
11
+ # Vision Transformer support (for HardlyHumans model - 92.2% accuracy)
12
+ transformers>=4.30.0
13
+ torch>=2.0.0
14
+ huggingface_hub>=0.20.0 # For downloading Asripa model
15
+
16
+ # utilities & production
17
+ requests>=2.28.0
18
+ gunicorn>=23.0.0