| | import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library'; |
| | import type { Selector } from '@reduxjs/toolkit'; |
| | import { addAppListener } from 'app/store/middleware/listenerMiddleware'; |
| | import type { AppStore, RootState } from 'app/store/store'; |
| | import { withResultAsync } from 'common/util/result'; |
| | import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer'; |
| | import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer'; |
| | import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; |
| | import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; |
| | import type { SubscriptionHandler } from 'features/controlLayers/konva/util'; |
| | import { createReduxSubscription, getPrefixedId } from 'features/controlLayers/konva/util'; |
| | import { |
| | selectCanvasSettingsSlice, |
| | settingsBrushWidthChanged, |
| | settingsColorChanged, |
| | settingsEraserWidthChanged, |
| | } from 'features/controlLayers/store/canvasSettingsSlice'; |
| | import { |
| | bboxChangedFromCanvas, |
| | controlLayerAdded, |
| | entityBrushLineAdded, |
| | entityEraserLineAdded, |
| | entityMoved, |
| | entityRasterized, |
| | entityRectAdded, |
| | entityReset, |
| | inpaintMaskAdded, |
| | rasterLayerAdded, |
| | rgAdded, |
| | } from 'features/controlLayers/store/canvasSlice'; |
| | import { selectCanvasStagingAreaSlice } from 'features/controlLayers/store/canvasStagingAreaSlice'; |
| | import { |
| | selectAllRenderableEntities, |
| | selectBbox, |
| | selectCanvasSlice, |
| | selectGridSize, |
| | } from 'features/controlLayers/store/selectors'; |
| | import type { |
| | CanvasState, |
| | EntityBrushLineAddedPayload, |
| | EntityEraserLineAddedPayload, |
| | EntityIdentifierPayload, |
| | EntityMovedPayload, |
| | EntityRasterizedPayload, |
| | EntityRectAddedPayload, |
| | Rect, |
| | RgbaColor, |
| | } from 'features/controlLayers/store/types'; |
| | import { isRenderableEntityIdentifier, RGBA_BLACK } from 'features/controlLayers/store/types'; |
| | import type { Graph } from 'features/nodes/util/graph/generation/Graph'; |
| | import { atom, computed } from 'nanostores'; |
| | import type { Logger } from 'roarr'; |
| | import { getImageDTO } from 'services/api/endpoints/images'; |
| | import { queueApi } from 'services/api/endpoints/queue'; |
| | import type { BatchConfig, ImageDTO, S } from 'services/api/types'; |
| | import { QueueError } from 'services/events/errors'; |
| | import type { Param0 } from 'tsafe'; |
| | import { assert } from 'tsafe'; |
| |
|
| | import type { CanvasEntityAdapter } from './CanvasEntity/types'; |
| |
|
| | export class CanvasStateApiModule extends CanvasModuleBase { |
| | readonly type = 'state_api'; |
| | readonly id: string; |
| | readonly path: string[]; |
| | readonly parent: CanvasManager; |
| | readonly manager: CanvasManager; |
| | readonly log: Logger; |
| |
|
| | |
| | |
| | |
| | store: AppStore; |
| |
|
| | constructor(store: AppStore, manager: CanvasManager) { |
| | super(); |
| | this.id = getPrefixedId(this.type); |
| | this.parent = manager; |
| | this.manager = manager; |
| | this.path = this.manager.buildPath(this); |
| | this.log = this.manager.buildLogger(this); |
| |
|
| | this.log.debug('Creating state api module'); |
| |
|
| | this.store = store; |
| | } |
| |
|
| | |
| | |
| | |
| | runSelector = <T>(selector: Selector<RootState, T>) => { |
| | return selector(this.store.getState()); |
| | }; |
| |
|
| | |
| | |
| | |
| | createStoreSubscription = <T>(selector: Selector<RootState, T>, handler: SubscriptionHandler<T>) => { |
| | return createReduxSubscription(this.store, selector, handler); |
| | }; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | addStoreListener = (arg: Parameters<typeof addAppListener>[0]) => { |
| | return this.store.dispatch(addAppListener(arg)); |
| | }; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | getCanvasState = (): CanvasState => { |
| | return this.runSelector(selectCanvasSlice); |
| | }; |
| |
|
| | |
| | |
| | |
| | resetEntity = (arg: EntityIdentifierPayload) => { |
| | this.store.dispatch(entityReset(arg)); |
| | }; |
| |
|
| | |
| | |
| | |
| | setEntityPosition = (arg: EntityMovedPayload) => { |
| | this.store.dispatch(entityMoved(arg)); |
| | }; |
| |
|
| | |
| | |
| | |
| | addBrushLine = (arg: EntityBrushLineAddedPayload) => { |
| | this.store.dispatch(entityBrushLineAdded(arg)); |
| | }; |
| |
|
| | |
| | |
| | |
| | addEraserLine = (arg: EntityEraserLineAddedPayload) => { |
| | this.store.dispatch(entityEraserLineAdded(arg)); |
| | }; |
| |
|
| | |
| | |
| | |
| | addRect = (arg: EntityRectAddedPayload) => { |
| | this.store.dispatch(entityRectAdded(arg)); |
| | }; |
| |
|
| | |
| | |
| | |
| | addRasterLayer = (arg: Param0<typeof rasterLayerAdded>) => { |
| | this.store.dispatch(rasterLayerAdded(arg)); |
| | }; |
| |
|
| | |
| | |
| | |
| | addControlLayer = (arg: Param0<typeof controlLayerAdded>) => { |
| | this.store.dispatch(controlLayerAdded(arg)); |
| | }; |
| |
|
| | |
| | |
| | |
| | addInpaintMask = (arg: Param0<typeof inpaintMaskAdded>) => { |
| | this.store.dispatch(inpaintMaskAdded(arg)); |
| | }; |
| |
|
| | |
| | |
| | |
| | addRegionalGuidance = (arg: Param0<typeof rgAdded>) => { |
| | this.store.dispatch(rgAdded(arg)); |
| | }; |
| |
|
| | |
| | |
| | |
| | rasterizeEntity = (arg: EntityRasterizedPayload) => { |
| | this.store.dispatch(entityRasterized(arg)); |
| | }; |
| |
|
| | |
| | |
| | |
| | setGenerationBbox = (rect: Rect) => { |
| | this.store.dispatch(bboxChangedFromCanvas(rect)); |
| | }; |
| |
|
| | |
| | |
| | |
| | setBrushWidth = (width: number) => { |
| | this.store.dispatch(settingsBrushWidthChanged(width)); |
| | }; |
| |
|
| | |
| | |
| | |
| | setEraserWidth = (width: number) => { |
| | this.store.dispatch(settingsEraserWidthChanged(width)); |
| | }; |
| |
|
| | |
| | |
| | |
| | setColor = (color: RgbaColor) => { |
| | return this.store.dispatch(settingsColorChanged(color)); |
| | }; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | runGraphAndReturnImageOutput = (arg: { |
| | graph: Graph; |
| | outputNodeId: string; |
| | destination?: string; |
| | prepend?: boolean; |
| | timeout?: number; |
| | signal?: AbortSignal; |
| | }): Promise<ImageDTO> => { |
| | const { graph, outputNodeId, destination, prepend, timeout, signal } = arg; |
| |
|
| | if (!graph.hasNode(outputNodeId)) { |
| | throw new Error(`Graph does not contain node with id: ${outputNodeId}`); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | const origin = getPrefixedId(graph.id); |
| |
|
| | const batch: BatchConfig = { |
| | prepend, |
| | batch: { |
| | graph: graph.getGraph(), |
| | origin, |
| | destination, |
| | runs: 1, |
| | }, |
| | }; |
| |
|
| | let didSuceed = false; |
| |
|
| | |
| | |
| | |
| | |
| | let timeoutId: number | null = null; |
| | const _clearTimeout = () => { |
| | if (timeoutId !== null) { |
| | window.clearTimeout(timeoutId); |
| | timeoutId = null; |
| | } |
| | }; |
| |
|
| | |
| | |
| | |
| | let cancelGraph: () => void = () => { |
| | this.log.warn('cancelGraph called before cancelGraph is set'); |
| | }; |
| |
|
| | const resultPromise = new Promise<ImageDTO>((resolve, reject) => { |
| | const invocationCompleteHandler = async (event: S['InvocationCompleteEvent']) => { |
| | |
| | if (event.origin !== origin) { |
| | return; |
| | } |
| |
|
| | |
| | if (event.invocation_source_id !== outputNodeId) { |
| | return; |
| | } |
| |
|
| | |
| |
|
| | |
| | _clearTimeout(); |
| | clearListeners(); |
| |
|
| | |
| | const { result } = event; |
| | if (result.type !== 'image_output') { |
| | reject(new Error(`Graph output node did not return an image output, got: ${result}`)); |
| | return; |
| | } |
| |
|
| | |
| | const getImageDTOResult = await withResultAsync(() => getImageDTO(result.image.image_name)); |
| | if (getImageDTOResult.isErr()) { |
| | reject(getImageDTOResult.error); |
| | return; |
| | } |
| |
|
| | didSuceed = true; |
| |
|
| | |
| | resolve(getImageDTOResult.value); |
| | }; |
| |
|
| | const queueItemStatusChangedHandler = (event: S['QueueItemStatusChangedEvent']) => { |
| | |
| | if (event.origin !== origin) { |
| | return; |
| | } |
| |
|
| | |
| | if (event.status === 'pending' || event.status === 'in_progress') { |
| | return; |
| | } |
| |
|
| | if (event.status === 'completed') { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | this.log.warn('Queue item completed without output node completion event'); |
| | return; |
| | } |
| |
|
| | |
| | _clearTimeout(); |
| | clearListeners(); |
| |
|
| | if (event.status === 'failed') { |
| | |
| | const { error_type, error_message, error_traceback } = event; |
| | if (error_type && error_message && error_traceback) { |
| | reject(new QueueError(error_type, error_message, error_traceback)); |
| | } else { |
| | reject(new Error('Queue item failed, but no error details were provided')); |
| | } |
| | } else { |
| | |
| | reject(new Error('Graph canceled')); |
| | } |
| | }; |
| |
|
| | |
| | const enqueueRequest = this.store.dispatch( |
| | queueApi.endpoints.enqueueBatch.initiate(batch, { |
| | |
| | |
| | fixedCacheKey: 'enqueueBatch', |
| | |
| | track: false, |
| | }) |
| | ); |
| |
|
| | |
| | |
| | enqueueRequest |
| | .unwrap() |
| | .then((data) => { |
| | |
| | |
| | const batch_id = data.batch.batch_id; |
| | assert(batch_id, 'Enqueue result is missing batch_id'); |
| | cancelGraph = () => { |
| | this.store.dispatch( |
| | queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batch_id] }, { track: false }) |
| | ); |
| | }; |
| | }) |
| | .catch((error) => { |
| | reject(error); |
| | }); |
| |
|
| | this.manager.socket.on('invocation_complete', invocationCompleteHandler); |
| | this.manager.socket.on('queue_item_status_changed', queueItemStatusChangedHandler); |
| |
|
| | const clearListeners = () => { |
| | this.manager.socket.off('invocation_complete', invocationCompleteHandler); |
| | this.manager.socket.off('queue_item_status_changed', queueItemStatusChangedHandler); |
| | }; |
| |
|
| | if (timeout) { |
| | timeoutId = window.setTimeout(() => { |
| | if (didSuceed) { |
| | |
| | return; |
| | } |
| | this.log.trace('Graph canceled by timeout'); |
| | clearListeners(); |
| | cancelGraph(); |
| | reject(new Error('Graph timed out')); |
| | }, timeout); |
| | } |
| |
|
| | if (signal) { |
| | signal.addEventListener('abort', () => { |
| | if (didSuceed) { |
| | |
| | return; |
| | } |
| | this.log.trace('Graph canceled by signal'); |
| | _clearTimeout(); |
| | clearListeners(); |
| | cancelGraph(); |
| | reject(new Error('Graph canceled')); |
| | }); |
| | } |
| | }); |
| |
|
| | return resultPromise; |
| | }; |
| |
|
| | |
| | |
| | |
| | getBbox = () => { |
| | return this.runSelector(selectBbox); |
| | }; |
| |
|
| | |
| | |
| | |
| | getSettings = () => { |
| | return this.runSelector(selectCanvasSettingsSlice); |
| | }; |
| |
|
| | |
| | |
| | |
| | |
| | getGridSize = (): number => { |
| | const snapToGrid = this.getSettings().snapToGrid; |
| | if (!snapToGrid) { |
| | return 1; |
| | } |
| | const useFine = this.$ctrlKey.get() || this.$metaKey.get(); |
| | if (useFine) { |
| | return 8; |
| | } |
| | return 64; |
| | }; |
| |
|
| | |
| | |
| | |
| | getRegionsState = () => { |
| | return this.getCanvasState().regionalGuidance; |
| | }; |
| |
|
| | |
| | |
| | |
| | getRasterLayersState = () => { |
| | return this.getCanvasState().rasterLayers; |
| | }; |
| |
|
| | |
| | |
| | |
| | getControlLayersState = () => { |
| | return this.getCanvasState().controlLayers; |
| | }; |
| |
|
| | |
| | |
| | |
| | getInpaintMasksState = () => { |
| | return this.getCanvasState().inpaintMasks; |
| | }; |
| |
|
| | |
| | |
| | |
| | getStagingArea = () => { |
| | return this.runSelector(selectCanvasStagingAreaSlice); |
| | }; |
| |
|
| | |
| | |
| | |
| | getBboxGridSize = (): number => { |
| | return this.runSelector(selectGridSize); |
| | }; |
| |
|
| | |
| | |
| | |
| | getIsSelected = (id: string): boolean => { |
| | return this.getCanvasState().selectedEntityIdentifier?.id === id; |
| | }; |
| |
|
| | |
| | |
| | |
| | getRenderedEntityCount = (): number => { |
| | const renderableEntities = selectAllRenderableEntities(this.getCanvasState()); |
| | let count = 0; |
| | for (const entity of renderableEntities) { |
| | if (entity.isEnabled) { |
| | count++; |
| | } |
| | } |
| | return count; |
| | }; |
| |
|
| | |
| | |
| | |
| | getSelectedEntityAdapter = (): CanvasEntityAdapter | null => { |
| | const state = this.getCanvasState(); |
| | if (!state.selectedEntityIdentifier) { |
| | return null; |
| | } |
| | if (!isRenderableEntityIdentifier(state.selectedEntityIdentifier)) { |
| | return null; |
| | } |
| | return this.manager.getAdapter(state.selectedEntityIdentifier); |
| | }; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | getCurrentColor = (): RgbaColor => { |
| | let color: RgbaColor = this.getSettings().color; |
| | const selectedEntity = this.getSelectedEntityAdapter(); |
| | if (selectedEntity) { |
| | |
| | if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') { |
| | color = RGBA_BLACK; |
| | } |
| | } |
| | return color; |
| | }; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | getBrushPreviewColor = (): RgbaColor => { |
| | const selectedEntity = this.getSelectedEntityAdapter(); |
| | if (selectedEntity?.state.type === 'regional_guidance' || selectedEntity?.state.type === 'inpaint_mask') { |
| | |
| | |
| | |
| | |
| | return { ...selectedEntity.state.fill.color, a: 0.5 }; |
| | } else { |
| | return this.getSettings().color; |
| | } |
| | }; |
| |
|
| | |
| | |
| | |
| | $filteringAdapter = atom<CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer | null>(null); |
| |
|
| | |
| | |
| | |
| | $isFiltering = computed(this.$filteringAdapter, (filteringAdapter) => Boolean(filteringAdapter)); |
| |
|
| | |
| | |
| | |
| | $transformingAdapter = atom<CanvasEntityAdapter | null>(null); |
| |
|
| | |
| | |
| | |
| | $isTransforming = computed(this.$transformingAdapter, (transformingAdapter) => Boolean(transformingAdapter)); |
| |
|
| | |
| | |
| | |
| | $rasterizingAdapter = atom<CanvasEntityAdapter | null>(null); |
| |
|
| | |
| | |
| | |
| | $isRasterizing = computed(this.$rasterizingAdapter, (rasterizingAdapter) => Boolean(rasterizingAdapter)); |
| |
|
| | |
| | |
| | |
| | $segmentingAdapter = atom<CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer | null>(null); |
| |
|
| | |
| | |
| | |
| | $isSegmenting = computed(this.$segmentingAdapter, (segmentingAdapter) => Boolean(segmentingAdapter)); |
| |
|
| | |
| | |
| | |
| | $spaceKey = atom<boolean>(false); |
| |
|
| | |
| | |
| | |
| | $altKey = $alt; |
| |
|
| | |
| | |
| | |
| | $ctrlKey = $ctrl; |
| |
|
| | |
| | |
| | |
| | $metaKey = $meta; |
| |
|
| | |
| | |
| | |
| | $shiftKey = $shift; |
| |
|
| | repr = () => { |
| | return { |
| | id: this.id, |
| | type: this.type, |
| | path: this.path, |
| | $filteringAdapter: this.$filteringAdapter.get()?.entityIdentifier ?? null, |
| | $isFiltering: this.$isFiltering.get(), |
| | $transformingAdapter: this.$transformingAdapter.get()?.entityIdentifier ?? null, |
| | $isTransforming: this.$isTransforming.get(), |
| | $rasterizingAdapter: this.$rasterizingAdapter.get()?.entityIdentifier ?? null, |
| | $isRasterizing: this.$isRasterizing.get(), |
| | $segmentingAdapter: this.$segmentingAdapter.get()?.entityIdentifier ?? null, |
| | $isSegmenting: this.$isSegmenting.get(), |
| | }; |
| | }; |
| | } |
| |
|