import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Box, Flex, Icon, Input, Modal, ModalBody, ModalContent, ModalOverlay, Spacer, Text, } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; import { $cursorPos, $edgePendingUpdate, $pendingConnection, $templates, edgesChanged, nodesChanged, useAddNodeCmdk, } from 'features/nodes/store/nodesSlice'; import { selectNodesSlice } from 'features/nodes/store/selectors'; import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; import { toast } from 'features/toast/toast'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import { memoize } from 'lodash-es'; import { computed } from 'nanostores'; import type { ChangeEvent } from 'react'; import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { PiCircuitryBold, PiFlaskBold, PiHammerBold } from 'react-icons/pi'; import type { EdgeChange, NodeChange } from 'reactflow'; import type { S } from 'services/api/types'; const useThrottle = (value: T, limit: number) => { const [throttledValue, setThrottledValue] = useState(value); const lastRan = useRef(Date.now()); useEffect(() => { const handler = setTimeout( function () { if (Date.now() - lastRan.current >= limit) { setThrottledValue(value); lastRan.current = Date.now(); } }, limit - (Date.now() - lastRan.current) ); return () => { clearTimeout(handler); }; }, [value, limit]); return throttledValue; }; const useAddNode = () => { const { t } = useTranslation(); const store = useAppStore(); const buildInvocation = useBuildNode(); const templates = useStore($templates); const pendingConnection = useStore($pendingConnection); const addNode = useCallback( (nodeType: string): void => { const node = buildInvocation(nodeType); if (!node) { const errorMessage = t('nodes.unknownNode', { nodeType: nodeType, }); toast({ status: 'error', title: errorMessage, }); return; } // Find a cozy spot for the node const cursorPos = $cursorPos.get(); const { nodes, edges } = selectNodesSlice(store.getState()); node.position = findUnoccupiedPosition(nodes, cursorPos?.x ?? node.position.x, cursorPos?.y ?? node.position.y); node.selected = true; // Deselect all other nodes and edges const nodeChanges: NodeChange[] = [{ type: 'add', item: node }]; const edgeChanges: EdgeChange[] = []; nodes.forEach(({ id, selected }) => { if (selected) { nodeChanges.push({ type: 'select', id, selected: false }); } }); edges.forEach(({ id, selected }) => { if (selected) { edgeChanges.push({ type: 'select', id, selected: false }); } }); // Onwards! if (nodeChanges.length > 0) { store.dispatch(nodesChanged(nodeChanges)); } if (edgeChanges.length > 0) { store.dispatch(edgesChanged(edgeChanges)); } // Auto-connect an edge if we just added a node and have a pending connection if (pendingConnection && isInvocationNode(node)) { const edgePendingUpdate = $edgePendingUpdate.get(); const { handleType } = pendingConnection; const source = handleType === 'source' ? pendingConnection.nodeId : node.id; const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null; const target = handleType === 'target' ? pendingConnection.nodeId : node.id; const targetHandle = handleType === 'target' ? pendingConnection.handleId : null; const { nodes, edges } = selectNodesSlice(store.getState()); const connection = getFirstValidConnection( source, sourceHandle, target, targetHandle, nodes, edges, templates, edgePendingUpdate ); if (connection) { const newEdge = connectionToEdge(connection); store.dispatch(edgesChanged([{ type: 'add', item: newEdge }])); } } }, [buildInvocation, pendingConnection, store, t, templates] ); return addNode; }; const cmdkRootSx: SystemStyleObject = { '[cmdk-root]': { w: 'full', h: 'full', }, '[cmdk-list]': { w: 'full', h: 'full', }, }; export const AddNodeCmdk = memo(() => { const { t } = useTranslation(); const addNodeCmdk = useAddNodeCmdk(); const inputRef = useRef(null); const [searchTerm, setSearchTerm] = useState(''); const addNode = useAddNode(); const tab = useAppSelector(selectActiveTab); const throttledSearchTerm = useThrottle(searchTerm, 100); useRegisteredHotkeys({ id: 'addNode', category: 'workflows', callback: addNodeCmdk.setTrue, options: { enabled: tab === 'workflows', preventDefault: true }, dependencies: [addNodeCmdk.setTrue, tab], }); const onChange = useCallback((e: ChangeEvent) => { setSearchTerm(e.target.value); }, []); const onClose = useCallback(() => { addNodeCmdk.setFalse(); setSearchTerm(''); $pendingConnection.set(null); }, [addNodeCmdk]); const onSelect = useCallback( (value: string) => { addNode(value); onClose(); }, [addNode, onClose] ); return ( ); }); AddNodeCmdk.displayName = 'AddNodeCmdk'; const cmdkItemSx: SystemStyleObject = { '&[data-selected="true"]': { bg: 'base.700', }, }; type NodeCommandItemData = { value: string; label: string; description: string; classification: S['Classification']; nodePack: string; }; /** * An array of all templates, excluding deprecated ones. */ const $templatesArray = computed($templates, (templates) => Object.values(templates).filter((template) => template.classification !== 'deprecated') ); const createRegex = memoize( (inputValue: string) => new RegExp( inputValue .trim() .replace(/[-[\]{}()*+!<=:?./\\^$|#,]/g, '') .split(' ') .join('.*'), 'gi' ) ); // Filterable items are a subset of Invocation template - we also want to filter for notes or current image node, // so we are using a less specific type instead of `InvocationTemplate` type FilterableItem = { type: string; title: string; description: string; tags: string[]; classification: S['Classification']; nodePack: string; }; const filter = memoize( (item: FilterableItem, searchTerm: string) => { const regex = createRegex(searchTerm); if (!searchTerm) { return true; } if (item.title.includes(searchTerm) || regex.test(item.title)) { return true; } if (item.type.includes(searchTerm) || regex.test(item.type)) { return true; } if (item.description.includes(searchTerm) || regex.test(item.description)) { return true; } if (item.nodePack.includes(searchTerm) || regex.test(item.nodePack)) { return true; } if (item.classification.includes(searchTerm) || regex.test(item.classification)) { return true; } for (const tag of item.tags) { if (tag.includes(searchTerm) || regex.test(tag)) { return true; } } return false; }, (item: FilterableItem, searchTerm: string) => `${item.type}-${searchTerm}` ); const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; onSelect: (value: string) => void }) => { const { t } = useTranslation(); const templatesArray = useStore($templatesArray); const pendingConnection = useStore($pendingConnection); const currentImageFilterItem = useMemo( () => ({ type: 'current_image', title: t('nodes.currentImage'), description: t('nodes.currentImageDescription'), tags: ['progress', 'image', 'current'], classification: 'stable', nodePack: 'invokeai', }), [t] ); const notesFilterItem = useMemo( () => ({ type: 'notes', title: t('nodes.notes'), description: t('nodes.notesDescription'), tags: ['notes'], classification: 'stable', nodePack: 'invokeai', }), [t] ); const items = useMemo(() => { // If we have a connection in progress, we need to filter the node choices const _items: NodeCommandItemData[] = []; if (!pendingConnection) { for (const template of templatesArray) { if (filter(template, searchTerm)) { _items.push({ label: template.title, value: template.type, description: template.description, classification: template.classification, nodePack: template.nodePack, }); } } for (const item of [currentImageFilterItem, notesFilterItem]) { if (filter(item, searchTerm)) { _items.push({ label: item.title, value: item.type, description: item.description, classification: item.classification, nodePack: item.nodePack, }); } } } else { for (const template of templatesArray) { if (filter(template, searchTerm)) { const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs; for (const field of Object.values(candidateFields)) { const sourceType = pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type; const targetType = pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type; if (validateConnectionTypes(sourceType, targetType)) { _items.push({ label: template.title, value: template.type, description: template.description, classification: template.classification, nodePack: template.nodePack, }); break; } } } } } return _items; }, [pendingConnection, currentImageFilterItem, searchTerm, notesFilterItem, templatesArray]); return ( <> {items.map((item) => ( {item.classification === 'beta' && } {item.classification === 'prototype' && } {item.classification === 'internal' && } {item.label} {item.nodePack} {item.description && {item.description}} ))} ); }); NodeCommandList.displayName = 'CommandListItems';