Spaces:
Paused
Paused
| import express from 'express'; | |
| import logger from '../utils/logger'; | |
| import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; | |
| import type { Embeddings } from '@langchain/core/embeddings'; | |
| import { ChatOpenAI } from '@langchain/openai'; | |
| import { | |
| getAvailableChatModelProviders, | |
| getAvailableEmbeddingModelProviders, | |
| } from '../lib/providers'; | |
| import { searchHandlers } from '../websocket/messageHandler'; | |
| import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; | |
| import { MetaSearchAgentType } from '../search/metaSearchAgent'; | |
| const router = express.Router(); | |
| interface chatModel { | |
| provider: string; | |
| model: string; | |
| customOpenAIBaseURL?: string; | |
| customOpenAIKey?: string; | |
| } | |
| interface embeddingModel { | |
| provider: string; | |
| model: string; | |
| } | |
| interface ChatRequestBody { | |
| optimizationMode: 'speed' | 'balanced'; | |
| focusMode: string; | |
| chatModel?: chatModel; | |
| embeddingModel?: embeddingModel; | |
| query: string; | |
| history: Array<[string, string]>; | |
| } | |
| router.post('/', async (req, res) => { | |
| try { | |
| const body: ChatRequestBody = req.body; | |
| if (!body.focusMode || !body.query) { | |
| return res.status(400).json({ message: 'Missing focus mode or query' }); | |
| } | |
| body.history = body.history || []; | |
| body.optimizationMode = body.optimizationMode || 'balanced'; | |
| const history: BaseMessage[] = body.history.map((msg) => { | |
| if (msg[0] === 'human') { | |
| return new HumanMessage({ | |
| content: msg[1], | |
| }); | |
| } else { | |
| return new AIMessage({ | |
| content: msg[1], | |
| }); | |
| } | |
| }); | |
| const [chatModelProviders, embeddingModelProviders] = await Promise.all([ | |
| getAvailableChatModelProviders(), | |
| getAvailableEmbeddingModelProviders(), | |
| ]); | |
| const chatModelProvider = | |
| body.chatModel?.provider || Object.keys(chatModelProviders)[0]; | |
| const chatModel = | |
| body.chatModel?.model || | |
| Object.keys(chatModelProviders[chatModelProvider])[0]; | |
| const embeddingModelProvider = | |
| body.embeddingModel?.provider || Object.keys(embeddingModelProviders)[0]; | |
| const embeddingModel = | |
| body.embeddingModel?.model || | |
| Object.keys(embeddingModelProviders[embeddingModelProvider])[0]; | |
| let llm: BaseChatModel | undefined; | |
| let embeddings: Embeddings | undefined; | |
| if (body.chatModel?.provider === 'custom_openai') { | |
| if ( | |
| !body.chatModel?.customOpenAIBaseURL || | |
| !body.chatModel?.customOpenAIKey | |
| ) { | |
| return res | |
| .status(400) | |
| .json({ message: 'Missing custom OpenAI base URL or key' }); | |
| } | |
| llm = new ChatOpenAI({ | |
| modelName: body.chatModel.model, | |
| openAIApiKey: body.chatModel.customOpenAIKey, | |
| temperature: 0.7, | |
| configuration: { | |
| baseURL: body.chatModel.customOpenAIBaseURL, | |
| }, | |
| }) as unknown as BaseChatModel; | |
| } else if ( | |
| chatModelProviders[chatModelProvider] && | |
| chatModelProviders[chatModelProvider][chatModel] | |
| ) { | |
| llm = chatModelProviders[chatModelProvider][chatModel] | |
| .model as unknown as BaseChatModel | undefined; | |
| } | |
| if ( | |
| embeddingModelProviders[embeddingModelProvider] && | |
| embeddingModelProviders[embeddingModelProvider][embeddingModel] | |
| ) { | |
| embeddings = embeddingModelProviders[embeddingModelProvider][ | |
| embeddingModel | |
| ].model as Embeddings | undefined; | |
| } | |
| if (!llm || !embeddings) { | |
| return res.status(400).json({ message: 'Invalid model selected' }); | |
| } | |
| const searchHandler: MetaSearchAgentType = searchHandlers[body.focusMode]; | |
| if (!searchHandler) { | |
| return res.status(400).json({ message: 'Invalid focus mode' }); | |
| } | |
| const emitter = await searchHandler.searchAndAnswer( | |
| body.query, | |
| history, | |
| llm, | |
| embeddings, | |
| body.optimizationMode, | |
| [], | |
| ); | |
| let message = ''; | |
| let sources = []; | |
| emitter.on('data', (data) => { | |
| const parsedData = JSON.parse(data); | |
| if (parsedData.type === 'response') { | |
| message += parsedData.data; | |
| } else if (parsedData.type === 'sources') { | |
| sources = parsedData.data; | |
| } | |
| }); | |
| emitter.on('end', () => { | |
| res.status(200).json({ message, sources }); | |
| }); | |
| emitter.on('error', (data) => { | |
| const parsedData = JSON.parse(data); | |
| res.status(500).json({ message: parsedData.data }); | |
| }); | |
| } catch (err: any) { | |
| logger.error(`Error in getting search results: ${err.message}`); | |
| res.status(500).json({ message: 'An error has occurred.' }); | |
| } | |
| }); | |
| export default router; | |