| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import {generateThumbnail} from '@/common/components/video/editor/VideoEditorUtils';
|
| import VideoWorkerContext from '@/common/components/video/VideoWorkerContext';
|
| import Logger from '@/common/logger/Logger';
|
| import {
|
| SAM2ModelAddNewPointsMutation,
|
| SAM2ModelAddNewPointsMutation$data,
|
| } from '@/common/tracker/__generated__/SAM2ModelAddNewPointsMutation.graphql';
|
| import {SAM2ModelCancelPropagateInVideoMutation} from '@/common/tracker/__generated__/SAM2ModelCancelPropagateInVideoMutation.graphql';
|
| import {SAM2ModelClearPointsInFrameMutation} from '@/common/tracker/__generated__/SAM2ModelClearPointsInFrameMutation.graphql';
|
| import {SAM2ModelClearPointsInVideoMutation} from '@/common/tracker/__generated__/SAM2ModelClearPointsInVideoMutation.graphql';
|
| import {SAM2ModelCloseSessionMutation} from '@/common/tracker/__generated__/SAM2ModelCloseSessionMutation.graphql';
|
| import {SAM2ModelRemoveObjectMutation} from '@/common/tracker/__generated__/SAM2ModelRemoveObjectMutation.graphql';
|
| import {SAM2ModelStartSessionMutation} from '@/common/tracker/__generated__/SAM2ModelStartSessionMutation.graphql';
|
| import {
|
| BaseTracklet,
|
| Mask,
|
| SegmentationPoint,
|
| StreamingState,
|
| Tracker,
|
| Tracklet,
|
| } from '@/common/tracker/Tracker';
|
| import {TrackerOptions} from '@/common/tracker/Trackers';
|
| import {
|
| ClearPointsInVideoResponse,
|
| SessionStartFailedResponse,
|
| SessionStartedResponse,
|
| StreamingCompletedResponse,
|
| StreamingStartedResponse,
|
| StreamingStateUpdateResponse,
|
| TrackletCreatedResponse,
|
| TrackletDeletedResponse,
|
| TrackletsUpdatedResponse,
|
| } from '@/common/tracker/TrackerTypes';
|
| import {convertMaskToRGBA} from '@/common/utils/MaskUtils';
|
| import multipartStream from '@/common/utils/MultipartStream';
|
| import {Stats} from '@/debug/stats/Stats';
|
| import {INFERENCE_API_ENDPOINT} from '@/demo/DemoConfig';
|
| import {createEnvironment} from '@/graphql/RelayEnvironment';
|
| import {
|
| DataArray,
|
| Masks,
|
| RLEObject,
|
| decode,
|
| encode,
|
| toBbox,
|
| } from '@/jscocotools/mask';
|
| import {THEME_COLORS} from '@/theme/colors';
|
| import invariant from 'invariant';
|
| import {IEnvironment, commitMutation, graphql} from 'relay-runtime';
|
|
|
| type Options = Pick<TrackerOptions, 'inferenceEndpoint'>;
|
|
|
| type Session = {
|
| id: string | null;
|
| tracklets: {[id: number]: Tracklet};
|
| };
|
|
|
| type StreamMasksResult = {
|
| frameIndex: number;
|
| rleMaskList: Array<{
|
| objectId: number;
|
| rleMask: RLEObject;
|
| }>;
|
| };
|
|
|
| type StreamMasksAbortResult = {
|
| aborted: boolean;
|
| };
|
|
|
| export class SAM2Model extends Tracker {
|
| private _endpoint: string;
|
| private _environment: IEnvironment;
|
|
|
| private abortController: AbortController | null = null;
|
| private _session: Session = {
|
| id: null,
|
| tracklets: {},
|
| };
|
| private _streamingState: StreamingState = 'none';
|
|
|
| private _emptyMask: RLEObject | null = null;
|
|
|
| private _maskCanvas: OffscreenCanvas;
|
| private _maskCtx: OffscreenCanvasRenderingContext2D;
|
|
|
| private _stats?: Stats;
|
|
|
| constructor(
|
| context: VideoWorkerContext,
|
| options: Options = {
|
| inferenceEndpoint: INFERENCE_API_ENDPOINT,
|
| },
|
| ) {
|
| super(context);
|
| this._endpoint = options.inferenceEndpoint;
|
| this._environment = createEnvironment(options.inferenceEndpoint);
|
|
|
| this._maskCanvas = new OffscreenCanvas(0, 0);
|
| const maskCtx = this._maskCanvas.getContext('2d');
|
| invariant(maskCtx != null, 'context cannot be null');
|
| this._maskCtx = maskCtx;
|
| }
|
|
|
| public startSession(videoPath: string): Promise<void> {
|
|
|
|
|
| this._updateStreamingState('none', true);
|
|
|
| return new Promise(resolve => {
|
| try {
|
| commitMutation<SAM2ModelStartSessionMutation>(this._environment, {
|
| mutation: graphql`
|
| mutation SAM2ModelStartSessionMutation($input: StartSessionInput!) {
|
| startSession(input: $input) {
|
| sessionId
|
| }
|
| }
|
| `,
|
| variables: {
|
| input: {
|
| path: videoPath,
|
| },
|
| },
|
| onCompleted: response => {
|
| const {sessionId} = response.startSession;
|
| this._session.id = sessionId;
|
|
|
| this._sendResponse<SessionStartedResponse>('sessionStarted', {
|
| sessionId,
|
| });
|
|
|
|
|
|
|
| this._clearTracklets();
|
|
|
|
|
| this.createTracklet();
|
| resolve();
|
| },
|
| onError: error => {
|
| Logger.error(error);
|
| this._sendResponse<SessionStartFailedResponse>(
|
| 'sessionStartFailed',
|
| );
|
| resolve();
|
| },
|
| });
|
| } catch (error) {
|
| Logger.error(error);
|
| this._sendResponse<SessionStartFailedResponse>('sessionStartFailed');
|
| resolve();
|
| }
|
| });
|
| }
|
|
|
| public closeSession(): Promise<void> {
|
| const sessionId = this._session.id;
|
|
|
|
|
|
|
|
|
| this._cleanup();
|
|
|
| if (sessionId === null) {
|
| return Promise.resolve();
|
| }
|
| return new Promise((resolve, reject) => {
|
| commitMutation<SAM2ModelCloseSessionMutation>(this._environment, {
|
| mutation: graphql`
|
| mutation SAM2ModelCloseSessionMutation($input: CloseSessionInput!) {
|
| closeSession(input: $input) {
|
| success
|
| }
|
| }
|
| `,
|
| variables: {
|
| input: {
|
| sessionId,
|
| },
|
| },
|
| onCompleted: response => {
|
| const {success} = response.closeSession;
|
| if (success === false) {
|
| reject(new Error('Failed to close session'));
|
| return;
|
| }
|
| resolve();
|
| },
|
| onError: error => {
|
| Logger.error(error);
|
| reject(error);
|
| },
|
| });
|
| });
|
| }
|
|
|
| public createTracklet(): void {
|
|
|
|
|
| const nextId =
|
| Object.values(this._session.tracklets).reduce(
|
| (prev, curr) => Math.max(prev, curr.id),
|
| -1,
|
| ) + 1;
|
|
|
| const newTracklet = {
|
| id: nextId,
|
| color: THEME_COLORS[nextId % THEME_COLORS.length],
|
| thumbnail: null,
|
| points: [],
|
| masks: [],
|
| isInitialized: false,
|
| };
|
|
|
| this._session.tracklets[nextId] = newTracklet;
|
|
|
|
|
| this._updateTracklets();
|
|
|
| this._sendResponse<TrackletCreatedResponse>('trackletCreated', {
|
| tracklet: newTracklet,
|
| });
|
| }
|
|
|
| public deleteTracklet(trackletId: number): Promise<void> {
|
| const sessionId = this._session.id;
|
| if (sessionId === null) {
|
| return Promise.reject('No active session');
|
| }
|
|
|
| const tracklet = this._session.tracklets[trackletId];
|
| invariant(
|
| tracklet != null,
|
| 'tracklet for tracklet id %s not initialized',
|
| trackletId,
|
| );
|
|
|
| return new Promise((resolve, reject) => {
|
| commitMutation<SAM2ModelRemoveObjectMutation>(this._environment, {
|
| mutation: graphql`
|
| mutation SAM2ModelRemoveObjectMutation($input: RemoveObjectInput!) {
|
| removeObject(input: $input) {
|
| frameIndex
|
| rleMaskList {
|
| objectId
|
| rleMask {
|
| counts
|
| size
|
| }
|
| }
|
| }
|
| }
|
| `,
|
| variables: {
|
| input: {objectId: trackletId, sessionId},
|
| },
|
| onCompleted: response => {
|
| const trackletUpdates = response.removeObject;
|
| this._sendResponse<TrackletDeletedResponse>('trackletDeleted', {
|
| isSuccessful: true,
|
| });
|
| for (const trackletUpdate of trackletUpdates) {
|
| this._updateTrackletMasks(
|
| trackletUpdate,
|
| trackletUpdate.frameIndex === this._context.frameIndex,
|
| false,
|
| );
|
| }
|
| this._removeTrackletMasks(tracklet);
|
| resolve();
|
| },
|
| onError: error => {
|
| this._sendResponse<TrackletDeletedResponse>('trackletDeleted', {
|
| isSuccessful: false,
|
| });
|
| Logger.error(error);
|
| reject(error);
|
| },
|
| });
|
| });
|
| }
|
|
|
| public updatePoints(
|
| frameIndex: number,
|
| objectId: number,
|
| points: SegmentationPoint[],
|
| ): Promise<void> {
|
| const sessionId = this._session.id;
|
| if (sessionId === null) {
|
| return Promise.reject('No active session');
|
| }
|
|
|
|
|
|
|
|
|
|
|
|
|
| if (this._emptyMask === null) {
|
|
|
|
|
| const tensor = new Masks(
|
| Math.trunc(this._context.height),
|
| Math.trunc(this._context.width),
|
| 1,
|
| ).toDataArray();
|
| this._emptyMask = encode(tensor)[0];
|
| }
|
|
|
| const tracklet = this._session.tracklets[objectId];
|
| invariant(
|
| tracklet != null,
|
| 'tracklet for object id %s not initialized',
|
| objectId,
|
| );
|
|
|
|
|
| this._updateStreamingState('required');
|
|
|
|
|
| if (points.length === 0) {
|
| return this.clearPointsInFrame(frameIndex, objectId);
|
| }
|
| return new Promise((resolve, reject) => {
|
| const normalizedPoints = points.map(p => [
|
| p[0] / this._context.width,
|
| p[1] / this._context.height,
|
| ]);
|
| const labels = points.map(p => p[2]);
|
| commitMutation<SAM2ModelAddNewPointsMutation>(this._environment, {
|
| mutation: graphql`
|
| mutation SAM2ModelAddNewPointsMutation($input: AddPointsInput!) {
|
| addPoints(input: $input) {
|
| frameIndex
|
| rleMaskList {
|
| objectId
|
| rleMask {
|
| counts
|
| size
|
| }
|
| }
|
| }
|
| }
|
| `,
|
| variables: {
|
| input: {
|
| sessionId,
|
| frameIndex,
|
| objectId,
|
| labels: labels,
|
| points: normalizedPoints,
|
| clearOldPoints: true,
|
| },
|
| },
|
| onCompleted: response => {
|
| tracklet.points[frameIndex] = points;
|
| tracklet.isInitialized = true;
|
| this._updateTrackletMasks(response.addPoints, true);
|
| resolve();
|
| },
|
| onError: error => {
|
| Logger.error(error);
|
| reject(error);
|
| },
|
| });
|
| });
|
| }
|
|
|
| public clearPointsInFrame(
|
| frameIndex: number,
|
| objectId: number,
|
| ): Promise<void> {
|
| const sessionId = this._session.id;
|
| if (sessionId === null) {
|
| return Promise.reject('No active session');
|
| }
|
|
|
| const tracklet = this._session.tracklets[objectId];
|
| invariant(
|
| tracklet != null,
|
| 'tracklet for object id %s not initialized',
|
| objectId,
|
| );
|
|
|
|
|
| this._updateStreamingState('required');
|
|
|
| return new Promise((resolve, reject) => {
|
| commitMutation<SAM2ModelClearPointsInFrameMutation>(this._environment, {
|
| mutation: graphql`
|
| mutation SAM2ModelClearPointsInFrameMutation(
|
| $input: ClearPointsInFrameInput!
|
| ) {
|
| clearPointsInFrame(input: $input) {
|
| frameIndex
|
| rleMaskList {
|
| objectId
|
| rleMask {
|
| counts
|
| size
|
| }
|
| }
|
| }
|
| }
|
| `,
|
| variables: {
|
| input: {
|
| sessionId,
|
| frameIndex,
|
| objectId,
|
| },
|
| },
|
| onCompleted: response => {
|
| tracklet.points[frameIndex] = [];
|
| tracklet.isInitialized = true;
|
| this._updateTrackletMasks(response.clearPointsInFrame, true);
|
| resolve();
|
| },
|
| onError: error => {
|
| Logger.error(error);
|
| reject(error);
|
| },
|
| });
|
| });
|
| }
|
|
|
| public clearPointsInVideo(): Promise<void> {
|
| const sessionId = this._session.id;
|
| if (sessionId === null) {
|
| return Promise.reject('No active session');
|
| }
|
|
|
|
|
| this._updateStreamingState('none');
|
|
|
| return new Promise(resolve => {
|
| commitMutation<SAM2ModelClearPointsInVideoMutation>(this._environment, {
|
| mutation: graphql`
|
| mutation SAM2ModelClearPointsInVideoMutation(
|
| $input: ClearPointsInVideoInput!
|
| ) {
|
| clearPointsInVideo(input: $input) {
|
| success
|
| }
|
| }
|
| `,
|
| variables: {
|
| input: {
|
| sessionId,
|
| },
|
| },
|
| onCompleted: response => {
|
| const {success} = response.clearPointsInVideo;
|
| if (!success) {
|
| this._sendResponse<ClearPointsInVideoResponse>(
|
| 'clearPointsInVideo',
|
| {isSuccessful: false},
|
| );
|
| return;
|
| }
|
|
|
|
|
| this._clearTracklets();
|
|
|
|
|
| this._context.goToFrame(this._context.frameIndex);
|
| this._updateTracklets();
|
| this._sendResponse<ClearPointsInVideoResponse>('clearPointsInVideo', {
|
| isSuccessful: true,
|
| });
|
| resolve();
|
| },
|
| onError: error => {
|
| this._sendResponse<ClearPointsInVideoResponse>('clearPointsInVideo', {
|
| isSuccessful: false,
|
| });
|
| Logger.error(error);
|
| },
|
| });
|
| });
|
| }
|
|
|
| public async streamMasks(frameIndex: number): Promise<void> {
|
| const sessionId = this._session.id;
|
| if (sessionId === null) {
|
| return Promise.reject('No active session');
|
| }
|
| try {
|
| this._sendResponse<StreamingStartedResponse>('streamingStarted');
|
|
|
|
|
| this._context.clearMasks();
|
| this._clearTrackletMasks();
|
|
|
|
|
| const controller = new AbortController();
|
| this.abortController = controller;
|
|
|
| this._updateStreamingState('requesting');
|
| const generator = this._streamMasksForSession(
|
| controller,
|
| sessionId,
|
| frameIndex,
|
| );
|
|
|
|
|
| let isAborted = false;
|
| for await (const result of generator) {
|
| if ('aborted' in result) {
|
| this._updateStreamingState('aborting');
|
| await this._abortRequest();
|
| this._updateStreamingState('aborted');
|
| isAborted = true;
|
| } else {
|
| await this._updateTrackletMasks(result, false);
|
| this._updateStreamingState('partial');
|
| }
|
| }
|
|
|
| if (!isAborted) {
|
|
|
| this._updateStreamingState('full');
|
| }
|
| } catch (error) {
|
| Logger.error(error);
|
| throw error;
|
| }
|
|
|
| this._sendResponse<StreamingCompletedResponse>('streamingCompleted');
|
| }
|
|
|
| public abortStreamMasks() {
|
| this.abortController?.abort();
|
| this._sendResponse<StreamingCompletedResponse>('streamingCompleted');
|
| }
|
|
|
| public enableStats(): void {
|
| this._stats = new Stats('ms', 'D', 1000 / 25);
|
| }
|
|
|
|
|
|
|
| private _cleanup() {
|
| this._session.id = null;
|
|
|
| this._session.tracklets = [];
|
| }
|
|
|
| private _clearTracklets() {
|
| this._session.tracklets = [];
|
| this._context.clearMasks();
|
| }
|
|
|
| private _updateStreamingState(
|
| state: StreamingState,
|
| forceUpdate: boolean = false,
|
| ) {
|
| if (!forceUpdate && this._streamingState === state) {
|
| return;
|
| }
|
| this._streamingState = state;
|
| this._sendResponse<StreamingStateUpdateResponse>('streamingStateUpdate', {
|
| state,
|
| });
|
| }
|
|
|
| private async _removeTrackletMasks(tracklet: Tracklet) {
|
| this._context.clearTrackletMasks(tracklet);
|
| delete this._session.tracklets[tracklet.id];
|
|
|
|
|
| this._context.goToFrame(this._context.frameIndex);
|
| this._updateTracklets();
|
| }
|
|
|
| private async _updateTrackletMasks(
|
| data: SAM2ModelAddNewPointsMutation$data['addPoints'],
|
| updateThumbnails: boolean,
|
| shouldGoToFrame: boolean = true,
|
| ) {
|
| const {frameIndex, rleMaskList} = data;
|
|
|
|
|
| for (const {objectId, rleMask} of rleMaskList) {
|
| const track = this._session.tracklets[objectId];
|
| const {size, counts} = rleMask;
|
| const rleObject: RLEObject = {
|
| size: [size[0], size[1]],
|
| counts: counts,
|
| };
|
| const isEmpty = counts === this._emptyMask?.counts;
|
|
|
| this._stats?.begin();
|
|
|
| const decodedMask = decode([rleObject]);
|
| const bbox = toBbox([rleObject]);
|
|
|
| const mask: Mask = {
|
| data: rleObject as RLEObject,
|
| shape: [...decodedMask.shape],
|
| bounds: [
|
| [bbox[0], bbox[1]],
|
| [bbox[0] + bbox[2], bbox[1] + bbox[3]],
|
| ],
|
| isEmpty,
|
| } as const;
|
| track.masks[frameIndex] = mask;
|
|
|
| if (updateThumbnails && !isEmpty) {
|
| const {ctx} = await this._compressMaskForCanvas(decodedMask);
|
| const frame = this._context.currentFrame as VideoFrame;
|
| await generateThumbnail(track, frameIndex, mask, frame, ctx);
|
| }
|
| }
|
|
|
| this._context.updateTracklets(
|
| frameIndex,
|
| Object.values(this._session.tracklets),
|
| shouldGoToFrame,
|
| );
|
|
|
|
|
| this._updateTracklets();
|
| }
|
|
|
| private _updateTracklets() {
|
| const tracklets: BaseTracklet[] = Object.values(
|
| this._session.tracklets,
|
| ).map(tracklet => {
|
|
|
| const {
|
| id,
|
| color,
|
| isInitialized,
|
| points: trackletPoints,
|
| thumbnail,
|
| masks,
|
| } = tracklet;
|
| return {
|
| id,
|
| color,
|
| isInitialized,
|
| points: trackletPoints,
|
| thumbnail,
|
| masks: masks.map(mask => ({
|
| shape: mask.shape,
|
| bounds: mask.bounds,
|
| isEmpty: mask.isEmpty,
|
| })),
|
| };
|
| });
|
|
|
| this._sendResponse<TrackletsUpdatedResponse>('trackletsUpdated', {
|
| tracklets,
|
| });
|
| }
|
|
|
| private _clearTrackletMasks() {
|
| const keys = Object.keys(this._session.tracklets);
|
| for (const key of keys) {
|
| const trackletId = Number(key);
|
| const tracklet = {...this._session.tracklets[trackletId], masks: []};
|
| this._session.tracklets[trackletId] = tracklet;
|
| }
|
| this._updateTracklets();
|
| }
|
|
|
| private async _compressMaskForCanvas(
|
| decodedMask: DataArray,
|
| ): Promise<{compressedData: Blob; ctx: OffscreenCanvasRenderingContext2D}> {
|
| const data = convertMaskToRGBA(decodedMask.data as Uint8Array);
|
|
|
| this._maskCanvas.width = decodedMask.shape[0];
|
| this._maskCanvas.height = decodedMask.shape[1];
|
|
|
| const imageData = new ImageData(
|
| data,
|
| decodedMask.shape[0],
|
| decodedMask.shape[1],
|
| );
|
| this._maskCtx.putImageData(imageData, 0, 0);
|
|
|
| const canvas = new OffscreenCanvas(
|
| decodedMask.shape[1],
|
| decodedMask.shape[0],
|
| );
|
|
|
| const ctx = canvas.getContext('2d');
|
| invariant(ctx != null, 'context cannot be null');
|
| ctx.save();
|
| ctx.rotate(Math.PI / 2);
|
|
|
|
|
| ctx.scale(1, -1);
|
| ctx.drawImage(this._maskCanvas, 0, 0);
|
| ctx.restore();
|
|
|
| const compressedData = await canvas.convertToBlob({type: 'image/png'});
|
|
|
| return {compressedData, ctx};
|
| }
|
|
|
| private async *_streamMasksForSession(
|
| abortController: AbortController,
|
| sessionId: string,
|
| startFrameIndex: undefined | number = 0,
|
| ): AsyncGenerator<StreamMasksResult | StreamMasksAbortResult, undefined> {
|
| const url = `${this._endpoint}/propagate_in_video`;
|
|
|
| const requestBody = {
|
| session_id: sessionId,
|
| start_frame_index: startFrameIndex,
|
| };
|
|
|
| const headers: {[name: string]: string} = Object.assign({
|
| 'Content-Type': 'application/json',
|
| });
|
|
|
| const response = await fetch(url, {
|
| method: 'POST',
|
| body: JSON.stringify(requestBody),
|
| headers,
|
| });
|
|
|
| const contentType = response.headers.get('Content-Type');
|
| if (
|
| contentType == null ||
|
| !contentType.startsWith('multipart/x-savi-stream;')
|
| ) {
|
| throw new Error(
|
| 'endpoint needs to support Content-Type "multipart/x-savi-stream"',
|
| );
|
| }
|
|
|
| const responseBody = response.body;
|
| if (responseBody == null) {
|
| throw new Error('response body is null');
|
| }
|
|
|
| const reader = multipartStream(contentType, responseBody).getReader();
|
|
|
| const textDecoder = new TextDecoder();
|
|
|
| while (true) {
|
| if (abortController.signal.aborted) {
|
| reader.releaseLock();
|
| yield {aborted: true};
|
| return;
|
| }
|
|
|
| const {done, value} = await reader.read();
|
| if (done) {
|
| return;
|
| }
|
|
|
| const {headers, body} = value;
|
|
|
| const contentType = headers.get('Content-Type') as string;
|
|
|
| if (contentType.startsWith('application/json')) {
|
| const jsonResponse = JSON.parse(textDecoder.decode(body));
|
| const maskResults = jsonResponse.results;
|
| const rleMaskList = maskResults.map(
|
| (mask: {object_id: number; mask: RLEObject}) => {
|
| return {
|
| objectId: mask.object_id,
|
| rleMask: mask.mask,
|
| };
|
| },
|
| );
|
| yield {
|
| frameIndex: jsonResponse.frame_index,
|
| rleMaskList,
|
| };
|
| }
|
| }
|
| }
|
|
|
| private async _abortRequest(): Promise<void> {
|
| const sessionId = this._session.id;
|
| invariant(sessionId != null, 'session id cannot be empty');
|
| return new Promise((resolve, reject) => {
|
| try {
|
| commitMutation<SAM2ModelCancelPropagateInVideoMutation>(
|
| this._environment,
|
| {
|
| mutation: graphql`
|
| mutation SAM2ModelCancelPropagateInVideoMutation(
|
| $input: CancelPropagateInVideoInput!
|
| ) {
|
| cancelPropagateInVideo(input: $input) {
|
| success
|
| }
|
| }
|
| `,
|
| variables: {
|
| input: {
|
| sessionId,
|
| },
|
| },
|
| onCompleted: response => {
|
| const {success} = response.cancelPropagateInVideo;
|
| if (!success) {
|
| reject(`could not abort session ${sessionId}`);
|
| return;
|
| }
|
| resolve();
|
| },
|
| onError: error => {
|
| Logger.error(error);
|
| reject(error);
|
| },
|
| },
|
| );
|
| } catch (error) {
|
| Logger.error(error);
|
| reject(error);
|
| }
|
| });
|
| }
|
| }
|
|
|