File size: 2,065 Bytes
2cc98e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import sys
import logging

from typing import Dict, Any
from flask import Flask, request, jsonify
from werkzeug.exceptions import HTTPException
from machine_learning.load_models import (
    not_for_me,
    recommended_for_you,
    similarity,
)


logging.basicConfig(
    level=os.environ.get("LOG_LEVEL", "INFO"),
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
    stream=sys.stdout,
)
log = logging.getLogger("ff1000-api")


MODELS: Dict[str, Any] = {
    "nfm": not_for_me,
    "rfy": recommended_for_you,
    "similarity": similarity,
}


def create_app() -> Flask:
    app = Flask(__name__)

    @app.errorhandler(Exception)
    def handle_exception(e):
        if isinstance(e, HTTPException):
            return jsonify(error=e.name, message=e.description), e.code
        log.exception("Unhandled exception")
        return jsonify(error="InternalServerError", message=str(e)), 500

    @app.get("/health")
    def healthz():
        return jsonify(status="ok")

    @app.post("/predict/<model_name>")
    def predict(model_name: str):
        if model_name not in MODELS:
            return jsonify(error="UnknownModel", message=f"valid models: {list(MODELS.keys())}"), 400

        try:
            payload = request.get_json(force=True, silent=False)
        except Exception:
            return jsonify(error="InvalidJSON", message="body must be valid JSON"), 400

        if not isinstance(payload, dict) or "items" not in payload:
            return jsonify(error="BadRequest", message="json must have key 'items'"), 400

        inputs = payload["items"]
        if not isinstance(inputs, list):
            return jsonify(error="BadRequest", message="'items' must be a list"), 400

        model = MODELS[model_name]
        try:
            preds = model.predict([inputs])
        except Exception as e:
            log.exception("Prediction failed")
            return jsonify(error="PredictionError", message=str(e)), 500

        return jsonify(model=model_name, predictions=preds)

    return app


app = create_app()