MiniSearch / client /modules /textGenerationWithOpenAi.ts
github-actions[bot]
Sync from https://github.com/felladrin/MiniSearch
16b7924
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
import { streamText } from "ai";
import {
listOpenAiCompatibleModels,
selectRandomModel,
} from "../../shared/openaiModels";
import { addLogEntry } from "./logEntries";
import {
getSettings,
getTextGenerationState,
updateResponse,
updateTextGenerationState,
} from "./pubSub";
import { sleep } from "./sleep";
import {
canStartResponding,
getDefaultChatCompletionCreateParamsStreaming,
getDefaultChatMessages,
getFormattedSearchResults,
} from "./textGenerationUtilities";
import type { ChatMessage } from "./types";
let currentAbortController: AbortController | null = null;
interface StreamOptions {
messages: ChatMessage[];
onUpdate: (text: string) => void;
}
async function createOpenAiStream({
messages,
onUpdate,
}: StreamOptions): Promise<string> {
const settings = getSettings();
const openaiProvider = createOpenAICompatible({
name: settings.openAiApiBaseUrl,
baseURL: settings.openAiApiBaseUrl,
apiKey: settings.openAiApiKey,
});
const params = getDefaultChatCompletionCreateParamsStreaming();
let effectiveModel = settings.openAiApiModel;
let availableModels: { id: string }[] = [];
if (!effectiveModel) {
try {
availableModels = await listOpenAiCompatibleModels(
settings.openAiApiBaseUrl,
settings.openAiApiKey,
);
const selectedModel = selectRandomModel(availableModels);
if (selectedModel) effectiveModel = selectedModel;
} catch (err) {
addLogEntry(
`Failed to list OpenAI models: ${err instanceof Error ? err.message : String(err)}`,
);
}
}
const maxRetries = 5;
const attemptedModels = new Set<string>();
let currentAttempt = 0;
const tryNextModel = async (): Promise<string> => {
if (currentAttempt >= maxRetries) {
throw new Error(
`Failed to generate text after ${maxRetries} retries with different models`,
);
}
if (effectiveModel) {
attemptedModels.add(effectiveModel);
}
currentAttempt++;
currentAbortController = new AbortController();
let shouldRetry = false;
try {
const stream = streamText({
model: openaiProvider.chatModel(effectiveModel),
messages,
maxOutputTokens: params.max_tokens,
temperature: params.temperature,
topP: params.top_p,
frequencyPenalty: params.frequency_penalty,
presencePenalty: params.presence_penalty,
abortSignal: currentAbortController.signal,
maxRetries: 0,
onError: async (error: unknown) => {
if (
getTextGenerationState() === "interrupted" ||
(error instanceof DOMException && error.name === "AbortError")
) {
throw new Error("Chat generation interrupted");
}
if (availableModels.length > 0 && currentAttempt < maxRetries) {
const nextModel = selectRandomModel(
availableModels,
attemptedModels,
);
if (nextModel) {
addLogEntry(
`Model "${effectiveModel}" failed, retrying with "${nextModel}" (Attempt ${currentAttempt}/${maxRetries})`,
);
effectiveModel = nextModel;
shouldRetry = true;
await sleep(100 * currentAttempt);
throw error;
}
}
throw error;
},
});
let text = "";
for await (const part of stream.fullStream) {
if (getTextGenerationState() === "interrupted") {
currentAbortController.abort();
throw new Error("Chat generation interrupted");
}
if (part.type === "text-delta") {
text += part.text;
onUpdate(text);
}
}
return text;
} catch (error) {
if (
getTextGenerationState() === "interrupted" ||
(error instanceof DOMException && error.name === "AbortError")
) {
throw new Error("Chat generation interrupted");
}
if (shouldRetry) {
return tryNextModel();
}
throw error;
} finally {
currentAbortController = null;
}
};
return tryNextModel();
}
export async function generateTextWithOpenAi() {
await canStartResponding();
updateTextGenerationState("preparingToGenerate");
const messages = getDefaultChatMessages(getFormattedSearchResults(true));
await createOpenAiStream({
messages,
onUpdate: (text) => {
if (getTextGenerationState() !== "generating") {
updateTextGenerationState("generating");
}
updateResponse(text);
},
});
}
export async function generateChatWithOpenAi(
messages: ChatMessage[],
onUpdate: (partialResponse: string) => void,
) {
return createOpenAiStream({ messages, onUpdate });
}