|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import hashlib |
|
|
import os |
|
|
import shutil |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from typing import Iterable, List, Optional, Tuple, Union |
|
|
|
|
|
import av |
|
|
import strawberry |
|
|
from app_conf import ( |
|
|
DATA_PATH, |
|
|
DEFAULT_VIDEO_PATH, |
|
|
MAX_UPLOAD_VIDEO_DURATION, |
|
|
UPLOADS_PATH, |
|
|
UPLOADS_PREFIX, |
|
|
) |
|
|
from data.data_types import ( |
|
|
AddPointsInput, |
|
|
CancelPropagateInVideo, |
|
|
CancelPropagateInVideoInput, |
|
|
ClearPointsInFrameInput, |
|
|
ClearPointsInVideo, |
|
|
ClearPointsInVideoInput, |
|
|
CloseSession, |
|
|
CloseSessionInput, |
|
|
RemoveObjectInput, |
|
|
RLEMask, |
|
|
RLEMaskForObject, |
|
|
RLEMaskListOnFrame, |
|
|
StartSession, |
|
|
StartSessionInput, |
|
|
Video, |
|
|
) |
|
|
from data.loader import get_video |
|
|
from data.store import get_videos |
|
|
from data.transcoder import get_video_metadata, transcode, VideoMetadata |
|
|
from inference.data_types import ( |
|
|
AddPointsRequest, |
|
|
CancelPropagateInVideoRequest, |
|
|
CancelPropagateInVideoRequest, |
|
|
ClearPointsInFrameRequest, |
|
|
ClearPointsInVideoRequest, |
|
|
CloseSessionRequest, |
|
|
RemoveObjectRequest, |
|
|
StartSessionRequest, |
|
|
) |
|
|
from inference.predictor import InferenceAPI |
|
|
from strawberry import relay |
|
|
from strawberry.file_uploads import Upload |
|
|
|
|
|
|
|
|
@strawberry.type |
|
|
class Query: |
|
|
|
|
|
@strawberry.field |
|
|
def default_video(self) -> Video: |
|
|
""" |
|
|
Return the default video. |
|
|
|
|
|
The default video can be set with the DEFAULT_VIDEO_PATH environment |
|
|
variable. It will return the video that matches this path. If no video |
|
|
is found, it will return the first video. |
|
|
""" |
|
|
all_videos = get_videos() |
|
|
|
|
|
|
|
|
|
|
|
for _, v in all_videos.items(): |
|
|
if v.path == DEFAULT_VIDEO_PATH: |
|
|
return v |
|
|
|
|
|
|
|
|
return next(iter(all_videos.values())) |
|
|
|
|
|
@relay.connection(relay.ListConnection[Video]) |
|
|
def videos( |
|
|
self, |
|
|
) -> Iterable[Video]: |
|
|
""" |
|
|
Return all available videos. |
|
|
""" |
|
|
all_videos = get_videos() |
|
|
return all_videos.values() |
|
|
|
|
|
|
|
|
@strawberry.type |
|
|
class Mutation: |
|
|
|
|
|
@strawberry.mutation |
|
|
def upload_video( |
|
|
self, |
|
|
file: Upload, |
|
|
start_time_sec: Optional[float] = None, |
|
|
duration_time_sec: Optional[float] = None, |
|
|
) -> Video: |
|
|
""" |
|
|
Receive a video file and store it in the configured S3 bucket. |
|
|
""" |
|
|
max_time = MAX_UPLOAD_VIDEO_DURATION |
|
|
filepath, file_key, vm = process_video( |
|
|
file, |
|
|
max_time=max_time, |
|
|
start_time_sec=start_time_sec, |
|
|
duration_time_sec=duration_time_sec, |
|
|
) |
|
|
|
|
|
video = get_video( |
|
|
filepath, |
|
|
UPLOADS_PATH, |
|
|
file_key=file_key, |
|
|
width=vm.width, |
|
|
height=vm.height, |
|
|
generate_poster=False, |
|
|
) |
|
|
return video |
|
|
|
|
|
@strawberry.mutation |
|
|
def start_session( |
|
|
self, input: StartSessionInput, info: strawberry.Info |
|
|
) -> StartSession: |
|
|
inference_api: InferenceAPI = info.context["inference_api"] |
|
|
|
|
|
request = StartSessionRequest( |
|
|
type="start_session", |
|
|
path=f"{DATA_PATH}/{input.path}", |
|
|
) |
|
|
|
|
|
response = inference_api.start_session(request=request) |
|
|
|
|
|
return StartSession(session_id=response.session_id) |
|
|
|
|
|
@strawberry.mutation |
|
|
def close_session( |
|
|
self, input: CloseSessionInput, info: strawberry.Info |
|
|
) -> CloseSession: |
|
|
inference_api: InferenceAPI = info.context["inference_api"] |
|
|
|
|
|
request = CloseSessionRequest( |
|
|
type="close_session", |
|
|
session_id=input.session_id, |
|
|
) |
|
|
response = inference_api.close_session(request) |
|
|
return CloseSession(success=response.success) |
|
|
|
|
|
@strawberry.mutation |
|
|
def add_points( |
|
|
self, input: AddPointsInput, info: strawberry.Info |
|
|
) -> RLEMaskListOnFrame: |
|
|
inference_api: InferenceAPI = info.context["inference_api"] |
|
|
|
|
|
request = AddPointsRequest( |
|
|
type="add_points", |
|
|
session_id=input.session_id, |
|
|
frame_index=input.frame_index, |
|
|
object_id=input.object_id, |
|
|
points=input.points, |
|
|
labels=input.labels, |
|
|
clear_old_points=input.clear_old_points, |
|
|
) |
|
|
reponse = inference_api.add_points(request) |
|
|
|
|
|
return RLEMaskListOnFrame( |
|
|
frame_index=reponse.frame_index, |
|
|
rle_mask_list=[ |
|
|
RLEMaskForObject( |
|
|
object_id=r.object_id, |
|
|
rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"), |
|
|
) |
|
|
for r in reponse.results |
|
|
], |
|
|
) |
|
|
|
|
|
@strawberry.mutation |
|
|
def remove_object( |
|
|
self, input: RemoveObjectInput, info: strawberry.Info |
|
|
) -> List[RLEMaskListOnFrame]: |
|
|
inference_api: InferenceAPI = info.context["inference_api"] |
|
|
|
|
|
request = RemoveObjectRequest( |
|
|
type="remove_object", session_id=input.session_id, object_id=input.object_id |
|
|
) |
|
|
|
|
|
response = inference_api.remove_object(request) |
|
|
|
|
|
return [ |
|
|
RLEMaskListOnFrame( |
|
|
frame_index=res.frame_index, |
|
|
rle_mask_list=[ |
|
|
RLEMaskForObject( |
|
|
object_id=r.object_id, |
|
|
rle_mask=RLEMask( |
|
|
counts=r.mask.counts, size=r.mask.size, order="F" |
|
|
), |
|
|
) |
|
|
for r in res.results |
|
|
], |
|
|
) |
|
|
for res in response.results |
|
|
] |
|
|
|
|
|
@strawberry.mutation |
|
|
def clear_points_in_frame( |
|
|
self, input: ClearPointsInFrameInput, info: strawberry.Info |
|
|
) -> RLEMaskListOnFrame: |
|
|
inference_api: InferenceAPI = info.context["inference_api"] |
|
|
|
|
|
request = ClearPointsInFrameRequest( |
|
|
type="clear_points_in_frame", |
|
|
session_id=input.session_id, |
|
|
frame_index=input.frame_index, |
|
|
object_id=input.object_id, |
|
|
) |
|
|
|
|
|
response = inference_api.clear_points_in_frame(request) |
|
|
|
|
|
return RLEMaskListOnFrame( |
|
|
frame_index=response.frame_index, |
|
|
rle_mask_list=[ |
|
|
RLEMaskForObject( |
|
|
object_id=r.object_id, |
|
|
rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"), |
|
|
) |
|
|
for r in response.results |
|
|
], |
|
|
) |
|
|
|
|
|
@strawberry.mutation |
|
|
def clear_points_in_video( |
|
|
self, input: ClearPointsInVideoInput, info: strawberry.Info |
|
|
) -> ClearPointsInVideo: |
|
|
inference_api: InferenceAPI = info.context["inference_api"] |
|
|
|
|
|
request = ClearPointsInVideoRequest( |
|
|
type="clear_points_in_video", |
|
|
session_id=input.session_id, |
|
|
) |
|
|
response = inference_api.clear_points_in_video(request) |
|
|
return ClearPointsInVideo(success=response.success) |
|
|
|
|
|
@strawberry.mutation |
|
|
def cancel_propagate_in_video( |
|
|
self, input: CancelPropagateInVideoInput, info: strawberry.Info |
|
|
) -> CancelPropagateInVideo: |
|
|
inference_api: InferenceAPI = info.context["inference_api"] |
|
|
|
|
|
request = CancelPropagateInVideoRequest( |
|
|
type="cancel_propagate_in_video", |
|
|
session_id=input.session_id, |
|
|
) |
|
|
response = inference_api.cancel_propagate_in_video(request) |
|
|
return CancelPropagateInVideo(success=response.success) |
|
|
|
|
|
|
|
|
def get_file_hash(video_path_or_file) -> str: |
|
|
if isinstance(video_path_or_file, str): |
|
|
with open(video_path_or_file, "rb") as in_f: |
|
|
result = hashlib.sha256(in_f.read()).hexdigest() |
|
|
else: |
|
|
video_path_or_file.seek(0) |
|
|
result = hashlib.sha256(video_path_or_file.read()).hexdigest() |
|
|
return result |
|
|
|
|
|
|
|
|
def _get_start_sec_duration_sec( |
|
|
start_time_sec: Union[float, None], |
|
|
duration_time_sec: Union[float, None], |
|
|
max_time: float, |
|
|
) -> Tuple[float, float]: |
|
|
default_seek_t = int(os.environ.get("VIDEO_ENCODE_SEEK_TIME", "0")) |
|
|
if start_time_sec is None: |
|
|
start_time_sec = default_seek_t |
|
|
|
|
|
if duration_time_sec is not None: |
|
|
duration_time_sec = min(duration_time_sec, max_time) |
|
|
else: |
|
|
duration_time_sec = max_time |
|
|
return start_time_sec, duration_time_sec |
|
|
|
|
|
|
|
|
def process_video( |
|
|
file: Upload, |
|
|
max_time: float, |
|
|
start_time_sec: Optional[float] = None, |
|
|
duration_time_sec: Optional[float] = None, |
|
|
) -> Tuple[Optional[str], str, str, VideoMetadata]: |
|
|
""" |
|
|
Process file upload including video trimming and content moderation checks. |
|
|
|
|
|
Returns the filepath, s3_file_key, hash & video metaedata as a tuple. |
|
|
""" |
|
|
with tempfile.TemporaryDirectory() as tempdir: |
|
|
in_path = f"{tempdir}/in.mp4" |
|
|
out_path = f"{tempdir}/out.mp4" |
|
|
with open(in_path, "wb") as in_f: |
|
|
in_f.write(file.read()) |
|
|
|
|
|
try: |
|
|
video_metadata = get_video_metadata(in_path) |
|
|
except av.InvalidDataError: |
|
|
raise Exception("not valid video file") |
|
|
|
|
|
if video_metadata.num_video_streams == 0: |
|
|
raise Exception("video container does not contain a video stream") |
|
|
if video_metadata.width is None or video_metadata.height is None: |
|
|
raise Exception("video container does not contain width or height metadata") |
|
|
|
|
|
if video_metadata.duration_sec in (None, 0): |
|
|
raise Exception("video container does time duration metadata") |
|
|
|
|
|
start_time_sec, duration_time_sec = _get_start_sec_duration_sec( |
|
|
max_time=max_time, |
|
|
start_time_sec=start_time_sec, |
|
|
duration_time_sec=duration_time_sec, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
transcode( |
|
|
in_path, |
|
|
out_path, |
|
|
video_metadata, |
|
|
seek_t=start_time_sec, |
|
|
duration_time_sec=duration_time_sec, |
|
|
) |
|
|
|
|
|
os.remove(in_path) |
|
|
|
|
|
out_video_metadata = get_video_metadata(out_path) |
|
|
if out_video_metadata.num_video_frames == 0: |
|
|
raise Exception( |
|
|
"transcode produced empty video; check seek time or your input video" |
|
|
) |
|
|
|
|
|
filepath = None |
|
|
file_key = None |
|
|
with open(out_path, "rb") as file_data: |
|
|
file_hash = get_file_hash(file_data) |
|
|
file_data.seek(0) |
|
|
|
|
|
file_key = UPLOADS_PREFIX + "/" + f"{file_hash}.mp4" |
|
|
filepath = os.path.join(UPLOADS_PATH, f"{file_hash}.mp4") |
|
|
|
|
|
assert filepath is not None and file_key is not None |
|
|
shutil.move(out_path, filepath) |
|
|
|
|
|
return filepath, file_key, out_video_metadata |
|
|
|
|
|
|
|
|
schema = strawberry.Schema( |
|
|
query=Query, |
|
|
mutation=Mutation, |
|
|
) |
|
|
|