Spaces:
Build error
Build error
| import { NextRequest, NextResponse } from "next/server"; | |
| import { getServerSideConfig } from "@/app/config/server"; | |
| import { auth } from "../../../auth"; | |
| import { ChatOpenAI } from "langchain/chat_models/openai"; | |
| import { BaseCallbackHandler } from "langchain/callbacks"; | |
| import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema"; | |
| import { BufferMemory, ChatMessageHistory } from "langchain/memory"; | |
| import { initializeAgentExecutorWithOptions } from "langchain/agents"; | |
| import { ACCESS_CODE_PREFIX } from "@/app/constant"; | |
| import { OpenAI } from "langchain/llms/openai"; | |
| import { OpenAIEmbeddings } from "langchain/embeddings/openai"; | |
| import * as langchainTools from "langchain/tools"; | |
| import { HttpGetTool } from "@/app/api/langchain-tools/http_get"; | |
| import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search"; | |
| import { WebBrowser } from "langchain/tools/webbrowser"; | |
| import { Calculator } from "langchain/tools/calculator"; | |
| import { DynamicTool, Tool } from "langchain/tools"; | |
| import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search"; | |
| import { GoogleSearch } from "@/app/api/langchain-tools/google_search"; | |
| import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv"; | |
| const serverConfig = getServerSideConfig(); | |
| interface RequestMessage { | |
| role: string; | |
| content: string; | |
| } | |
| interface RequestBody { | |
| messages: RequestMessage[]; | |
| model: string; | |
| stream?: boolean; | |
| temperature: number; | |
| presence_penalty?: number; | |
| frequency_penalty?: number; | |
| top_p?: number; | |
| baseUrl?: string; | |
| apiKey?: string; | |
| maxIterations: number; | |
| returnIntermediateSteps: boolean; | |
| useTools: (undefined | string)[]; | |
| } | |
| class ResponseBody { | |
| isSuccess: boolean = true; | |
| message!: string; | |
| isToolMessage: boolean = false; | |
| toolName?: string; | |
| } | |
| interface ToolInput { | |
| input: string; | |
| } | |
| async function handle(req: NextRequest) { | |
| if (req.method === "OPTIONS") { | |
| return NextResponse.json({ body: "OK" }, { status: 200 }); | |
| } | |
| try { | |
| const authResult = auth(req); | |
| if (authResult.error) { | |
| return NextResponse.json(authResult, { | |
| status: 401, | |
| }); | |
| } | |
| const encoder = new TextEncoder(); | |
| const transformStream = new TransformStream(); | |
| const writer = transformStream.writable.getWriter(); | |
| const reqBody: RequestBody = await req.json(); | |
| const authToken = req.headers.get("Authorization") ?? ""; | |
| const token = authToken.trim().replaceAll("Bearer ", "").trim(); | |
| const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX); | |
| let useTools = reqBody.useTools ?? []; | |
| let apiKey = serverConfig.apiKey; | |
| if (isOpenAiKey && token) { | |
| apiKey = token; | |
| } | |
| // support base url | |
| let baseUrl = "https://api.openai.com/v1"; | |
| if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; | |
| if ( | |
| reqBody.baseUrl?.startsWith("http://") || | |
| reqBody.baseUrl?.startsWith("https://") | |
| ) | |
| baseUrl = reqBody.baseUrl; | |
| if (!baseUrl.endsWith("/v1")) | |
| baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; | |
| console.log("[baseUrl]", baseUrl); | |
| const handler = BaseCallbackHandler.fromMethods({ | |
| async handleLLMNewToken(token: string) { | |
| // console.log("[Token]", token); | |
| if (token) { | |
| var response = new ResponseBody(); | |
| response.message = token; | |
| await writer.ready; | |
| await writer.write( | |
| encoder.encode(`data: ${JSON.stringify(response)}\n\n`), | |
| ); | |
| } | |
| }, | |
| async handleChainError(err, runId, parentRunId, tags) { | |
| console.log("[handleChainError]", err, "writer error"); | |
| var response = new ResponseBody(); | |
| response.isSuccess = false; | |
| response.message = err; | |
| await writer.ready; | |
| await writer.write( | |
| encoder.encode(`data: ${JSON.stringify(response)}\n\n`), | |
| ); | |
| await writer.close(); | |
| }, | |
| async handleChainEnd(outputs, runId, parentRunId, tags) { | |
| console.log("[handleChainEnd]"); | |
| await writer.ready; | |
| await writer.close(); | |
| }, | |
| async handleLLMEnd() { | |
| // await writer.ready; | |
| // await writer.close(); | |
| }, | |
| async handleLLMError(e: Error) { | |
| console.log("[handleLLMError]", e, "writer error"); | |
| var response = new ResponseBody(); | |
| response.isSuccess = false; | |
| response.message = e.message; | |
| await writer.ready; | |
| await writer.write( | |
| encoder.encode(`data: ${JSON.stringify(response)}\n\n`), | |
| ); | |
| await writer.close(); | |
| }, | |
| handleLLMStart(llm, _prompts: string[]) { | |
| // console.log("handleLLMStart: I'm the second handler!!", { llm }); | |
| }, | |
| handleChainStart(chain) { | |
| // console.log("handleChainStart: I'm the second handler!!", { chain }); | |
| }, | |
| async handleAgentAction(action) { | |
| try { | |
| console.log("[handleAgentAction]", action.tool); | |
| if (!reqBody.returnIntermediateSteps) return; | |
| var response = new ResponseBody(); | |
| response.isToolMessage = true; | |
| response.message = JSON.stringify(action.toolInput); | |
| response.toolName = action.tool; | |
| await writer.ready; | |
| await writer.write( | |
| encoder.encode(`data: ${JSON.stringify(response)}\n\n`), | |
| ); | |
| } catch (ex) { | |
| console.error("[handleAgentAction]", ex); | |
| var response = new ResponseBody(); | |
| response.isSuccess = false; | |
| response.message = (ex as Error).message; | |
| await writer.ready; | |
| await writer.write( | |
| encoder.encode(`data: ${JSON.stringify(response)}\n\n`), | |
| ); | |
| await writer.close(); | |
| } | |
| }, | |
| handleToolStart(tool, input) { | |
| console.log("[handleToolStart]", { tool }); | |
| }, | |
| async handleToolEnd(output, runId, parentRunId, tags) { | |
| console.log("[handleToolEnd]", { output, runId, parentRunId, tags }); | |
| }, | |
| handleAgentEnd(action, runId, parentRunId, tags) { | |
| console.log("[handleAgentEnd]"); | |
| }, | |
| }); | |
| let searchTool: Tool = new DuckDuckGo(); | |
| if (process.env.CHOOSE_SEARCH_ENGINE) { | |
| switch (process.env.CHOOSE_SEARCH_ENGINE) { | |
| case "google": | |
| searchTool = new GoogleSearch(); | |
| break; | |
| case "baidu": | |
| searchTool = new BaiduSearch(); | |
| break; | |
| } | |
| } | |
| if (process.env.BING_SEARCH_API_KEY) { | |
| let bingSearchTool = new langchainTools["BingSerpAPI"]( | |
| process.env.BING_SEARCH_API_KEY, | |
| ); | |
| searchTool = new DynamicTool({ | |
| name: "bing_search", | |
| description: bingSearchTool.description, | |
| func: async (input: string) => bingSearchTool.call(input), | |
| }); | |
| } | |
| if (process.env.SERPAPI_API_KEY) { | |
| let serpAPITool = new langchainTools["SerpAPI"]( | |
| process.env.SERPAPI_API_KEY, | |
| ); | |
| searchTool = new DynamicTool({ | |
| name: "google_search", | |
| description: serpAPITool.description, | |
| func: async (input: string) => serpAPITool.call(input), | |
| }); | |
| } | |
| const model = new OpenAI( | |
| { | |
| temperature: 0, | |
| modelName: reqBody.model, | |
| openAIApiKey: apiKey, | |
| }, | |
| { basePath: baseUrl }, | |
| ); | |
| const embeddings = new OpenAIEmbeddings( | |
| { | |
| openAIApiKey: apiKey, | |
| }, | |
| { basePath: baseUrl }, | |
| ); | |
| const tools = [ | |
| // new RequestsGetTool(), | |
| // new RequestsPostTool(), | |
| ]; | |
| const webBrowserTool = new WebBrowser({ model, embeddings }); | |
| const calculatorTool = new Calculator(); | |
| const arxivAPITool = new ArxivAPIWrapper(); | |
| if (useTools.includes("web-search")) tools.push(searchTool); | |
| if (useTools.includes(webBrowserTool.name)) tools.push(webBrowserTool); | |
| if (useTools.includes(calculatorTool.name)) tools.push(calculatorTool); | |
| if (useTools.includes(arxivAPITool.name)) tools.push(arxivAPITool); | |
| useTools.forEach((toolName) => { | |
| if (toolName) { | |
| var tool = langchainTools[ | |
| toolName as keyof typeof langchainTools | |
| ] as any; | |
| if (tool) { | |
| tools.push(new tool()); | |
| } | |
| } | |
| }); | |
| const pastMessages = new Array(); | |
| reqBody.messages | |
| .slice(0, reqBody.messages.length - 1) | |
| .forEach((message) => { | |
| if (message.role === "system") | |
| pastMessages.push(new SystemMessage(message.content)); | |
| if (message.role === "user") | |
| pastMessages.push(new HumanMessage(message.content)); | |
| if (message.role === "assistant") | |
| pastMessages.push(new AIMessage(message.content)); | |
| }); | |
| const memory = new BufferMemory({ | |
| memoryKey: "chat_history", | |
| returnMessages: true, | |
| inputKey: "input", | |
| outputKey: "output", | |
| chatHistory: new ChatMessageHistory(pastMessages), | |
| }); | |
| const llm = new ChatOpenAI( | |
| { | |
| modelName: reqBody.model, | |
| openAIApiKey: apiKey, | |
| temperature: reqBody.temperature, | |
| streaming: reqBody.stream, | |
| topP: reqBody.top_p, | |
| presencePenalty: reqBody.presence_penalty, | |
| frequencyPenalty: reqBody.frequency_penalty, | |
| }, | |
| { basePath: baseUrl }, | |
| ); | |
| const executor = await initializeAgentExecutorWithOptions(tools, llm, { | |
| agentType: "openai-functions", | |
| returnIntermediateSteps: reqBody.returnIntermediateSteps, | |
| maxIterations: reqBody.maxIterations, | |
| memory: memory, | |
| }); | |
| executor.call( | |
| { | |
| input: reqBody.messages.slice(-1)[0].content, | |
| }, | |
| [handler], | |
| ); | |
| console.log("returning response"); | |
| return new Response(transformStream.readable, { | |
| headers: { "Content-Type": "text/event-stream" }, | |
| }); | |
| } catch (e) { | |
| return new Response(JSON.stringify({ error: (e as any).message }), { | |
| status: 500, | |
| headers: { "Content-Type": "application/json" }, | |
| }); | |
| } | |
| } | |
| export const GET = handle; | |
| export const POST = handle; | |
| export const runtime = "edge"; |