|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from typing import Any, Generator |
|
|
|
|
|
from app_conf import ( |
|
|
GALLERY_PATH, |
|
|
GALLERY_PREFIX, |
|
|
POSTERS_PATH, |
|
|
POSTERS_PREFIX, |
|
|
UPLOADS_PATH, |
|
|
UPLOADS_PREFIX, |
|
|
) |
|
|
from data.loader import preload_data |
|
|
from data.schema import schema |
|
|
from data.store import set_videos |
|
|
from flask import Flask, make_response, Request, request, Response, send_from_directory |
|
|
from flask_cors import CORS |
|
|
from inference.data_types import PropagateDataResponse, PropagateInVideoRequest |
|
|
from inference.multipart import MultipartResponseBuilder |
|
|
from inference.predictor import InferenceAPI |
|
|
from strawberry.flask.views import GraphQLView |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = Flask(__name__) |
|
|
cors = CORS(app, supports_credentials=True) |
|
|
|
|
|
videos = preload_data() |
|
|
set_videos(videos) |
|
|
|
|
|
inference_api = InferenceAPI() |
|
|
|
|
|
|
|
|
@app.route("/healthy") |
|
|
def healthy() -> Response: |
|
|
return make_response("OK", 200) |
|
|
|
|
|
|
|
|
@app.route(f"/{GALLERY_PREFIX}/<path:path>", methods=["GET"]) |
|
|
def send_gallery_video(path: str) -> Response: |
|
|
try: |
|
|
return send_from_directory( |
|
|
GALLERY_PATH, |
|
|
path, |
|
|
) |
|
|
except: |
|
|
raise ValueError("resource not found") |
|
|
|
|
|
|
|
|
@app.route(f"/{POSTERS_PREFIX}/<path:path>", methods=["GET"]) |
|
|
def send_poster_image(path: str) -> Response: |
|
|
try: |
|
|
return send_from_directory( |
|
|
POSTERS_PATH, |
|
|
path, |
|
|
) |
|
|
except: |
|
|
raise ValueError("resource not found") |
|
|
|
|
|
|
|
|
@app.route(f"/{UPLOADS_PREFIX}/<path:path>", methods=["GET"]) |
|
|
def send_uploaded_video(path: str): |
|
|
try: |
|
|
return send_from_directory( |
|
|
UPLOADS_PATH, |
|
|
path, |
|
|
) |
|
|
except: |
|
|
raise ValueError("resource not found") |
|
|
|
|
|
|
|
|
|
|
|
@app.route("/propagate_in_video", methods=["POST"]) |
|
|
def propagate_in_video() -> Response: |
|
|
data = request.json |
|
|
args = { |
|
|
"session_id": data["session_id"], |
|
|
"start_frame_index": data.get("start_frame_index", 0), |
|
|
} |
|
|
|
|
|
boundary = "frame" |
|
|
frame = gen_track_with_mask_stream(boundary, **args) |
|
|
return Response(frame, mimetype="multipart/x-savi-stream; boundary=" + boundary) |
|
|
|
|
|
|
|
|
def gen_track_with_mask_stream( |
|
|
boundary: str, |
|
|
session_id: str, |
|
|
start_frame_index: int, |
|
|
) -> Generator[bytes, None, None]: |
|
|
with inference_api.autocast_context(): |
|
|
request = PropagateInVideoRequest( |
|
|
type="propagate_in_video", |
|
|
session_id=session_id, |
|
|
start_frame_index=start_frame_index, |
|
|
) |
|
|
|
|
|
for chunk in inference_api.propagate_in_video(request=request): |
|
|
yield MultipartResponseBuilder.build( |
|
|
boundary=boundary, |
|
|
headers={ |
|
|
"Content-Type": "application/json; charset=utf-8", |
|
|
"Frame-Current": "-1", |
|
|
|
|
|
"Frame-Total": "-1", |
|
|
"Mask-Type": "RLE[]", |
|
|
}, |
|
|
body=chunk.to_json().encode("UTF-8"), |
|
|
).get_message() |
|
|
|
|
|
|
|
|
class MyGraphQLView(GraphQLView): |
|
|
def get_context(self, request: Request, response: Response) -> Any: |
|
|
return {"inference_api": inference_api} |
|
|
|
|
|
|
|
|
|
|
|
app.add_url_rule( |
|
|
"/graphql", |
|
|
view_func=MyGraphQLView.as_view( |
|
|
"graphql_view", |
|
|
schema=schema, |
|
|
|
|
|
|
|
|
|
|
|
allow_queries_via_get=False, |
|
|
|
|
|
|
|
|
|
|
|
multipart_uploads_enabled=True, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.run(host="0.0.0.0", port=5000) |
|
|
|