| import { |
| useCallback, |
| useEffect, |
| useRef, |
| useState, |
| } from 'react' |
| import produce from 'immer' |
| import { isEqual } from 'lodash-es' |
| import type { ValueSelector, Var } from '../../types' |
| import { BlockEnum, VarType } from '../../types' |
| import { |
| useIsChatMode, useNodesReadOnly, |
| useWorkflow, |
| } from '../../hooks' |
| import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types' |
| import { |
| getMultipleRetrievalConfig, |
| getSelectedDatasetsMode, |
| } from './utils' |
| import { RETRIEVE_TYPE } from '@/types/app' |
| import { DATASET_DEFAULT } from '@/config' |
| import type { DataSet } from '@/models/datasets' |
| import { fetchDatasets } from '@/service/datasets' |
| import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' |
| import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' |
| import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' |
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' |
|
|
| const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { |
| const { nodesReadOnly: readOnly } = useNodesReadOnly() |
| const isChatMode = useIsChatMode() |
| const { getBeforeNodesInSameBranch } = useWorkflow() |
| const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start) |
| const startNodeId = startNode?.id |
| const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload) |
|
|
| const inputRef = useRef(inputs) |
|
|
| const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => { |
| const newInputs = produce(s, (draft) => { |
| if (s.retrieval_mode === RETRIEVE_TYPE.multiWay) |
| delete draft.single_retrieval_config |
| else |
| delete draft.multiple_retrieval_config |
| }) |
| |
| doSetInputs(newInputs) |
| inputRef.current = newInputs |
| }, [doSetInputs]) |
|
|
| const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { |
| const newInputs = produce(inputs, (draft) => { |
| draft.query_variable_selector = newVar as ValueSelector |
| }) |
| setInputs(newInputs) |
| }, [inputs, setInputs]) |
|
|
| const { |
| currentProvider, |
| currentModel, |
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) |
|
|
| const { |
| modelList: rerankModelList, |
| defaultModel: rerankDefaultModel, |
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) |
|
|
| const { |
| currentModel: currentRerankModel, |
| } = useCurrentProviderAndModel( |
| rerankModelList, |
| rerankDefaultModel |
| ? { |
| ...rerankDefaultModel, |
| provider: rerankDefaultModel.provider.provider, |
| } |
| : undefined, |
| ) |
|
|
| const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => { |
| const newInputs = produce(inputRef.current, (draft) => { |
| if (!draft.single_retrieval_config) { |
| draft.single_retrieval_config = { |
| model: { |
| provider: '', |
| name: '', |
| mode: '', |
| completion_params: {}, |
| }, |
| } |
| } |
| const draftModel = draft.single_retrieval_config?.model |
| draftModel.provider = model.provider |
| draftModel.name = model.modelId |
| draftModel.mode = model.mode! |
| }) |
| setInputs(newInputs) |
| }, [setInputs]) |
|
|
| const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => { |
| |
| if (isEqual(newParams, inputRef.current.single_retrieval_config?.model.completion_params)) |
| return |
|
|
| const newInputs = produce(inputRef.current, (draft) => { |
| if (!draft.single_retrieval_config) { |
| draft.single_retrieval_config = { |
| model: { |
| provider: '', |
| name: '', |
| mode: '', |
| completion_params: {}, |
| }, |
| } |
| } |
| draft.single_retrieval_config.model.completion_params = newParams |
| }) |
| setInputs(newInputs) |
| }, [setInputs]) |
|
|
| |
| useEffect(() => { |
| const inputs = inputRef.current |
| if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel) |
| return |
|
|
| if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider) |
| return |
|
|
| const newInput = produce(inputs, (draft) => { |
| if (currentProvider?.provider && currentModel?.model) { |
| const hasSetModel = draft.single_retrieval_config?.model?.provider |
| if (!hasSetModel) { |
| draft.single_retrieval_config = { |
| model: { |
| provider: currentProvider?.provider, |
| name: currentModel?.model, |
| mode: currentModel?.model_properties?.mode as string, |
| completion_params: {}, |
| }, |
| } |
| } |
| } |
| const multipleRetrievalConfig = draft.multiple_retrieval_config |
| draft.multiple_retrieval_config = { |
| top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k, |
| score_threshold: multipleRetrievalConfig?.score_threshold, |
| reranking_model: multipleRetrievalConfig?.reranking_model, |
| reranking_mode: multipleRetrievalConfig?.reranking_mode, |
| weights: multipleRetrievalConfig?.weights, |
| reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined |
| ? multipleRetrievalConfig.reranking_enable |
| : Boolean(currentRerankModel && rerankDefaultModel), |
| } |
| }) |
| setInputs(newInput) |
| |
| }, [currentProvider?.provider, currentModel, rerankDefaultModel]) |
| const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([]) |
| const [rerankModelOpen, setRerankModelOpen] = useState(false) |
| const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => { |
| const newInputs = produce(inputs, (draft) => { |
| draft.retrieval_mode = newMode |
| if (newMode === RETRIEVE_TYPE.multiWay) { |
| const multipleRetrievalConfig = draft.multiple_retrieval_config |
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) |
| } |
| else { |
| const hasSetModel = draft.single_retrieval_config?.model?.provider |
| if (!hasSetModel) { |
| draft.single_retrieval_config = { |
| model: { |
| provider: currentProvider?.provider || '', |
| name: currentModel?.model || '', |
| mode: currentModel?.model_properties?.mode as string, |
| completion_params: {}, |
| }, |
| } |
| } |
| } |
| }) |
| setInputs(newInputs) |
| }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel]) |
|
|
| const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { |
| const newInputs = produce(inputs, (draft) => { |
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) |
| }) |
| setInputs(newInputs) |
| }, [inputs, setInputs, selectedDatasets, currentRerankModel]) |
|
|
| |
| useEffect(() => { |
| (async () => { |
| const inputs = inputRef.current |
| const datasetIds = inputs.dataset_ids |
| if (datasetIds?.length > 0) { |
| const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } }) |
| setSelectedDatasets(dataSetsWithDetail) |
| } |
| const newInputs = produce(inputs, (draft) => { |
| draft.dataset_ids = datasetIds |
| }) |
| setInputs(newInputs) |
| })() |
| |
| }, []) |
|
|
| useEffect(() => { |
| const inputs = inputRef.current |
| let query_variable_selector: ValueSelector = inputs.query_variable_selector |
| if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId) |
| query_variable_selector = [startNodeId, 'sys.query'] |
|
|
| setInputs(produce(inputs, (draft) => { |
| draft.query_variable_selector = query_variable_selector |
| })) |
| |
| }, []) |
|
|
| const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => { |
| const { |
| mixtureHighQualityAndEconomic, |
| mixtureInternalAndExternal, |
| inconsistentEmbeddingModel, |
| allInternal, |
| allExternal, |
| } = getSelectedDatasetsMode(newDatasets) |
| const newInputs = produce(inputs, (draft) => { |
| draft.dataset_ids = newDatasets.map(d => d.id) |
|
|
| if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { |
| const multipleRetrievalConfig = draft.multiple_retrieval_config |
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel) |
| } |
| }) |
| setInputs(newInputs) |
| setSelectedDatasets(newDatasets) |
|
|
| if ( |
| (allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)) |
| || mixtureInternalAndExternal |
| || allExternal |
| ) |
| setRerankModelOpen(true) |
| }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) |
|
|
| const filterVar = useCallback((varPayload: Var) => { |
| return varPayload.type === VarType.string |
| }, []) |
|
|
| |
| const { |
| isShowSingleRun, |
| hideSingleRun, |
| runningStatus, |
| handleRun, |
| handleStop, |
| runInputData, |
| setRunInputData, |
| runResult, |
| } = useOneStepRun<KnowledgeRetrievalNodeType>({ |
| id, |
| data: inputs, |
| defaultRunInputData: { |
| query: '', |
| }, |
| }) |
|
|
| const query = runInputData.query |
| const setQuery = useCallback((newQuery: string) => { |
| setRunInputData({ |
| ...runInputData, |
| query: newQuery, |
| }) |
| }, [runInputData, setRunInputData]) |
|
|
| return { |
| readOnly, |
| inputs, |
| handleQueryVarChange, |
| filterVar, |
| handleRetrievalModeChange, |
| handleMultipleRetrievalConfigChange, |
| handleModelChanged, |
| handleCompletionParamsChange, |
| selectedDatasets: selectedDatasets.filter(d => d.name), |
| handleOnDatasetsChange, |
| isShowSingleRun, |
| hideSingleRun, |
| runningStatus, |
| handleRun, |
| handleStop, |
| query, |
| setQuery, |
| runResult, |
| rerankModelOpen, |
| setRerankModelOpen, |
| } |
| } |
|
|
| export default useConfig |
|
|