Spaces:
Running
Running
| # app/__init__.py | |
| import os | |
| import datetime | |
| import csv | |
| import traceback | |
| import logging | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from werkzeug.utils import secure_filename | |
| from werkzeug.exceptions import RequestEntityTooLarge | |
| # ---------------------------- | |
| # Module-level config (deterministic) | |
| # ---------------------------- | |
| PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| TMP_DIR_DEFAULT = os.path.join(PROJECT_ROOT, "tmp") | |
| IMAGES_DIR_DEFAULT = os.path.join(PROJECT_ROOT, "images") | |
| LOG_CSV = os.path.join(PROJECT_ROOT, "predictions_log.csv") | |
| DB_PATH = os.path.join(PROJECT_ROOT, "predictions.db") | |
| # App-level defaults (can be overridden via app.config) | |
| DEFAULTS = { | |
| "MIN_CONFIDENCE": 0.18, # Lowered to 0.18 for ambiguous cases (was 0.20, originally 0.5) | |
| "MAX_FILE_SIZE": 5 * 1024 * 1024, # 5 MB | |
| "TMP_DIR": TMP_DIR_DEFAULT, | |
| "IMAGES_DIR": IMAGES_DIR_DEFAULT, | |
| "ALLOWED_EXT": (".jpg", ".jpeg", ".png"), | |
| "CORS_ORIGINS": "*", # Can be overridden for production | |
| } | |
| # Ensure directories exist | |
| os.makedirs(DEFAULTS["TMP_DIR"], exist_ok=True) | |
| os.makedirs(DEFAULTS["IMAGES_DIR"], exist_ok=True) | |
| # Ensure CSV header exists (helpful for older logs) | |
| if not os.path.exists(LOG_CSV): | |
| try: | |
| with open(LOG_CSV, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["timestamp", "filename", "emotion", "confidence"]) | |
| except Exception: | |
| # Non-fatal — keep module import light. | |
| pass | |
| # ---------------------------- | |
| # Factory | |
| # ---------------------------- | |
| def create_app(config: dict | None = None): | |
| """ | |
| Create and return the Flask application. | |
| Heavy imports (model loading, db init) are performed inside this factory | |
| so importing modules from scripts/tests doesn't trigger expensive work. | |
| """ | |
| # Merge defaults with provided config | |
| cfg = DEFAULTS.copy() | |
| if config: | |
| cfg.update(config) | |
| app = Flask(__name__) | |
| # CORS configuration - allow config override | |
| cors_origins = cfg.get("CORS_ORIGINS", DEFAULTS["CORS_ORIGINS"]) | |
| if cors_origins == "*": | |
| CORS(app, resources={r"/*": {"origins": "*"}}) | |
| else: | |
| # Allow list of origins | |
| origins_list = cors_origins.split(",") if isinstance(cors_origins, str) else cors_origins | |
| CORS(app, resources={r"/*": {"origins": origins_list}}) | |
| # ---------- file logging setup (after app created) ---------- | |
| LOG_DIR = os.path.join(PROJECT_ROOT, "logs") | |
| try: | |
| os.makedirs(LOG_DIR, exist_ok=True) | |
| except Exception: | |
| # If logs dir cannot be created, continue; app.logger will still work to stdout | |
| pass | |
| log_path = os.path.join(LOG_DIR, "app.log") | |
| try: | |
| file_handler = logging.FileHandler(log_path) | |
| file_handler.setLevel(logging.INFO) # change to ERROR if you prefer | |
| formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(module)s: %(message)s") | |
| file_handler.setFormatter(formatter) | |
| # avoid adding duplicate handlers when reloading | |
| abs_log_path = os.path.abspath(log_path) | |
| if not any( | |
| isinstance(h, logging.FileHandler) and getattr(h, "baseFilename", None) == abs_log_path | |
| for h in app.logger.handlers | |
| ): | |
| app.logger.addHandler(file_handler) | |
| # set app logger level (don't lower if already configured higher) | |
| app.logger.setLevel(logging.INFO) | |
| except Exception: | |
| # If logging can't be configured, keep going — logger will fallback to default handlers. | |
| app.logger.exception("Failed to configure file logging") | |
| # Apply config to app | |
| app.config["MAX_CONTENT_LENGTH"] = cfg["MAX_FILE_SIZE"] | |
| app.config["TMP_DIR"] = cfg["TMP_DIR"] | |
| app.config["IMAGES_DIR"] = cfg.get("IMAGES_DIR", DEFAULTS["IMAGES_DIR"]) | |
| app.config["ALLOWED_EXT"] = cfg["ALLOWED_EXT"] | |
| app.config["MIN_CONFIDENCE"] = cfg["MIN_CONFIDENCE"] | |
| # Ensure tmp directory exists (again, per app) | |
| os.makedirs(app.config["TMP_DIR"], exist_ok=True) | |
| # Local (deferred) imports — avoid import-time side effects | |
| from .model_loader import load_emotion_model | |
| from .db_logger import init_db, log_prediction, get_metrics, tail_rows, get_total_count, delete_prediction | |
| from .utils import preprocess_face | |
| from .image_storage import save_image, get_image_path, ensure_images_dir | |
| from .validators import validate_image_file, validate_pagination_params, validate_confidence_range | |
| from .rate_limiter import detect_limiter, logs_limiter, images_limiter, get_client_identifier | |
| # Initialize DB | |
| try: | |
| init_db(DB_PATH) | |
| app.logger.info("Initialized SQLite DB at %s", DB_PATH) | |
| except Exception: | |
| app.logger.exception("Failed to initialize DB at startup") | |
| # Load model & labels. Keep these local to the factory (no module-level side effects). | |
| # We'll load models on-demand based on request parameter | |
| base_model = None | |
| base_labels = None | |
| base_model_version = "unknown" | |
| base_model_type = "unknown" | |
| finetuned_model = None | |
| finetuned_labels = None | |
| finetuned_model_version = "unknown" | |
| finetuned_model_type = "unknown" | |
| # Load base model by default | |
| try: | |
| # load_emotion_model returns (model, labels, version, model_type) | |
| res = load_emotion_model(force_model='base') | |
| if isinstance(res, tuple) and len(res) == 4: | |
| base_model, base_labels, base_model_version, base_model_type = res | |
| elif isinstance(res, tuple) and len(res) == 3: | |
| base_model, base_labels, base_model_version = res | |
| base_model_type = "keras" # Default for old format | |
| elif isinstance(res, tuple) and len(res) == 2: | |
| base_model, base_labels = res | |
| base_model_type = "keras" # Default for old format | |
| else: | |
| # Unexpected return shape - try to be permissive | |
| try: | |
| base_model = res | |
| base_labels = None | |
| base_model_type = "keras" | |
| except Exception: | |
| base_model = None | |
| base_labels = None | |
| base_model_type = "unknown" | |
| app.logger.info("Base model loaded: %s (version=%s, type=%s)", bool(base_model), base_model_version, base_model_type) | |
| print(f"[APP] Base model loaded: type={base_model_type}, version={base_model_version}, labels={len(base_labels) if base_labels else 0}") | |
| except Exception as exc: | |
| app.logger.exception("Base model failed to load at startup: %s", exc) | |
| base_model = None | |
| base_labels = None | |
| base_model_version = "failed" | |
| base_model_type = "unknown" | |
| # Try to load fine-tuned model | |
| try: | |
| res = load_emotion_model(force_model='fine-tuned') | |
| if isinstance(res, tuple) and len(res) == 4: | |
| finetuned_model, finetuned_labels, finetuned_model_version, finetuned_model_type = res | |
| elif isinstance(res, tuple) and len(res) == 3: | |
| finetuned_model, finetuned_labels, finetuned_model_version = res | |
| finetuned_model_type = "keras" | |
| elif isinstance(res, tuple) and len(res) == 2: | |
| finetuned_model, finetuned_labels = res | |
| finetuned_model_type = "keras" | |
| app.logger.info("Asripa model loaded: %s (version=%s, type=%s)", bool(finetuned_model), finetuned_model_version, finetuned_model_type) | |
| print(f"[APP] Asripa model loaded: type={finetuned_model_type}, version={finetuned_model_version}") | |
| except Exception as exc: | |
| app.logger.warning("Asripa model not available: %s", exc) | |
| finetuned_model = None | |
| finetuned_labels = None | |
| finetuned_model_version = "not-available" | |
| finetuned_model_type = "unknown" | |
| # Store in app.config - default to base model | |
| app.config["BASE_MODEL"] = base_model | |
| app.config["BASE_LABELS"] = base_labels | |
| app.config["BASE_MODEL_VERSION"] = base_model_version | |
| app.config["BASE_MODEL_TYPE"] = base_model_type | |
| app.config["FINETUNED_MODEL"] = finetuned_model | |
| app.config["FINETUNED_LABELS"] = finetuned_labels | |
| app.config["FINETUNED_MODEL_VERSION"] = finetuned_model_version | |
| app.config["FINETUNED_MODEL_TYPE"] = finetuned_model_type | |
| # Default to base model | |
| app.config["MODEL"] = base_model | |
| app.config["LABELS"] = base_labels | |
| app.config["MODEL_VERSION"] = base_model_version | |
| app.config["MODEL_TYPE"] = base_model_type | |
| # ---------------------------- | |
| # Error handlers (import before routes to ensure proper handling) | |
| # ---------------------------- | |
| from .error_handlers import register_error_handlers, APIError, ValidationError, NotFoundError, ServiceUnavailableError | |
| register_error_handlers(app) | |
| # Make these available in route scope | |
| globals()['APIError'] = APIError | |
| globals()['ValidationError'] = ValidationError | |
| globals()['NotFoundError'] = NotFoundError | |
| globals()['ServiceUnavailableError'] = ServiceUnavailableError | |
| def handle_large_file(e): | |
| return jsonify({"error": "File too large", "max_size_mb": app.config.get("MAX_CONTENT_LENGTH", 5 * 1024 * 1024) / (1024 * 1024)}), 413 | |
| # ---------------------------- | |
| # Routes | |
| # ---------------------------- | |
| def index(): | |
| return jsonify({"status": "ok", "message": "Flask backend running"}), 200 | |
| def health(): | |
| """ | |
| Lightweight health check endpoint. | |
| Optimized for speed - minimal checks to avoid timeouts. | |
| """ | |
| try: | |
| # Quick check - don't do expensive operations | |
| model_loaded = bool(app.config.get("MODEL")) | |
| model_type = app.config.get("MODEL_TYPE", "unknown") | |
| model_version = app.config.get("MODEL_VERSION", "unknown") | |
| # Get labels count quickly | |
| labels_obj = app.config.get("LABELS") | |
| labels_count = len(labels_obj) if labels_obj and hasattr(labels_obj, "__len__") else 0 | |
| return jsonify( | |
| { | |
| "ok": True, | |
| "model_loaded": model_loaded, | |
| "model_type": model_type, | |
| "model_version": model_version, | |
| "labels_count": labels_count, | |
| } | |
| ), 200 | |
| except Exception as e: | |
| # Even if there's an error, return 200 to indicate service is running | |
| # This prevents false "offline" status | |
| app.logger.warning(f"Health check error (non-fatal): {e}") | |
| return jsonify( | |
| { | |
| "ok": True, | |
| "model_loaded": False, | |
| "model_type": "unknown", | |
| "model_version": "unknown", | |
| "labels_count": 0, | |
| "warning": "Health check had minor issues but service is running", | |
| } | |
| ), 200 | |
| def metrics(): | |
| try: | |
| m = get_metrics(DB_PATH) | |
| recent = tail_rows(DB_PATH, limit=10) | |
| return jsonify({"ok": True, "metrics": m, "recent": recent}), 200 | |
| except Exception as exc: | |
| app.logger.exception("Failed to fetch metrics") | |
| return jsonify({"error": "Failed to fetch metrics", "details": str(exc)}), 500 | |
| def logs(): | |
| """ | |
| GET /logs?limit=20&offset=0&emotion=happy&min_confidence=0.5&max_confidence=1.0&date_from=2024-01-01&date_to=2024-12-31 | |
| Returns paginated and filtered logs. | |
| """ | |
| # Rate limiting | |
| client_id = get_client_identifier(request) | |
| is_allowed, remaining = logs_limiter.is_allowed(client_id) | |
| if not is_allowed: | |
| return jsonify({ | |
| "error": "Rate limit exceeded", | |
| "detail": f"Maximum {logs_limiter.max_requests} requests per {logs_limiter.window_seconds} seconds", | |
| "retry_after": logs_limiter.window_seconds, | |
| }), 429 | |
| try: | |
| # Validate pagination | |
| limit, offset, pagination_error = validate_pagination_params( | |
| request.args.get("limit"), | |
| request.args.get("offset"), | |
| ) | |
| if pagination_error: | |
| return jsonify({"error": pagination_error}), 400 | |
| # Validate confidence range | |
| min_confidence, max_confidence, confidence_error = validate_confidence_range( | |
| request.args.get("min_confidence"), | |
| request.args.get("max_confidence"), | |
| ) | |
| if confidence_error: | |
| return jsonify({"error": confidence_error}), 400 | |
| # Filters | |
| emotion_filter = request.args.get("emotion", None) | |
| if emotion_filter and emotion_filter.strip(): | |
| emotion_filter = emotion_filter.strip() | |
| else: | |
| emotion_filter = None | |
| date_from = request.args.get("date_from", None) | |
| date_to = request.args.get("date_to", None) | |
| # Fetch data | |
| rows = tail_rows( | |
| DB_PATH, | |
| limit=limit, | |
| offset=offset, | |
| emotion_filter=emotion_filter, | |
| min_confidence=min_confidence, | |
| max_confidence=max_confidence, | |
| date_from=date_from, | |
| date_to=date_to, | |
| ) | |
| total = get_total_count( | |
| DB_PATH, | |
| emotion_filter=emotion_filter, | |
| min_confidence=min_confidence, | |
| max_confidence=max_confidence, | |
| date_from=date_from, | |
| date_to=date_to, | |
| ) | |
| # Convert to list of dicts | |
| result = [] | |
| for r in rows: | |
| if len(r) == 6: | |
| _id, ts, filename, image_path, emotion, confidence = r | |
| record = { | |
| "id": _id, | |
| "ts": ts, | |
| "filename": filename, | |
| "image_path": image_path or filename, # Fallback to filename if no image_path | |
| "emotion": emotion, | |
| "confidence": confidence, | |
| } | |
| elif len(r) == 5: | |
| _id, ts, filename, emotion, confidence = r | |
| record = { | |
| "id": _id, | |
| "ts": ts, | |
| "filename": filename, | |
| "image_path": filename, # Fallback | |
| "emotion": emotion, | |
| "confidence": confidence, | |
| } | |
| elif len(r) == 4: | |
| ts, filename, emotion, confidence = r | |
| record = { | |
| "ts": ts, | |
| "filename": filename, | |
| "image_path": filename, # Fallback | |
| "emotion": emotion, | |
| "confidence": confidence, | |
| } | |
| else: | |
| record = {"row": r} | |
| result.append(record) | |
| return jsonify({ | |
| "ok": True, | |
| "logs": result, | |
| "pagination": { | |
| "total": total, | |
| "limit": limit, | |
| "offset": offset, | |
| "has_more": (offset + limit) < total, | |
| }, | |
| }), 200 | |
| except Exception as exc: | |
| app.logger.exception("Failed to fetch logs") | |
| return jsonify({"error": "Failed to fetch logs", "detail": str(exc)}), 500 | |
| def delete_log(prediction_id: int): | |
| """ | |
| DELETE /logs/<id> | |
| Delete a prediction by ID. | |
| """ | |
| # Rate limiting | |
| client_id = get_client_identifier(request) | |
| is_allowed, remaining = logs_limiter.is_allowed(client_id) | |
| if not is_allowed: | |
| return jsonify({ | |
| "error": "Rate limit exceeded", | |
| "detail": f"Maximum {logs_limiter.max_requests} requests per {logs_limiter.window_seconds} seconds", | |
| "retry_after": logs_limiter.window_seconds, | |
| }), 429 | |
| try: | |
| # Delete from database | |
| deleted = delete_prediction(DB_PATH, prediction_id) | |
| if not deleted: | |
| return jsonify({"error": "Prediction not found"}), 404 | |
| # Optionally delete associated image file | |
| from .image_storage import delete_image | |
| # Note: We'd need to fetch the image_path first, but for now just delete from DB | |
| # You can enhance this later to also delete the image file | |
| return jsonify({"ok": True, "message": "Prediction deleted successfully"}), 200 | |
| except Exception as exc: | |
| app.logger.exception(f"Failed to delete prediction {prediction_id}") | |
| return jsonify({"error": "Failed to delete prediction", "detail": str(exc)}), 500 | |
| def detect(): | |
| """ | |
| POST form-data: image file under key 'image' | |
| Returns: JSON {emotion, confidence} or error JSON | |
| """ | |
| # Rate limiting | |
| client_id = get_client_identifier(request) | |
| is_allowed, remaining = detect_limiter.is_allowed(client_id) | |
| if not is_allowed: | |
| return jsonify({ | |
| "error": "Rate limit exceeded", | |
| "detail": f"Maximum {detect_limiter.max_requests} requests per {detect_limiter.window_seconds} seconds", | |
| "retry_after": detect_limiter.window_seconds, | |
| }), 429 | |
| # Get model selection from query parameter (default: 'base') | |
| model_selection = request.args.get("model", "base").lower() | |
| if model_selection == "fine-tuned" or model_selection == "finetuned": | |
| model_local = app.config.get("FINETUNED_MODEL") | |
| labels_local = app.config.get("FINETUNED_LABELS") or [] | |
| model_type = app.config.get("FINETUNED_MODEL_TYPE", "keras") | |
| model_version = app.config.get("FINETUNED_MODEL_VERSION", "unknown") | |
| if model_local is None: | |
| app.logger.warning("Asripa model requested but not available, using base model") | |
| model_local = app.config.get("BASE_MODEL") | |
| labels_local = app.config.get("BASE_LABELS") or [] | |
| model_type = app.config.get("BASE_MODEL_TYPE", "keras") | |
| model_version = app.config.get("BASE_MODEL_VERSION", "unknown") | |
| else: | |
| # Use base model (default) | |
| model_local = app.config.get("BASE_MODEL") | |
| labels_local = app.config.get("BASE_LABELS") or [] | |
| model_type = app.config.get("BASE_MODEL_TYPE", "keras") | |
| model_version = app.config.get("BASE_MODEL_VERSION", "unknown") | |
| app.logger.info(f"Using model: {model_selection} (version: {model_version})") | |
| if model_local is None: | |
| app.logger.error("Detect called but model not loaded") | |
| raise ServiceUnavailableError("Model not loaded on server") | |
| print(f"[DETECT] Using model type: {model_type}") | |
| # Validate upload presence | |
| if "image" not in request.files: | |
| raise ValidationError("No image provided") | |
| file = request.files["image"] | |
| # Comprehensive validation | |
| is_valid, error_msg, filename = validate_image_file( | |
| file, | |
| max_size=app.config.get("MAX_CONTENT_LENGTH", DEFAULTS["MAX_FILE_SIZE"]), | |
| allowed_extensions=app.config.get("ALLOWED_EXT", DEFAULTS["ALLOWED_EXT"]), | |
| ) | |
| if not is_valid: | |
| raise ValidationError(error_msg) | |
| tmp_dir = app.config.get("TMP_DIR", TMP_DIR_DEFAULT) | |
| tmp_path = os.path.join(tmp_dir, filename) | |
| used_filename = filename | |
| try: | |
| # Save file and verify it was saved | |
| file.save(tmp_path) | |
| if not os.path.exists(tmp_path): | |
| app.logger.error("Failed to save uploaded file to %s", tmp_path) | |
| raise ValidationError("Failed to save uploaded image") | |
| file_size = os.path.getsize(tmp_path) | |
| if file_size == 0: | |
| app.logger.error("Saved file is empty: %s", tmp_path) | |
| raise ValidationError("Uploaded image is empty") | |
| print(f"[DETECT] Saved file: {tmp_path}, size: {file_size} bytes") | |
| app.logger.info("Saved file: %s, size: %d bytes", tmp_path, file_size) | |
| # Import numpy for both paths | |
| import numpy as np | |
| # Handle ViT and Keras models differently | |
| if model_type == "vit": | |
| # Vision Transformer model - needs RGB PIL Image | |
| from app.vit_utils import preprocess_face_for_vit, predict_with_vit | |
| from PIL import Image | |
| face_image, used_filename = preprocess_face_for_vit(tmp_path) | |
| if face_image is None: | |
| app.logger.warning("No face detected for file %s (size: %d bytes)", filename, file_size) | |
| raise ValidationError("No face detected in image. Please ensure your face is clearly visible, well-lit, and facing the camera.") | |
| # Run ViT prediction | |
| idx, confidence, all_probs = predict_with_vit(model_local, face_image, labels_local) | |
| emotion = labels_local[idx] if idx < len(labels_local) else str(idx) | |
| # Debug output | |
| sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True) | |
| app.logger.info(f"Prediction probabilities for {filename} (sorted): {sorted_probs}") | |
| print(f"[DETECT] All emotion probabilities (sorted by confidence):") | |
| for emo, prob in sorted_probs: | |
| marker = " <-- SELECTED" if emo == emotion else "" | |
| print(f" {emo}: {prob:.3f}{marker}") | |
| print(f"[DETECT] Predicted emotion: {emotion}, confidence: {confidence:.3f}") | |
| # Warn if happy probability is suspiciously low (potential misclassification) | |
| happy_prob = all_probs.get('happy', 0.0) | |
| if happy_prob < 0.15 and confidence > 0.3 and emotion != 'happy': | |
| app.logger.warning(f"⚠️ Low happy probability ({happy_prob:.3f}) but high confidence ({confidence:.3f}) for {emotion}. Possible misclassification.") | |
| print(f"[DETECT] ⚠️ WARNING: Happy probability is very low ({happy_prob:.3f}) - possible misclassification") | |
| # Convert to numpy array format for compatibility with rest of code | |
| probs = np.array([all_probs.get(labels_local[i] if i < len(labels_local) else f"class_{i}", 0.0) | |
| for i in range(len(labels_local))]) | |
| else: | |
| # Keras model - existing code path | |
| # Preprocess face - preprocess_face is imported above in factory scope | |
| res = preprocess_face(tmp_path) | |
| if isinstance(res, tuple): | |
| face_array, used_filename = res | |
| else: | |
| face_array = res | |
| if face_array is None: | |
| app.logger.warning("No face detected for file %s (size: %d bytes)", filename, file_size) | |
| raise ValidationError("No face detected in image. Please ensure your face is clearly visible, well-lit, and facing the camera.") | |
| # Defensive conversion and validations (numpy already imported above) | |
| try: | |
| face_input = np.asarray(face_array) | |
| except Exception as exc: | |
| app.logger.exception("Failed converting preprocessed face to numpy array") | |
| return jsonify({"error": "Invalid preprocessed face data."}), 500 | |
| if getattr(face_input, "dtype", None) == object: | |
| app.logger.error("face_input has object dtype (likely contains None) for file %s", filename) | |
| return jsonify({"error": "Invalid preprocessed face data (object dtype)."}), 500 | |
| # Ensure batch dim and channel dim | |
| if face_input.ndim == 2: | |
| # (H, W) -> (1, H, W, 1) | |
| face_input = np.expand_dims(np.expand_dims(face_input, axis=-1), axis=0) | |
| elif face_input.ndim == 3: | |
| # (H, W, C) -> (1, H, W, C) | |
| face_input = np.expand_dims(face_input, axis=0) | |
| elif face_input.ndim == 4: | |
| # already batched | |
| pass | |
| else: | |
| app.logger.error("Unsupported preprocessed face ndim %s for file %s", getattr(face_input, "ndim", None), filename) | |
| return jsonify({"error": "Unsupported preprocessed face shape."}), 500 | |
| # sanity checks | |
| if face_input.shape[0] < 1: | |
| return jsonify({"error": "Empty batch sent to model."}), 500 | |
| try: | |
| if not np.isfinite(face_input.astype("float32")).all(): | |
| app.logger.error("face_input contains non-finite values for file %s", filename) | |
| return jsonify({"error": "Preprocessed face contains non-finite values."}), 500 | |
| except Exception: | |
| app.logger.exception("Failed checking finiteness of face_input") | |
| return jsonify({"error": "Preprocessed face contains invalid numeric values."}), 500 | |
| # Run prediction | |
| try: | |
| preds = model_local.predict(face_input, verbose=0) | |
| except Exception as exc: | |
| app.logger.exception("Model predict failed for file %s", filename) | |
| return jsonify({"error": "Prediction failed", "detail": str(exc)}), 500 | |
| if preds is None: | |
| return jsonify({"error": "Prediction returned no output"}), 500 | |
| arr = np.asarray(preds) | |
| if arr.ndim == 2: | |
| probs = arr[0] | |
| elif arr.ndim == 1: | |
| probs = arr | |
| else: | |
| app.logger.error("Unexpected prediction shape %s for file %s", getattr(arr, "shape", None), filename) | |
| return jsonify({"error": "Unexpected prediction shape", "shape": list(getattr(arr, "shape", []))}), 500 | |
| if probs.size == 0: | |
| return jsonify({"error": "Empty prediction probabilities"}), 500 | |
| # Verify model output matches expected number of classes | |
| expected_classes = len(labels_local) if isinstance(labels_local, (list, dict)) else 7 | |
| if len(probs) != expected_classes: | |
| app.logger.warning(f"Model output has {len(probs)} classes but labels have {expected_classes}. Labels: {labels_local}") | |
| print(f"[WARNING] Model output shape mismatch: {len(probs)} classes vs {expected_classes} labels") | |
| idx = int(np.argmax(probs)) | |
| confidence = float(probs[idx]) | |
| # Debug: Log all prediction probabilities to understand model behavior | |
| all_probs = {} | |
| for i in range(len(probs)): | |
| if isinstance(labels_local, list) and i < len(labels_local): | |
| all_probs[labels_local[i]] = float(probs[i]) | |
| elif isinstance(labels_local, dict): | |
| label_key = str(i) if str(i) in labels_local else i if i in labels_local else f"class_{i}" | |
| all_probs[label_key] = float(probs[i]) | |
| else: | |
| all_probs[str(i)] = float(probs[i]) | |
| # Sort by probability (highest first) for easier debugging | |
| sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True) | |
| app.logger.info(f"Prediction probabilities for {filename} (sorted): {sorted_probs}") | |
| print(f"[DETECT] All emotion probabilities (sorted by confidence):") | |
| for emotion, prob in sorted_probs: | |
| marker = " <-- SELECTED" if emotion == (labels_local[idx] if isinstance(labels_local, list) and idx < len(labels_local) else str(idx)) else "" | |
| print(f" {emotion}: {prob:.3f}{marker}") | |
| print(f"[DETECT] Predicted emotion index: {idx}, confidence: {confidence:.3f}") | |
| print(f"[DETECT] Available labels: {labels_local}") | |
| # Resolve label safely | |
| if isinstance(labels_local, dict): | |
| emotion = labels_local.get(str(idx)) or labels_local.get(idx) or list(labels_local.values())[idx] | |
| elif isinstance(labels_local, list): | |
| emotion = labels_local[idx] if 0 <= idx < len(labels_local) else str(idx) | |
| else: | |
| emotion = str(idx) | |
| print(f"[DETECT] Mapped emotion label: {emotion}") | |
| # Save image even for low confidence (for debugging/analysis) | |
| images_dir = app.config.get("IMAGES_DIR", IMAGES_DIR_DEFAULT) | |
| stored_filename = None | |
| try: | |
| stored_filename = save_image(tmp_path, images_dir, used_filename) | |
| except Exception: | |
| app.logger.exception("Failed to save image, continuing without storage") | |
| # Confidence threshold - slightly lower for better detection in challenging conditions | |
| # But still maintain quality standards | |
| min_conf = app.config.get("MIN_CONFIDENCE", DEFAULTS["MIN_CONFIDENCE"]) | |
| # Allow slightly lower confidence (0.45) but warn user | |
| if confidence < min_conf: | |
| try: | |
| log_prediction(DB_PATH, used_filename, "low_confidence", confidence, stored_filename) | |
| except Exception: | |
| app.logger.exception("Failed logging low-confidence prediction") | |
| return jsonify({ | |
| "error": "low confidence", | |
| "confidence": round(confidence, 3), | |
| "filename": stored_filename or used_filename, | |
| }), 422 | |
| # Log and respond (image already saved above) | |
| try: | |
| log_prediction(DB_PATH, used_filename, emotion, confidence, stored_filename) | |
| except Exception: | |
| app.logger.exception("Failed to log prediction to DB") | |
| # Return all probabilities for debugging (frontend can use this to show top emotions) | |
| all_emotion_probs = {} | |
| if model_type == "vit": | |
| # For ViT, all_probs already contains the dict | |
| all_emotion_probs = {k: round(v, 4) for k, v in all_probs.items()} | |
| else: | |
| # For Keras, build from probs array | |
| for i in range(len(probs)): | |
| if isinstance(labels_local, list) and i < len(labels_local): | |
| all_emotion_probs[labels_local[i]] = round(float(probs[i]), 4) | |
| elif isinstance(labels_local, dict): | |
| label_key = str(i) if str(i) in labels_local else i if i in labels_local else f"class_{i}" | |
| all_emotion_probs[label_key] = round(float(probs[i]), 4) | |
| return jsonify({ | |
| "emotion": emotion, | |
| "confidence": round(confidence, 3), | |
| "filename": stored_filename or used_filename, | |
| "all_probabilities": all_emotion_probs, # Include all probabilities for debugging | |
| "model": model_selection, | |
| "model_version": model_version, | |
| }), 200 | |
| except (ValidationError, APIError, NotFoundError, ServiceUnavailableError) as exc: | |
| # Let Flask's error handler process these | |
| raise | |
| except Exception as exc: | |
| app.logger.exception("detection error for file %s", filename) | |
| tb = traceback.format_exc() | |
| return jsonify({"error": "internal error", "detail": str(exc), "trace": tb}), 500 | |
| finally: | |
| # cleanup tmp file (image is already saved to images/ if successful) | |
| try: | |
| if os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| except Exception: | |
| app.logger.exception("failed removing tmp file") | |
| # ---------------------------- | |
| # Image serving endpoint | |
| # ---------------------------- | |
| def serve_image(filename: str): | |
| """ | |
| Serve stored images. | |
| GET /images/{filename} | |
| """ | |
| from flask import send_from_directory, abort | |
| # Rate limiting | |
| client_id = get_client_identifier(request) | |
| is_allowed, remaining = images_limiter.is_allowed(client_id) | |
| if not is_allowed: | |
| return jsonify({ | |
| "error": "Rate limit exceeded", | |
| "detail": f"Maximum {images_limiter.max_requests} requests per {images_limiter.window_seconds} seconds", | |
| "retry_after": images_limiter.window_seconds, | |
| }), 429 | |
| try: | |
| images_dir = app.config.get("IMAGES_DIR", IMAGES_DIR_DEFAULT) | |
| image_path = get_image_path(images_dir, filename) | |
| if not image_path: | |
| app.logger.warning("Image not found: %s (checked in %s)", filename, images_dir) | |
| abort(404) | |
| # Extract the actual filename from the path (in case secure_filename changed it) | |
| actual_filename = os.path.basename(image_path) | |
| return send_from_directory( | |
| images_dir, | |
| actual_filename, | |
| mimetype="image/jpeg", # Default, will be auto-detected | |
| ) | |
| except Exception as exc: | |
| app.logger.exception("Failed to serve image %s", filename) | |
| return jsonify({"error": "Failed to serve image", "detail": str(exc)}), 500 | |
| return app | |