Spaces:
Running
Running
| import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; | |
| import { streamText } from "ai"; | |
| import type { ChatMessage } from "gpt-tokenizer/GptEncoding"; | |
| import { | |
| getSettings, | |
| getTextGenerationState, | |
| updateResponse, | |
| updateTextGenerationState, | |
| } from "./pubSub"; | |
| import { | |
| canStartResponding, | |
| getDefaultChatCompletionCreateParamsStreaming, | |
| getDefaultChatMessages, | |
| getFormattedSearchResults, | |
| } from "./textGenerationUtilities"; | |
| let currentAbortController: AbortController | null = null; | |
| interface StreamOptions { | |
| messages: ChatMessage[]; | |
| onUpdate: (text: string) => void; | |
| } | |
| async function createOpenAiStream({ | |
| messages, | |
| onUpdate, | |
| }: StreamOptions): Promise<string> { | |
| const settings = getSettings(); | |
| const openaiProvider = createOpenAICompatible({ | |
| name: settings.openAiApiBaseUrl, | |
| baseURL: settings.openAiApiBaseUrl, | |
| apiKey: settings.openAiApiKey, | |
| }); | |
| const params = getDefaultChatCompletionCreateParamsStreaming(); | |
| try { | |
| currentAbortController = new AbortController(); | |
| const stream = streamText({ | |
| model: openaiProvider.chatModel(settings.openAiApiModel), | |
| messages: messages.map((msg) => ({ | |
| role: msg.role || "user", | |
| content: msg.content, | |
| })), | |
| maxOutputTokens: params.max_tokens, | |
| temperature: params.temperature, | |
| topP: params.top_p, | |
| frequencyPenalty: params.frequency_penalty, | |
| presencePenalty: params.presence_penalty, | |
| abortSignal: currentAbortController.signal, | |
| }); | |
| let text = ""; | |
| for await (const part of stream.fullStream) { | |
| if (getTextGenerationState() === "interrupted") { | |
| currentAbortController.abort(); | |
| throw new Error("Chat generation interrupted"); | |
| } | |
| if (part.type === "text-delta") { | |
| text += part.text; | |
| onUpdate(text); | |
| } | |
| } | |
| return text; | |
| } catch (error) { | |
| if ( | |
| getTextGenerationState() === "interrupted" || | |
| (error instanceof DOMException && error.name === "AbortError") | |
| ) { | |
| throw new Error("Chat generation interrupted"); | |
| } | |
| throw error; | |
| } finally { | |
| currentAbortController = null; | |
| } | |
| } | |
| export async function generateTextWithOpenAi() { | |
| await canStartResponding(); | |
| updateTextGenerationState("preparingToGenerate"); | |
| const messages = getDefaultChatMessages(getFormattedSearchResults(true)); | |
| await createOpenAiStream({ | |
| messages, | |
| onUpdate: (text) => { | |
| if (getTextGenerationState() !== "generating") { | |
| updateTextGenerationState("generating"); | |
| } | |
| updateResponse(text); | |
| }, | |
| }); | |
| } | |
| export async function generateChatWithOpenAi( | |
| messages: ChatMessage[], | |
| onUpdate: (partialResponse: string) => void, | |
| ) { | |
| return createOpenAiStream({ messages, onUpdate }); | |
| } | |