|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import {generateThumbnail} from '@/common/components/video/editor/VideoEditorUtils'; |
|
|
import VideoWorkerContext from '@/common/components/video/VideoWorkerContext'; |
|
|
import Logger from '@/common/logger/Logger'; |
|
|
import { |
|
|
SAM2ModelAddNewPointsMutation, |
|
|
SAM2ModelAddNewPointsMutation$data, |
|
|
} from '@/common/tracker/__generated__/SAM2ModelAddNewPointsMutation.graphql'; |
|
|
import {SAM2ModelCancelPropagateInVideoMutation} from '@/common/tracker/__generated__/SAM2ModelCancelPropagateInVideoMutation.graphql'; |
|
|
import {SAM2ModelClearPointsInFrameMutation} from '@/common/tracker/__generated__/SAM2ModelClearPointsInFrameMutation.graphql'; |
|
|
import {SAM2ModelClearPointsInVideoMutation} from '@/common/tracker/__generated__/SAM2ModelClearPointsInVideoMutation.graphql'; |
|
|
import {SAM2ModelCloseSessionMutation} from '@/common/tracker/__generated__/SAM2ModelCloseSessionMutation.graphql'; |
|
|
import {SAM2ModelRemoveObjectMutation} from '@/common/tracker/__generated__/SAM2ModelRemoveObjectMutation.graphql'; |
|
|
import {SAM2ModelStartSessionMutation} from '@/common/tracker/__generated__/SAM2ModelStartSessionMutation.graphql'; |
|
|
import { |
|
|
BaseTracklet, |
|
|
Mask, |
|
|
SegmentationPoint, |
|
|
StreamingState, |
|
|
Tracker, |
|
|
Tracklet, |
|
|
} from '@/common/tracker/Tracker'; |
|
|
import {TrackerOptions} from '@/common/tracker/Trackers'; |
|
|
import { |
|
|
ClearPointsInVideoResponse, |
|
|
SessionStartFailedResponse, |
|
|
SessionStartedResponse, |
|
|
StreamingCompletedResponse, |
|
|
StreamingStartedResponse, |
|
|
StreamingStateUpdateResponse, |
|
|
TrackletCreatedResponse, |
|
|
TrackletDeletedResponse, |
|
|
TrackletsUpdatedResponse, |
|
|
} from '@/common/tracker/TrackerTypes'; |
|
|
import {convertMaskToRGBA} from '@/common/utils/MaskUtils'; |
|
|
import multipartStream from '@/common/utils/MultipartStream'; |
|
|
import {Stats} from '@/debug/stats/Stats'; |
|
|
import {INFERENCE_API_ENDPOINT} from '@/demo/DemoConfig'; |
|
|
import {createEnvironment} from '@/graphql/RelayEnvironment'; |
|
|
import { |
|
|
DataArray, |
|
|
Masks, |
|
|
RLEObject, |
|
|
decode, |
|
|
encode, |
|
|
toBbox, |
|
|
} from '@/jscocotools/mask'; |
|
|
import {THEME_COLORS} from '@/theme/colors'; |
|
|
import invariant from 'invariant'; |
|
|
import {IEnvironment, commitMutation, graphql} from 'relay-runtime'; |
|
|
|
|
|
type Options = Pick<TrackerOptions, 'inferenceEndpoint'>; |
|
|
|
|
|
type Session = { |
|
|
id: string | null; |
|
|
tracklets: {[id: number]: Tracklet}; |
|
|
}; |
|
|
|
|
|
type StreamMasksResult = { |
|
|
frameIndex: number; |
|
|
rleMaskList: Array<{ |
|
|
objectId: number; |
|
|
rleMask: RLEObject; |
|
|
}>; |
|
|
}; |
|
|
|
|
|
type StreamMasksAbortResult = { |
|
|
aborted: boolean; |
|
|
}; |
|
|
|
|
|
export class SAM2Model extends Tracker { |
|
|
private _endpoint: string; |
|
|
private _environment: IEnvironment; |
|
|
|
|
|
private abortController: AbortController | null = null; |
|
|
private _session: Session = { |
|
|
id: null, |
|
|
tracklets: {}, |
|
|
}; |
|
|
private _streamingState: StreamingState = 'none'; |
|
|
|
|
|
private _emptyMask: RLEObject | null = null; |
|
|
|
|
|
private _maskCanvas: OffscreenCanvas; |
|
|
private _maskCtx: OffscreenCanvasRenderingContext2D; |
|
|
|
|
|
private _stats?: Stats; |
|
|
|
|
|
constructor( |
|
|
context: VideoWorkerContext, |
|
|
options: Options = { |
|
|
inferenceEndpoint: INFERENCE_API_ENDPOINT, |
|
|
}, |
|
|
) { |
|
|
super(context); |
|
|
this._endpoint = options.inferenceEndpoint; |
|
|
this._environment = createEnvironment(options.inferenceEndpoint); |
|
|
|
|
|
this._maskCanvas = new OffscreenCanvas(0, 0); |
|
|
const maskCtx = this._maskCanvas.getContext('2d'); |
|
|
invariant(maskCtx != null, 'context cannot be null'); |
|
|
this._maskCtx = maskCtx; |
|
|
} |
|
|
|
|
|
public startSession(videoPath: string): Promise<void> { |
|
|
|
|
|
|
|
|
this._updateStreamingState('none', true); |
|
|
|
|
|
return new Promise(resolve => { |
|
|
try { |
|
|
commitMutation<SAM2ModelStartSessionMutation>(this._environment, { |
|
|
mutation: graphql` |
|
|
mutation SAM2ModelStartSessionMutation($input: StartSessionInput!) { |
|
|
startSession(input: $input) { |
|
|
sessionId |
|
|
} |
|
|
} |
|
|
`, |
|
|
variables: { |
|
|
input: { |
|
|
path: videoPath, |
|
|
}, |
|
|
}, |
|
|
onCompleted: response => { |
|
|
const {sessionId} = response.startSession; |
|
|
this._session.id = sessionId; |
|
|
|
|
|
this._sendResponse<SessionStartedResponse>('sessionStarted', { |
|
|
sessionId, |
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
this._clearTracklets(); |
|
|
|
|
|
|
|
|
this.createTracklet(); |
|
|
resolve(); |
|
|
}, |
|
|
onError: error => { |
|
|
Logger.error(error); |
|
|
this._sendResponse<SessionStartFailedResponse>( |
|
|
'sessionStartFailed', |
|
|
); |
|
|
resolve(); |
|
|
}, |
|
|
}); |
|
|
} catch (error) { |
|
|
Logger.error(error); |
|
|
this._sendResponse<SessionStartFailedResponse>('sessionStartFailed'); |
|
|
resolve(); |
|
|
} |
|
|
}); |
|
|
} |
|
|
|
|
|
public closeSession(): Promise<void> { |
|
|
const sessionId = this._session.id; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this._cleanup(); |
|
|
|
|
|
if (sessionId === null) { |
|
|
return Promise.resolve(); |
|
|
} |
|
|
return new Promise((resolve, reject) => { |
|
|
commitMutation<SAM2ModelCloseSessionMutation>(this._environment, { |
|
|
mutation: graphql` |
|
|
mutation SAM2ModelCloseSessionMutation($input: CloseSessionInput!) { |
|
|
closeSession(input: $input) { |
|
|
success |
|
|
} |
|
|
} |
|
|
`, |
|
|
variables: { |
|
|
input: { |
|
|
sessionId, |
|
|
}, |
|
|
}, |
|
|
onCompleted: response => { |
|
|
const {success} = response.closeSession; |
|
|
if (success === false) { |
|
|
reject(new Error('Failed to close session')); |
|
|
return; |
|
|
} |
|
|
resolve(); |
|
|
}, |
|
|
onError: error => { |
|
|
Logger.error(error); |
|
|
reject(error); |
|
|
}, |
|
|
}); |
|
|
}); |
|
|
} |
|
|
|
|
|
public createTracklet(): void { |
|
|
|
|
|
|
|
|
const nextId = |
|
|
Object.values(this._session.tracklets).reduce( |
|
|
(prev, curr) => Math.max(prev, curr.id), |
|
|
-1, |
|
|
) + 1; |
|
|
|
|
|
const newTracklet = { |
|
|
id: nextId, |
|
|
color: THEME_COLORS[nextId % THEME_COLORS.length], |
|
|
thumbnail: null, |
|
|
points: [], |
|
|
masks: [], |
|
|
isInitialized: false, |
|
|
}; |
|
|
|
|
|
this._session.tracklets[nextId] = newTracklet; |
|
|
|
|
|
|
|
|
this._updateTracklets(); |
|
|
|
|
|
this._sendResponse<TrackletCreatedResponse>('trackletCreated', { |
|
|
tracklet: newTracklet, |
|
|
}); |
|
|
} |
|
|
|
|
|
public deleteTracklet(trackletId: number): Promise<void> { |
|
|
const sessionId = this._session.id; |
|
|
if (sessionId === null) { |
|
|
return Promise.reject('No active session'); |
|
|
} |
|
|
|
|
|
const tracklet = this._session.tracklets[trackletId]; |
|
|
invariant( |
|
|
tracklet != null, |
|
|
'tracklet for tracklet id %s not initialized', |
|
|
trackletId, |
|
|
); |
|
|
|
|
|
return new Promise((resolve, reject) => { |
|
|
commitMutation<SAM2ModelRemoveObjectMutation>(this._environment, { |
|
|
mutation: graphql` |
|
|
mutation SAM2ModelRemoveObjectMutation($input: RemoveObjectInput!) { |
|
|
removeObject(input: $input) { |
|
|
frameIndex |
|
|
rleMaskList { |
|
|
objectId |
|
|
rleMask { |
|
|
counts |
|
|
size |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
`, |
|
|
variables: { |
|
|
input: {objectId: trackletId, sessionId}, |
|
|
}, |
|
|
onCompleted: response => { |
|
|
const trackletUpdates = response.removeObject; |
|
|
this._sendResponse<TrackletDeletedResponse>('trackletDeleted', { |
|
|
isSuccessful: true, |
|
|
}); |
|
|
for (const trackletUpdate of trackletUpdates) { |
|
|
this._updateTrackletMasks( |
|
|
trackletUpdate, |
|
|
trackletUpdate.frameIndex === this._context.frameIndex, |
|
|
false, |
|
|
); |
|
|
} |
|
|
this._removeTrackletMasks(tracklet); |
|
|
resolve(); |
|
|
}, |
|
|
onError: error => { |
|
|
this._sendResponse<TrackletDeletedResponse>('trackletDeleted', { |
|
|
isSuccessful: false, |
|
|
}); |
|
|
Logger.error(error); |
|
|
reject(error); |
|
|
}, |
|
|
}); |
|
|
}); |
|
|
} |
|
|
|
|
|
public updatePoints( |
|
|
frameIndex: number, |
|
|
objectId: number, |
|
|
points: SegmentationPoint[], |
|
|
): Promise<void> { |
|
|
const sessionId = this._session.id; |
|
|
if (sessionId === null) { |
|
|
return Promise.reject('No active session'); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (this._emptyMask === null) { |
|
|
|
|
|
|
|
|
const tensor = new Masks( |
|
|
Math.trunc(this._context.height), |
|
|
Math.trunc(this._context.width), |
|
|
1, |
|
|
).toDataArray(); |
|
|
this._emptyMask = encode(tensor)[0]; |
|
|
} |
|
|
|
|
|
const tracklet = this._session.tracklets[objectId]; |
|
|
invariant( |
|
|
tracklet != null, |
|
|
'tracklet for object id %s not initialized', |
|
|
objectId, |
|
|
); |
|
|
|
|
|
|
|
|
this._updateStreamingState('required'); |
|
|
|
|
|
|
|
|
if (points.length === 0) { |
|
|
return this.clearPointsInFrame(frameIndex, objectId); |
|
|
} |
|
|
return new Promise((resolve, reject) => { |
|
|
const normalizedPoints = points.map(p => [ |
|
|
p[0] / this._context.width, |
|
|
p[1] / this._context.height, |
|
|
]); |
|
|
const labels = points.map(p => p[2]); |
|
|
commitMutation<SAM2ModelAddNewPointsMutation>(this._environment, { |
|
|
mutation: graphql` |
|
|
mutation SAM2ModelAddNewPointsMutation($input: AddPointsInput!) { |
|
|
addPoints(input: $input) { |
|
|
frameIndex |
|
|
rleMaskList { |
|
|
objectId |
|
|
rleMask { |
|
|
counts |
|
|
size |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
`, |
|
|
variables: { |
|
|
input: { |
|
|
sessionId, |
|
|
frameIndex, |
|
|
objectId, |
|
|
labels: labels, |
|
|
points: normalizedPoints, |
|
|
clearOldPoints: true, |
|
|
}, |
|
|
}, |
|
|
onCompleted: response => { |
|
|
tracklet.points[frameIndex] = points; |
|
|
tracklet.isInitialized = true; |
|
|
this._updateTrackletMasks(response.addPoints, true); |
|
|
resolve(); |
|
|
}, |
|
|
onError: error => { |
|
|
Logger.error(error); |
|
|
reject(error); |
|
|
}, |
|
|
}); |
|
|
}); |
|
|
} |
|
|
|
|
|
public clearPointsInFrame( |
|
|
frameIndex: number, |
|
|
objectId: number, |
|
|
): Promise<void> { |
|
|
const sessionId = this._session.id; |
|
|
if (sessionId === null) { |
|
|
return Promise.reject('No active session'); |
|
|
} |
|
|
|
|
|
const tracklet = this._session.tracklets[objectId]; |
|
|
invariant( |
|
|
tracklet != null, |
|
|
'tracklet for object id %s not initialized', |
|
|
objectId, |
|
|
); |
|
|
|
|
|
|
|
|
this._updateStreamingState('required'); |
|
|
|
|
|
return new Promise((resolve, reject) => { |
|
|
commitMutation<SAM2ModelClearPointsInFrameMutation>(this._environment, { |
|
|
mutation: graphql` |
|
|
mutation SAM2ModelClearPointsInFrameMutation( |
|
|
$input: ClearPointsInFrameInput! |
|
|
) { |
|
|
clearPointsInFrame(input: $input) { |
|
|
frameIndex |
|
|
rleMaskList { |
|
|
objectId |
|
|
rleMask { |
|
|
counts |
|
|
size |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
`, |
|
|
variables: { |
|
|
input: { |
|
|
sessionId, |
|
|
frameIndex, |
|
|
objectId, |
|
|
}, |
|
|
}, |
|
|
onCompleted: response => { |
|
|
tracklet.points[frameIndex] = []; |
|
|
tracklet.isInitialized = true; |
|
|
this._updateTrackletMasks(response.clearPointsInFrame, true); |
|
|
resolve(); |
|
|
}, |
|
|
onError: error => { |
|
|
Logger.error(error); |
|
|
reject(error); |
|
|
}, |
|
|
}); |
|
|
}); |
|
|
} |
|
|
|
|
|
public clearPointsInVideo(): Promise<void> { |
|
|
const sessionId = this._session.id; |
|
|
if (sessionId === null) { |
|
|
return Promise.reject('No active session'); |
|
|
} |
|
|
|
|
|
|
|
|
this._updateStreamingState('none'); |
|
|
|
|
|
return new Promise(resolve => { |
|
|
commitMutation<SAM2ModelClearPointsInVideoMutation>(this._environment, { |
|
|
mutation: graphql` |
|
|
mutation SAM2ModelClearPointsInVideoMutation( |
|
|
$input: ClearPointsInVideoInput! |
|
|
) { |
|
|
clearPointsInVideo(input: $input) { |
|
|
success |
|
|
} |
|
|
} |
|
|
`, |
|
|
variables: { |
|
|
input: { |
|
|
sessionId, |
|
|
}, |
|
|
}, |
|
|
onCompleted: response => { |
|
|
const {success} = response.clearPointsInVideo; |
|
|
if (!success) { |
|
|
this._sendResponse<ClearPointsInVideoResponse>( |
|
|
'clearPointsInVideo', |
|
|
{isSuccessful: false}, |
|
|
); |
|
|
return; |
|
|
} |
|
|
|
|
|
|
|
|
this._clearTracklets(); |
|
|
|
|
|
|
|
|
this._context.goToFrame(this._context.frameIndex); |
|
|
this._updateTracklets(); |
|
|
this._sendResponse<ClearPointsInVideoResponse>('clearPointsInVideo', { |
|
|
isSuccessful: true, |
|
|
}); |
|
|
resolve(); |
|
|
}, |
|
|
onError: error => { |
|
|
this._sendResponse<ClearPointsInVideoResponse>('clearPointsInVideo', { |
|
|
isSuccessful: false, |
|
|
}); |
|
|
Logger.error(error); |
|
|
}, |
|
|
}); |
|
|
}); |
|
|
} |
|
|
|
|
|
public async streamMasks(frameIndex: number): Promise<void> { |
|
|
const sessionId = this._session.id; |
|
|
if (sessionId === null) { |
|
|
return Promise.reject('No active session'); |
|
|
} |
|
|
try { |
|
|
this._sendResponse<StreamingStartedResponse>('streamingStarted'); |
|
|
|
|
|
|
|
|
this._context.clearMasks(); |
|
|
this._clearTrackletMasks(); |
|
|
|
|
|
|
|
|
const controller = new AbortController(); |
|
|
this.abortController = controller; |
|
|
|
|
|
this._updateStreamingState('requesting'); |
|
|
const generator = this._streamMasksForSession( |
|
|
controller, |
|
|
sessionId, |
|
|
frameIndex, |
|
|
); |
|
|
|
|
|
|
|
|
let isAborted = false; |
|
|
for await (const result of generator) { |
|
|
if ('aborted' in result) { |
|
|
this._updateStreamingState('aborting'); |
|
|
await this._abortRequest(); |
|
|
this._updateStreamingState('aborted'); |
|
|
isAborted = true; |
|
|
} else { |
|
|
await this._updateTrackletMasks(result, false); |
|
|
this._updateStreamingState('partial'); |
|
|
} |
|
|
} |
|
|
|
|
|
if (!isAborted) { |
|
|
|
|
|
this._updateStreamingState('full'); |
|
|
} |
|
|
} catch (error) { |
|
|
Logger.error(error); |
|
|
throw error; |
|
|
} |
|
|
|
|
|
this._sendResponse<StreamingCompletedResponse>('streamingCompleted'); |
|
|
} |
|
|
|
|
|
public abortStreamMasks() { |
|
|
this.abortController?.abort(); |
|
|
this._sendResponse<StreamingCompletedResponse>('streamingCompleted'); |
|
|
} |
|
|
|
|
|
public enableStats(): void { |
|
|
this._stats = new Stats('ms', 'D', 1000 / 25); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private _cleanup() { |
|
|
this._session.id = null; |
|
|
|
|
|
this._session.tracklets = []; |
|
|
} |
|
|
|
|
|
private _clearTracklets() { |
|
|
this._session.tracklets = []; |
|
|
this._context.clearMasks(); |
|
|
} |
|
|
|
|
|
private _updateStreamingState( |
|
|
state: StreamingState, |
|
|
forceUpdate: boolean = false, |
|
|
) { |
|
|
if (!forceUpdate && this._streamingState === state) { |
|
|
return; |
|
|
} |
|
|
this._streamingState = state; |
|
|
this._sendResponse<StreamingStateUpdateResponse>('streamingStateUpdate', { |
|
|
state, |
|
|
}); |
|
|
} |
|
|
|
|
|
private async _removeTrackletMasks(tracklet: Tracklet) { |
|
|
this._context.clearTrackletMasks(tracklet); |
|
|
delete this._session.tracklets[tracklet.id]; |
|
|
|
|
|
|
|
|
this._context.goToFrame(this._context.frameIndex); |
|
|
this._updateTracklets(); |
|
|
} |
|
|
|
|
|
private async _updateTrackletMasks( |
|
|
data: SAM2ModelAddNewPointsMutation$data['addPoints'], |
|
|
updateThumbnails: boolean, |
|
|
shouldGoToFrame: boolean = true, |
|
|
) { |
|
|
const {frameIndex, rleMaskList} = data; |
|
|
|
|
|
|
|
|
for (const {objectId, rleMask} of rleMaskList) { |
|
|
const track = this._session.tracklets[objectId]; |
|
|
const {size, counts} = rleMask; |
|
|
const rleObject: RLEObject = { |
|
|
size: [size[0], size[1]], |
|
|
counts: counts, |
|
|
}; |
|
|
const isEmpty = counts === this._emptyMask?.counts; |
|
|
|
|
|
this._stats?.begin(); |
|
|
|
|
|
const decodedMask = decode([rleObject]); |
|
|
const bbox = toBbox([rleObject]); |
|
|
|
|
|
const mask: Mask = { |
|
|
data: rleObject as RLEObject, |
|
|
shape: [...decodedMask.shape], |
|
|
bounds: [ |
|
|
[bbox[0], bbox[1]], |
|
|
[bbox[0] + bbox[2], bbox[1] + bbox[3]], |
|
|
], |
|
|
isEmpty, |
|
|
} as const; |
|
|
track.masks[frameIndex] = mask; |
|
|
|
|
|
if (updateThumbnails && !isEmpty) { |
|
|
const {ctx} = await this._compressMaskForCanvas(decodedMask); |
|
|
const frame = this._context.currentFrame as VideoFrame; |
|
|
await generateThumbnail(track, frameIndex, mask, frame, ctx); |
|
|
} |
|
|
} |
|
|
|
|
|
this._context.updateTracklets( |
|
|
frameIndex, |
|
|
Object.values(this._session.tracklets), |
|
|
shouldGoToFrame, |
|
|
); |
|
|
|
|
|
|
|
|
this._updateTracklets(); |
|
|
} |
|
|
|
|
|
private _updateTracklets() { |
|
|
const tracklets: BaseTracklet[] = Object.values( |
|
|
this._session.tracklets, |
|
|
).map(tracklet => { |
|
|
|
|
|
const { |
|
|
id, |
|
|
color, |
|
|
isInitialized, |
|
|
points: trackletPoints, |
|
|
thumbnail, |
|
|
masks, |
|
|
} = tracklet; |
|
|
return { |
|
|
id, |
|
|
color, |
|
|
isInitialized, |
|
|
points: trackletPoints, |
|
|
thumbnail, |
|
|
masks: masks.map(mask => ({ |
|
|
shape: mask.shape, |
|
|
bounds: mask.bounds, |
|
|
isEmpty: mask.isEmpty, |
|
|
})), |
|
|
}; |
|
|
}); |
|
|
|
|
|
this._sendResponse<TrackletsUpdatedResponse>('trackletsUpdated', { |
|
|
tracklets, |
|
|
}); |
|
|
} |
|
|
|
|
|
private _clearTrackletMasks() { |
|
|
const keys = Object.keys(this._session.tracklets); |
|
|
for (const key of keys) { |
|
|
const trackletId = Number(key); |
|
|
const tracklet = {...this._session.tracklets[trackletId], masks: []}; |
|
|
this._session.tracklets[trackletId] = tracklet; |
|
|
} |
|
|
this._updateTracklets(); |
|
|
} |
|
|
|
|
|
private async _compressMaskForCanvas( |
|
|
decodedMask: DataArray, |
|
|
): Promise<{compressedData: Blob; ctx: OffscreenCanvasRenderingContext2D}> { |
|
|
const data = convertMaskToRGBA(decodedMask.data as Uint8Array); |
|
|
|
|
|
this._maskCanvas.width = decodedMask.shape[0]; |
|
|
this._maskCanvas.height = decodedMask.shape[1]; |
|
|
|
|
|
const imageData = new ImageData( |
|
|
data, |
|
|
decodedMask.shape[0], |
|
|
decodedMask.shape[1], |
|
|
); |
|
|
this._maskCtx.putImageData(imageData, 0, 0); |
|
|
|
|
|
const canvas = new OffscreenCanvas( |
|
|
decodedMask.shape[1], |
|
|
decodedMask.shape[0], |
|
|
); |
|
|
|
|
|
const ctx = canvas.getContext('2d'); |
|
|
invariant(ctx != null, 'context cannot be null'); |
|
|
ctx.save(); |
|
|
ctx.rotate(Math.PI / 2); |
|
|
|
|
|
|
|
|
ctx.scale(1, -1); |
|
|
ctx.drawImage(this._maskCanvas, 0, 0); |
|
|
ctx.restore(); |
|
|
|
|
|
const compressedData = await canvas.convertToBlob({type: 'image/png'}); |
|
|
|
|
|
return {compressedData, ctx}; |
|
|
} |
|
|
|
|
|
private async *_streamMasksForSession( |
|
|
abortController: AbortController, |
|
|
sessionId: string, |
|
|
startFrameIndex: undefined | number = 0, |
|
|
): AsyncGenerator<StreamMasksResult | StreamMasksAbortResult, undefined> { |
|
|
const url = `${this._endpoint}/propagate_in_video`; |
|
|
|
|
|
const requestBody = { |
|
|
session_id: sessionId, |
|
|
start_frame_index: startFrameIndex, |
|
|
}; |
|
|
|
|
|
const headers: {[name: string]: string} = Object.assign({ |
|
|
'Content-Type': 'application/json', |
|
|
}); |
|
|
|
|
|
const response = await fetch(url, { |
|
|
method: 'POST', |
|
|
body: JSON.stringify(requestBody), |
|
|
headers, |
|
|
}); |
|
|
|
|
|
const contentType = response.headers.get('Content-Type'); |
|
|
if ( |
|
|
contentType == null || |
|
|
!contentType.startsWith('multipart/x-savi-stream;') |
|
|
) { |
|
|
throw new Error( |
|
|
'endpoint needs to support Content-Type "multipart/x-savi-stream"', |
|
|
); |
|
|
} |
|
|
|
|
|
const responseBody = response.body; |
|
|
if (responseBody == null) { |
|
|
throw new Error('response body is null'); |
|
|
} |
|
|
|
|
|
const reader = multipartStream(contentType, responseBody).getReader(); |
|
|
|
|
|
const textDecoder = new TextDecoder(); |
|
|
|
|
|
while (true) { |
|
|
if (abortController.signal.aborted) { |
|
|
reader.releaseLock(); |
|
|
yield {aborted: true}; |
|
|
return; |
|
|
} |
|
|
|
|
|
const {done, value} = await reader.read(); |
|
|
if (done) { |
|
|
return; |
|
|
} |
|
|
|
|
|
const {headers, body} = value; |
|
|
|
|
|
const contentType = headers.get('Content-Type') as string; |
|
|
|
|
|
if (contentType.startsWith('application/json')) { |
|
|
const jsonResponse = JSON.parse(textDecoder.decode(body)); |
|
|
const maskResults = jsonResponse.results; |
|
|
const rleMaskList = maskResults.map( |
|
|
(mask: {object_id: number; mask: RLEObject}) => { |
|
|
return { |
|
|
objectId: mask.object_id, |
|
|
rleMask: mask.mask, |
|
|
}; |
|
|
}, |
|
|
); |
|
|
yield { |
|
|
frameIndex: jsonResponse.frame_index, |
|
|
rleMaskList, |
|
|
}; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
private async _abortRequest(): Promise<void> { |
|
|
const sessionId = this._session.id; |
|
|
invariant(sessionId != null, 'session id cannot be empty'); |
|
|
return new Promise((resolve, reject) => { |
|
|
try { |
|
|
commitMutation<SAM2ModelCancelPropagateInVideoMutation>( |
|
|
this._environment, |
|
|
{ |
|
|
mutation: graphql` |
|
|
mutation SAM2ModelCancelPropagateInVideoMutation( |
|
|
$input: CancelPropagateInVideoInput! |
|
|
) { |
|
|
cancelPropagateInVideo(input: $input) { |
|
|
success |
|
|
} |
|
|
} |
|
|
`, |
|
|
variables: { |
|
|
input: { |
|
|
sessionId, |
|
|
}, |
|
|
}, |
|
|
onCompleted: response => { |
|
|
const {success} = response.cancelPropagateInVideo; |
|
|
if (!success) { |
|
|
reject(`could not abort session ${sessionId}`); |
|
|
return; |
|
|
} |
|
|
resolve(); |
|
|
}, |
|
|
onError: error => { |
|
|
Logger.error(error); |
|
|
reject(error); |
|
|
}, |
|
|
}, |
|
|
); |
|
|
} catch (error) { |
|
|
Logger.error(error); |
|
|
reject(error); |
|
|
} |
|
|
}); |
|
|
} |
|
|
} |
|
|
|