| import { useSetModalState } from '@/hooks/common-hooks'; |
| import { useFetchFlow, useResetFlow, useSetFlow } from '@/hooks/flow-hooks'; |
| import { IGraph } from '@/interfaces/database/flow'; |
| import { useIsFetching } from '@tanstack/react-query'; |
| import React, { |
| ChangeEvent, |
| KeyboardEventHandler, |
| useCallback, |
| useEffect, |
| useMemo, |
| useState, |
| } from 'react'; |
| import { Connection, Edge, Node, Position, ReactFlowInstance } from 'reactflow'; |
| |
| import { variableEnabledFieldMap } from '@/constants/chat'; |
| import { |
| ModelVariableType, |
| settledModelVariableMap, |
| } from '@/constants/knowledge'; |
| import { useFetchModelId, useSendMessageWithSse } from '@/hooks/logic-hooks'; |
| import { Variable } from '@/interfaces/database/chat'; |
| import api from '@/utils/api'; |
| import { useDebounceEffect } from 'ahooks'; |
| import { FormInstance, message } from 'antd'; |
| import { humanId } from 'human-id'; |
| import trim from 'lodash/trim'; |
| import { useParams } from 'umi'; |
| import { v4 as uuid } from 'uuid'; |
| import { |
| NodeMap, |
| Operator, |
| RestrictedUpstreamMap, |
| SwitchElseTo, |
| initialAkShareValues, |
| initialArXivValues, |
| initialBaiduFanyiValues, |
| initialBaiduValues, |
| initialBeginValues, |
| initialBingValues, |
| initialCategorizeValues, |
| initialConcentratorValues, |
| initialDeepLValues, |
| initialDuckValues, |
| initialExeSqlValues, |
| initialGenerateValues, |
| initialGithubValues, |
| initialGoogleScholarValues, |
| initialGoogleValues, |
| initialJin10Values, |
| initialKeywordExtractValues, |
| initialMessageValues, |
| initialPubMedValues, |
| initialQWeatherValues, |
| initialRelevantValues, |
| initialRetrievalValues, |
| initialRewriteQuestionValues, |
| initialSwitchValues, |
| initialWenCaiValues, |
| initialWikipediaValues, |
| initialYahooFinanceValues, |
| } from './constant'; |
| import { ICategorizeForm, IRelevantForm, ISwitchForm } from './interface'; |
| import useGraphStore, { RFState } from './store'; |
| import { |
| buildDslComponentsByGraph, |
| generateSwitchHandleText, |
| receiveMessageError, |
| replaceIdWithText, |
| } from './utils'; |
|
|
| const selector = (state: RFState) => ({ |
| nodes: state.nodes, |
| edges: state.edges, |
| onNodesChange: state.onNodesChange, |
| onEdgesChange: state.onEdgesChange, |
| onConnect: state.onConnect, |
| setNodes: state.setNodes, |
| onSelectionChange: state.onSelectionChange, |
| }); |
|
|
| export const useSelectCanvasData = () => { |
| |
| |
| return useGraphStore(selector); |
| }; |
|
|
| export const useInitializeOperatorParams = () => { |
| const llmId = useFetchModelId(); |
|
|
| const initialFormValuesMap = useMemo(() => { |
| return { |
| [Operator.Begin]: initialBeginValues, |
| [Operator.Retrieval]: initialRetrievalValues, |
| [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId }, |
| [Operator.Answer]: {}, |
| [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId }, |
| [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId }, |
| [Operator.RewriteQuestion]: { |
| ...initialRewriteQuestionValues, |
| llm_id: llmId, |
| }, |
| [Operator.Message]: initialMessageValues, |
| [Operator.KeywordExtract]: { |
| ...initialKeywordExtractValues, |
| llm_id: llmId, |
| }, |
| [Operator.DuckDuckGo]: initialDuckValues, |
| [Operator.Baidu]: initialBaiduValues, |
| [Operator.Wikipedia]: initialWikipediaValues, |
| [Operator.PubMed]: initialPubMedValues, |
| [Operator.ArXiv]: initialArXivValues, |
| [Operator.Google]: initialGoogleValues, |
| [Operator.Bing]: initialBingValues, |
| [Operator.GoogleScholar]: initialGoogleScholarValues, |
| [Operator.DeepL]: initialDeepLValues, |
| [Operator.GitHub]: initialGithubValues, |
| [Operator.BaiduFanyi]: initialBaiduFanyiValues, |
| [Operator.QWeather]: initialQWeatherValues, |
| [Operator.ExeSQL]: initialExeSqlValues, |
| [Operator.Switch]: initialSwitchValues, |
| [Operator.WenCai]: initialWenCaiValues, |
| [Operator.AkShare]: initialAkShareValues, |
| [Operator.YahooFinance]: initialYahooFinanceValues, |
| [Operator.Jin10]: initialJin10Values, |
| [Operator.Concentrator]: initialConcentratorValues, |
| }; |
| }, [llmId]); |
|
|
| const initializeOperatorParams = useCallback( |
| (operatorName: Operator) => { |
| return initialFormValuesMap[operatorName]; |
| }, |
| [initialFormValuesMap], |
| ); |
|
|
| return initializeOperatorParams; |
| }; |
|
|
| export const useHandleDrag = () => { |
| const handleDragStart = useCallback( |
| (operatorId: string) => (ev: React.DragEvent<HTMLDivElement>) => { |
| ev.dataTransfer.setData('application/reactflow', operatorId); |
| ev.dataTransfer.effectAllowed = 'move'; |
| }, |
| [], |
| ); |
|
|
| return { handleDragStart }; |
| }; |
|
|
| export const useHandleDrop = () => { |
| const addNode = useGraphStore((state) => state.addNode); |
| const [reactFlowInstance, setReactFlowInstance] = |
| useState<ReactFlowInstance<any, any>>(); |
| const initializeOperatorParams = useInitializeOperatorParams(); |
|
|
| const onDragOver = useCallback((event: React.DragEvent<HTMLDivElement>) => { |
| event.preventDefault(); |
| event.dataTransfer.dropEffect = 'move'; |
| }, []); |
|
|
| const onDrop = useCallback( |
| (event: React.DragEvent<HTMLDivElement>) => { |
| event.preventDefault(); |
|
|
| const type = event.dataTransfer.getData('application/reactflow'); |
|
|
| |
| if (typeof type === 'undefined' || !type) { |
| return; |
| } |
|
|
| |
| |
| |
| const position = reactFlowInstance?.screenToFlowPosition({ |
| x: event.clientX, |
| y: event.clientY, |
| }); |
| const newNode = { |
| id: `${type}:${humanId()}`, |
| type: NodeMap[type as Operator] || 'ragNode', |
| position: position || { |
| x: 0, |
| y: 0, |
| }, |
| data: { |
| label: `${type}`, |
| name: humanId(), |
| form: initializeOperatorParams(type as Operator), |
| }, |
| sourcePosition: Position.Right, |
| targetPosition: Position.Left, |
| }; |
|
|
| addNode(newNode); |
| }, |
| [reactFlowInstance, addNode, initializeOperatorParams], |
| ); |
|
|
| return { onDrop, onDragOver, setReactFlowInstance }; |
| }; |
|
|
| export const useShowDrawer = () => { |
| const { |
| clickedNodeId: clickNodeId, |
| setClickedNodeId, |
| getNode, |
| } = useGraphStore((state) => state); |
| const { |
| visible: drawerVisible, |
| hideModal: hideDrawer, |
| showModal: showDrawer, |
| } = useSetModalState(); |
|
|
| const handleShow = useCallback( |
| (node: Node) => { |
| setClickedNodeId(node.id); |
| showDrawer(); |
| }, |
| [showDrawer, setClickedNodeId], |
| ); |
|
|
| return { |
| drawerVisible, |
| hideDrawer, |
| showDrawer: handleShow, |
| clickedNode: getNode(clickNodeId), |
| }; |
| }; |
|
|
| export const useHandleKeyUp = () => { |
| const deleteEdge = useGraphStore((state) => state.deleteEdge); |
| const handleKeyUp: KeyboardEventHandler = useCallback( |
| (e) => { |
| if (e.code === 'Delete') { |
| deleteEdge(); |
| } |
| }, |
| [deleteEdge], |
| ); |
|
|
| return { handleKeyUp }; |
| }; |
|
|
| export const useSaveGraph = () => { |
| const { data } = useFetchFlow(); |
| const { setFlow } = useSetFlow(); |
| const { id } = useParams(); |
| const { nodes, edges } = useGraphStore((state) => state); |
| const saveGraph = useCallback(async () => { |
| const dslComponents = buildDslComponentsByGraph(nodes, edges); |
| return setFlow({ |
| id, |
| title: data.title, |
| dsl: { ...data.dsl, graph: { nodes, edges }, components: dslComponents }, |
| }); |
| }, [nodes, edges, setFlow, id, data]); |
|
|
| return { saveGraph }; |
| }; |
|
|
| export const useWatchGraphChange = () => { |
| const nodes = useGraphStore((state) => state.nodes); |
| const edges = useGraphStore((state) => state.edges); |
| useDebounceEffect( |
| () => { |
| |
| }, |
| [nodes, edges], |
| { |
| wait: 1000, |
| }, |
| ); |
| }; |
|
|
| export const useHandleFormValuesChange = (id?: string) => { |
| const updateNodeForm = useGraphStore((state) => state.updateNodeForm); |
| const handleValuesChange = useCallback( |
| (changedValues: any, values: any) => { |
| let nextValues: any = values; |
| |
| if ( |
| Object.keys(changedValues).length === 1 && |
| 'parameter' in changedValues && |
| changedValues['parameter'] in settledModelVariableMap |
| ) { |
| nextValues = { |
| ...values, |
| ...settledModelVariableMap[ |
| changedValues['parameter'] as keyof typeof settledModelVariableMap |
| ], |
| }; |
| } |
| if (id) { |
| updateNodeForm(id, nextValues); |
| } |
| }, |
| [updateNodeForm, id], |
| ); |
|
|
| return { handleValuesChange }; |
| }; |
|
|
| const useSetGraphInfo = () => { |
| const { setEdges, setNodes } = useGraphStore((state) => state); |
| const setGraphInfo = useCallback( |
| ({ nodes = [], edges = [] }: IGraph) => { |
| if (nodes.length || edges.length) { |
| setNodes(nodes); |
| setEdges(edges); |
| } |
| }, |
| [setEdges, setNodes], |
| ); |
| return setGraphInfo; |
| }; |
|
|
| export const useFetchDataOnMount = () => { |
| const { loading, data, refetch } = useFetchFlow(); |
| const setGraphInfo = useSetGraphInfo(); |
|
|
| useEffect(() => { |
| setGraphInfo(data?.dsl?.graph ?? ({} as IGraph)); |
| }, [setGraphInfo, data]); |
|
|
| useWatchGraphChange(); |
|
|
| useEffect(() => { |
| refetch(); |
| }, [refetch]); |
|
|
| return { loading, flowDetail: data }; |
| }; |
|
|
| export const useFlowIsFetching = () => { |
| return useIsFetching({ queryKey: ['flowDetail'] }) > 0; |
| }; |
|
|
| export const useSetLlmSetting = (form?: FormInstance) => { |
| const initialLlmSetting = undefined; |
|
|
| useEffect(() => { |
| const switchBoxValues = Object.keys(variableEnabledFieldMap).reduce< |
| Record<string, boolean> |
| >((pre, field) => { |
| pre[field] = |
| initialLlmSetting === undefined |
| ? true |
| : !!initialLlmSetting[ |
| variableEnabledFieldMap[ |
| field as keyof typeof variableEnabledFieldMap |
| ] as keyof Variable |
| ]; |
| return pre; |
| }, {}); |
| const otherValues = settledModelVariableMap[ModelVariableType.Precise]; |
| form?.setFieldsValue({ |
| ...switchBoxValues, |
| ...otherValues, |
| }); |
| }, [form, initialLlmSetting]); |
| }; |
|
|
| export const useValidateConnection = () => { |
| const { edges, getOperatorTypeFromId } = useGraphStore((state) => state); |
| |
| const isValidConnection = useCallback( |
| (connection: Connection) => { |
| |
| const isSelfConnected = connection.target === connection.source; |
|
|
| |
| const hasLine = edges.some( |
| (x) => x.source === connection.source && x.target === connection.target, |
| ); |
|
|
| const ret = |
| !isSelfConnected && |
| !hasLine && |
| RestrictedUpstreamMap[ |
| getOperatorTypeFromId(connection.source) as Operator |
| ]?.every((x) => x !== getOperatorTypeFromId(connection.target)); |
| return ret; |
| }, |
| [edges, getOperatorTypeFromId], |
| ); |
|
|
| return isValidConnection; |
| }; |
|
|
| export const useHandleNodeNameChange = (node?: Node) => { |
| const [name, setName] = useState<string>(''); |
| const { updateNodeName, nodes } = useGraphStore((state) => state); |
| const previousName = node?.data.name; |
| const id = node?.id; |
|
|
| const handleNameBlur = useCallback(() => { |
| const existsSameName = nodes.some((x) => x.data.name === name); |
| if (trim(name) === '' || existsSameName) { |
| if (existsSameName && previousName !== name) { |
| message.error('The name cannot be repeated'); |
| } |
| setName(previousName); |
| return; |
| } |
|
|
| if (id) { |
| updateNodeName(id, name); |
| } |
| }, [name, id, updateNodeName, previousName, nodes]); |
|
|
| const handleNameChange = useCallback((e: ChangeEvent<any>) => { |
| setName(e.target.value); |
| }, []); |
|
|
| useEffect(() => { |
| setName(previousName); |
| }, [previousName]); |
|
|
| return { name, handleNameBlur, handleNameChange }; |
| }; |
|
|
| export const useSaveGraphBeforeOpeningDebugDrawer = (show: () => void) => { |
| const { id } = useParams(); |
| const { saveGraph } = useSaveGraph(); |
| const { resetFlow } = useResetFlow(); |
| const { refetch } = useFetchFlow(); |
| const { send } = useSendMessageWithSse(api.runCanvas); |
| const handleRun = useCallback(async () => { |
| const saveRet = await saveGraph(); |
| if (saveRet?.retcode === 0) { |
| |
| const resetRet = await resetFlow(); |
| |
| if (resetRet?.retcode === 0) { |
| |
| const sendRet = await send({ id }); |
| if (receiveMessageError(sendRet)) { |
| message.error(sendRet?.data?.retmsg); |
| } else { |
| refetch(); |
| show(); |
| } |
| } |
| } |
| }, [saveGraph, resetFlow, id, send, show, refetch]); |
|
|
| return handleRun; |
| }; |
|
|
| export const useReplaceIdWithText = (output: unknown) => { |
| const getNode = useGraphStore((state) => state.getNode); |
|
|
| const getNameById = (id?: string) => { |
| return getNode(id)?.data.name; |
| }; |
|
|
| return replaceIdWithText(output, getNameById); |
| }; |
|
|
| |
| |
| |
| |
| export const useWatchNodeFormDataChange = () => { |
| const { getNode, nodes, setEdgesByNodeId } = useGraphStore((state) => state); |
|
|
| const buildCategorizeEdgesByFormData = useCallback( |
| (nodeId: string, form: ICategorizeForm) => { |
| |
| |
| |
| const categoryDescription = form.category_description; |
| const downstreamEdges = Object.keys(categoryDescription).reduce<Edge[]>( |
| (pre, sourceHandle) => { |
| const target = categoryDescription[sourceHandle]?.to; |
| if (target) { |
| pre.push({ |
| id: uuid(), |
| source: nodeId, |
| target, |
| sourceHandle, |
| }); |
| } |
|
|
| return pre; |
| }, |
| [], |
| ); |
|
|
| setEdgesByNodeId(nodeId, downstreamEdges); |
| }, |
| [setEdgesByNodeId], |
| ); |
|
|
| const buildRelevantEdgesByFormData = useCallback( |
| (nodeId: string, form: IRelevantForm) => { |
| const downstreamEdges = ['yes', 'no'].reduce<Edge[]>((pre, cur) => { |
| const target = form[cur as keyof IRelevantForm] as string; |
| if (target) { |
| pre.push({ id: uuid(), source: nodeId, target, sourceHandle: cur }); |
| } |
|
|
| return pre; |
| }, []); |
|
|
| setEdgesByNodeId(nodeId, downstreamEdges); |
| }, |
| [setEdgesByNodeId], |
| ); |
|
|
| const buildSwitchEdgesByFormData = useCallback( |
| (nodeId: string, form: ISwitchForm) => { |
| |
| |
| |
| const conditions = form.conditions; |
| const downstreamEdges = conditions.reduce<Edge[]>((pre, _, idx) => { |
| const target = conditions[idx]?.to; |
| if (target) { |
| pre.push({ |
| id: uuid(), |
| source: nodeId, |
| target, |
| sourceHandle: generateSwitchHandleText(idx), |
| }); |
| } |
|
|
| return pre; |
| }, []); |
|
|
| |
| const elseTo = form[SwitchElseTo]; |
| if (elseTo) { |
| downstreamEdges.push({ |
| id: uuid(), |
| source: nodeId, |
| target: elseTo, |
| sourceHandle: SwitchElseTo, |
| }); |
| } |
|
|
| setEdgesByNodeId(nodeId, downstreamEdges); |
| }, |
| [setEdgesByNodeId], |
| ); |
|
|
| useEffect(() => { |
| nodes.forEach((node) => { |
| const currentNode = getNode(node.id); |
| const form = currentNode?.data.form ?? {}; |
| const operatorType = currentNode?.data.label; |
| switch (operatorType) { |
| case Operator.Relevant: |
| buildRelevantEdgesByFormData(node.id, form as IRelevantForm); |
| break; |
| case Operator.Categorize: |
| buildCategorizeEdgesByFormData(node.id, form as ICategorizeForm); |
| break; |
| case Operator.Switch: |
| buildSwitchEdgesByFormData(node.id, form as ISwitchForm); |
| break; |
| default: |
| break; |
| } |
| }); |
| }, [ |
| nodes, |
| buildCategorizeEdgesByFormData, |
| getNode, |
| buildRelevantEdgesByFormData, |
| buildSwitchEdgesByFormData, |
| ]); |
| }; |
|
|
| |
| const ExcludedNodes = [ |
| Operator.Categorize, |
| Operator.Relevant, |
| Operator.Begin, |
| Operator.Answer, |
| ]; |
|
|
| export const useBuildComponentIdSelectOptions = (nodeId?: string) => { |
| const nodes = useGraphStore((state) => state.nodes); |
|
|
| const options = useMemo(() => { |
| return nodes |
| .filter( |
| (x) => |
| x.id !== nodeId && !ExcludedNodes.some((y) => y === x.data.label), |
| ) |
| .map((x) => ({ label: x.data.name, value: x.id })); |
| }, [nodes, nodeId]); |
|
|
| return options; |
| }; |
|
|