| import { useCallback } from 'react' |
|
|
| import { |
| Assistant, |
| ConversationalExtension, |
| ExtensionTypeEnum, |
| Thread, |
| ThreadAssistantInfo, |
| ThreadState, |
| Model, |
| AssistantTool, |
| } from '@janhq/core' |
| import { atom, useAtomValue, useSetAtom } from 'jotai' |
|
|
| import { fileUploadAtom } from '@/containers/Providers/Jotai' |
|
|
| import { generateThreadId } from '@/utils/thread' |
|
|
| import { useActiveModel } from './useActiveModel' |
| import useRecommendedModel from './useRecommendedModel' |
|
|
| import useSetActiveThread from './useSetActiveThread' |
|
|
| import { extensionManager } from '@/extension' |
|
|
| import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' |
| import { selectedModelAtom } from '@/helpers/atoms/Model.atom' |
| import { |
| threadsAtom, |
| threadStatesAtom, |
| updateThreadAtom, |
| setThreadModelParamsAtom, |
| isGeneratingResponseAtom, |
| } from '@/helpers/atoms/Thread.atom' |
|
|
| const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => { |
| |
| const currentState = { ...get(threadStatesAtom) } |
|
|
| const threadState: ThreadState = { |
| hasMore: false, |
| waitingForResponse: false, |
| lastMessage: undefined, |
| } |
| currentState[newThread.id] = threadState |
| set(threadStatesAtom, currentState) |
|
|
| |
| const threads = get(threadsAtom) |
| set(threadsAtom, [newThread, ...threads]) |
| }) |
|
|
| export const useCreateNewThread = () => { |
| const createNewThread = useSetAtom(createNewThreadAtom) |
| const { setActiveThread } = useSetActiveThread() |
| const updateThread = useSetAtom(updateThreadAtom) |
| const setFileUpload = useSetAtom(fileUploadAtom) |
| const setSelectedModel = useSetAtom(selectedModelAtom) |
| const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) |
|
|
| const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom) |
| const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) |
|
|
| const { recommendedModel, downloadedModels } = useRecommendedModel() |
|
|
| const threads = useAtomValue(threadsAtom) |
| const { stopInference } = useActiveModel() |
|
|
| const requestCreateNewThread = async ( |
| assistant: Assistant, |
| model?: Model | undefined |
| ) => { |
| |
| setIsGeneratingResponse(false) |
| stopInference() |
|
|
| const defaultModel = model ?? recommendedModel ?? downloadedModels[0] |
|
|
| if (!model) { |
| |
|
|
| |
| const lastMessage = threads[0]?.metadata?.lastMessage |
|
|
| if (!lastMessage && threads.length) { |
| return null |
| } |
| } |
|
|
| |
| const assistantTools: AssistantTool = { |
| type: 'retrieval', |
| enabled: true, |
| settings: assistant.tools && assistant.tools[0].settings, |
| } |
|
|
| const overriddenSettings = |
| defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048 |
| ? { ctx_len: 2048 } |
| : {} |
|
|
| const overriddenParameters = |
| defaultModel?.parameters.max_tokens && defaultModel.parameters.max_tokens |
| ? { max_tokens: 2048 } |
| : {} |
|
|
| const createdAt = Date.now() |
| const assistantInfo: ThreadAssistantInfo = { |
| assistant_id: assistant.id, |
| assistant_name: assistant.name, |
| tools: experimentalEnabled ? [assistantTools] : assistant.tools, |
| model: { |
| id: defaultModel?.id ?? '*', |
| settings: { ...defaultModel?.settings, ...overriddenSettings } ?? {}, |
| parameters: |
| { ...defaultModel?.parameters, ...overriddenParameters } ?? {}, |
| engine: defaultModel?.engine, |
| }, |
| instructions: assistant.instructions, |
| } |
|
|
| const threadId = generateThreadId(assistant.id) |
| const thread: Thread = { |
| id: threadId, |
| object: 'thread', |
| title: 'New Thread', |
| assistants: [assistantInfo], |
| created: createdAt, |
| updated: createdAt, |
| } |
|
|
| |
| |
| createNewThread(thread) |
|
|
| setSelectedModel(defaultModel) |
| setThreadModelParams(thread.id, { |
| ...defaultModel?.settings, |
| ...defaultModel?.parameters, |
| ...overriddenSettings, |
| }) |
|
|
| |
| setFileUpload([]) |
| |
| await updateThreadMetadata(thread) |
|
|
| setActiveThread(thread) |
| } |
|
|
| const updateThreadMetadata = useCallback( |
| async (thread: Thread) => { |
| updateThread(thread) |
|
|
| await extensionManager |
| .get<ConversationalExtension>(ExtensionTypeEnum.Conversational) |
| ?.saveThread(thread) |
| }, |
| [updateThread] |
| ) |
|
|
| return { |
| requestCreateNewThread, |
| updateThreadMetadata, |
| } |
| } |
|
|