Spaces:
Running
Running
| import { useChat } from 'ai/react'; | |
| import { toast } from 'react-hot-toast'; | |
| import { useEffect, useRef, useState } from 'react'; | |
| import { ChatWithMessages, MessageUI, MessageUserInput } from '../types'; | |
| import { | |
| dbPostCreateMessage, | |
| dbPostUpdateMessageResponse, | |
| } from '../db/functions'; | |
| import { | |
| convertAssistantUIMessageToDBMessageResponse, | |
| convertDBMessageToUIMessage, | |
| } from '../utils/message'; | |
| const useVisionAgent = (chat: ChatWithMessages) => { | |
| const { messages: dbMessages, id, mediaUrl } = chat; | |
| const latestDbMessage = dbMessages[dbMessages.length - 1]; | |
| // Temporary solution for now while single we have to pass mediaUrl separately outside of the messages | |
| const currMediaUrl = useRef<string>(mediaUrl); | |
| const currMessageId = useRef<string>(latestDbMessage?.id); | |
| const { messages, append, isLoading, reload } = useChat({ | |
| api: '/api/vision-agent', | |
| streamMode: 'text', | |
| onResponse(response) { | |
| if (response.status !== 200) { | |
| toast.error(response.statusText); | |
| } | |
| }, | |
| onFinish: async message => { | |
| await dbPostUpdateMessageResponse( | |
| currMessageId.current, | |
| convertAssistantUIMessageToDBMessageResponse(message), | |
| ); | |
| }, | |
| sendExtraMessageFields: true, | |
| initialMessages: convertDBMessageToUIMessage(dbMessages), | |
| body: { | |
| mediaUrl: currMediaUrl.current, | |
| id, | |
| }, | |
| onError: err => { | |
| err && toast.error(err.message); | |
| }, | |
| }); | |
| /** | |
| * If case this is first time user navigated with init message, we need to reload the chat for the first response | |
| */ | |
| const once = useRef(true); | |
| useEffect(() => { | |
| if ( | |
| !isLoading && | |
| messages.length === 1 && | |
| messages[0].role === 'user' && | |
| once.current | |
| ) { | |
| once.current = false; | |
| reload(); | |
| } | |
| }, [isLoading, messages, reload]); | |
| return { | |
| messages: messages as MessageUI[], | |
| append: async (messageInput: MessageUserInput) => { | |
| currMediaUrl.current = messageInput.mediaUrl; | |
| append({ | |
| id, | |
| role: 'user', | |
| content: messageInput.prompt, | |
| // @ts-ignore valid when setting sendExtraMessageFields | |
| mediaUrl: messageInput.mediaUrl, | |
| }); | |
| const resp = await dbPostCreateMessage(id, messageInput); | |
| currMessageId.current = resp.id; | |
| }, | |
| reload, | |
| isLoading, | |
| }; | |
| }; | |
| export default useVisionAgent; | |