Spaces:
Running
Running
| # Copyright 2024 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Gunicorn application for passing requests through to the executor command. | |
| Provides a thin, subject-agnostic request server for Vertex endpoints which | |
| handles requests by piping their JSON bodies to the given executor command | |
| and returning the json output. | |
| """ | |
| from collections.abc import Mapping | |
| import http | |
| import os | |
| from typing import Any, Optional, Sequence | |
| from absl import app | |
| from absl import logging | |
| import flask | |
| from gunicorn.app import base as gunicorn_base | |
| import pete_predictor_v2 | |
| def _create_app() -> flask.Flask: | |
| """Creates a Flask app with the given executor.""" | |
| predictor = pete_predictor_v2.PetePredictor() | |
| flask_app = flask.Flask(__name__) | |
| def predict() -> tuple[dict[str, Any], int]: | |
| logging.info("predict route hit") | |
| if flask.request.get_json(silent=True) is None: | |
| return {"error": "No JSON body."}, http.HTTPStatus.BAD_REQUEST.value | |
| logging.debug("Dispatching request to executor.") | |
| try: | |
| exec_result = predictor.predict(flask.request.get_json()) | |
| logging.debug("Executor returned results.") | |
| return (exec_result, http.HTTPStatus.OK.value) | |
| except RuntimeError: | |
| logging.exception("Internal error handling request: Executor failed.") | |
| return { | |
| "error": "Internal server error." | |
| }, http.HTTPStatus.INTERNAL_SERVER_ERROR.value | |
| predict_route = os.environ.get("AIP_PREDICT_ROUTE", "/predict") | |
| logging.info("predict route: %s", predict_route) | |
| flask_app.add_url_rule(predict_route, view_func=predict, methods=["POST"]) | |
| flask_app.config["TRAP_BAD_REQUEST_ERRORS"] = True | |
| return flask_app | |
| class PredictionApplication(gunicorn_base.BaseApplication): | |
| """Application to serve predictors on Vertex endpoints using gunicorn.""" | |
| def __init__( | |
| self, | |
| *, | |
| options: Optional[Mapping[str, Any]] = None, | |
| ): | |
| self.options = options or {} | |
| self.options = dict(self.options) | |
| self.options["preload_app"] = False | |
| self.application = _create_app() | |
| super().__init__() | |
| def load_config(self): | |
| config = { | |
| key: value | |
| for key, value in self.options.items() | |
| if key in self.cfg.settings and value is not None | |
| } | |
| for key, value in config.items(): | |
| self.cfg.set(key.lower(), value) | |
| def load(self) -> flask.Flask: | |
| return self.application | |
| def main(argv: Sequence[str]) -> None: | |
| options = {'bind': f'127.0.0.1:80', | |
| 'workers': 3, | |
| 'timeout':600 | |
| } | |
| PredictionApplication(options=options).run() | |
| if __name__ == '__main__': | |
| app.run(main) | |