Spaces:
Paused
Paused
| import { DEFAULT_MODELS, ServiceProvider } from "../constant"; | |
| import { LLMModel } from "../client/api"; | |
| const CustomSeq = { | |
| val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts | |
| cache: new Map<string, number>(), | |
| next: (id: string) => { | |
| if (CustomSeq.cache.has(id)) { | |
| return CustomSeq.cache.get(id) as number; | |
| } else { | |
| let seq = CustomSeq.val++; | |
| CustomSeq.cache.set(id, seq); | |
| return seq; | |
| } | |
| }, | |
| }; | |
| const customProvider = (providerName: string) => ({ | |
| id: providerName.toLowerCase(), | |
| providerName: providerName, | |
| providerType: "custom", | |
| sorted: CustomSeq.next(providerName), | |
| }); | |
| /** | |
| * Sorts an array of models based on specified rules. | |
| * | |
| * First, sorted by provider; if the same, sorted by model | |
| */ | |
| const sortModelTable = (models: ReturnType<typeof collectModels>) => | |
| models.sort((a, b) => { | |
| if (a.provider && b.provider) { | |
| let cmp = a.provider.sorted - b.provider.sorted; | |
| return cmp === 0 ? a.sorted - b.sorted : cmp; | |
| } else { | |
| return a.sorted - b.sorted; | |
| } | |
| }); | |
| /** | |
| * get model name and provider from a formatted string, | |
| * e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google` | |
| * @param modelWithProvider model name with provider separated by last `@` char, | |
| * @returns [model, provider] tuple, if no `@` char found, provider is undefined | |
| */ | |
| export function getModelProvider(modelWithProvider: string): [string, string?] { | |
| const [model, provider] = modelWithProvider.split(/@(?!.*@)/); | |
| return [model, provider]; | |
| } | |
| export function collectModelTable( | |
| models: readonly LLMModel[], | |
| customModels: string, | |
| ) { | |
| const modelTable: Record< | |
| string, | |
| { | |
| available: boolean; | |
| name: string; | |
| displayName: string; | |
| sorted: number; | |
| provider?: LLMModel["provider"]; // Marked as optional | |
| isDefault?: boolean; | |
| } | |
| > = {}; | |
| // default models | |
| models.forEach((m) => { | |
| // using <modelName>@<providerId> as fullName | |
| modelTable[`${m.name}@${m?.provider?.id}`] = { | |
| ...m, | |
| displayName: m.name, // 'provider' is copied over if it exists | |
| }; | |
| }); | |
| // server custom models | |
| customModels | |
| .split(",") | |
| .filter((v) => !!v && v.length > 0) | |
| .forEach((m) => { | |
| const available = !m.startsWith("-"); | |
| const nameConfig = | |
| m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m; | |
| let [name, displayName] = nameConfig.split("="); | |
| // enable or disable all models | |
| if (name === "all") { | |
| Object.values(modelTable).forEach( | |
| (model) => (model.available = available), | |
| ); | |
| } else { | |
| // 1. find model by name, and set available value | |
| const [customModelName, customProviderName] = getModelProvider(name); | |
| let count = 0; | |
| for (const fullName in modelTable) { | |
| const [modelName, providerName] = getModelProvider(fullName); | |
| if ( | |
| customModelName == modelName && | |
| (customProviderName === undefined || | |
| customProviderName === providerName) | |
| ) { | |
| count += 1; | |
| modelTable[fullName]["available"] = available; | |
| // swap name and displayName for bytedance | |
| if (providerName === "bytedance") { | |
| [name, displayName] = [displayName, modelName]; | |
| modelTable[fullName]["name"] = name; | |
| } | |
| if (displayName) { | |
| modelTable[fullName]["displayName"] = displayName; | |
| } | |
| } | |
| } | |
| // 2. if model not exists, create new model with available value | |
| if (count === 0) { | |
| let [customModelName, customProviderName] = getModelProvider(name); | |
| const provider = customProvider( | |
| customProviderName || customModelName, | |
| ); | |
| // swap name and displayName for bytedance | |
| if (displayName && provider.providerName == "ByteDance") { | |
| [customModelName, displayName] = [displayName, customModelName]; | |
| } | |
| modelTable[`${customModelName}@${provider?.id}`] = { | |
| name: customModelName, | |
| displayName: displayName || customModelName, | |
| available, | |
| provider, // Use optional chaining | |
| sorted: CustomSeq.next(`${customModelName}@${provider?.id}`), | |
| }; | |
| } | |
| } | |
| }); | |
| return modelTable; | |
| } | |
| export function collectModelTableWithDefaultModel( | |
| models: readonly LLMModel[], | |
| customModels: string, | |
| defaultModel: string, | |
| ) { | |
| let modelTable = collectModelTable(models, customModels); | |
| if (defaultModel && defaultModel !== "") { | |
| if (defaultModel.includes("@")) { | |
| if (defaultModel in modelTable) { | |
| modelTable[defaultModel].isDefault = true; | |
| } | |
| } else { | |
| for (const key of Object.keys(modelTable)) { | |
| if ( | |
| modelTable[key].available && | |
| getModelProvider(key)[0] == defaultModel | |
| ) { | |
| modelTable[key].isDefault = true; | |
| break; | |
| } | |
| } | |
| } | |
| } | |
| return modelTable; | |
| } | |
| /** | |
| * Generate full model table. | |
| */ | |
| export function collectModels( | |
| models: readonly LLMModel[], | |
| customModels: string, | |
| ) { | |
| const modelTable = collectModelTable(models, customModels); | |
| let allModels = Object.values(modelTable); | |
| allModels = sortModelTable(allModels); | |
| return allModels; | |
| } | |
| export function collectModelsWithDefaultModel( | |
| models: readonly LLMModel[], | |
| customModels: string, | |
| defaultModel: string, | |
| ) { | |
| const modelTable = collectModelTableWithDefaultModel( | |
| models, | |
| customModels, | |
| defaultModel, | |
| ); | |
| let allModels = Object.values(modelTable); | |
| allModels = sortModelTable(allModels); | |
| return allModels; | |
| } | |
| export function isModelAvailableInServer( | |
| customModels: string, | |
| modelName: string, | |
| providerName: string, | |
| ) { | |
| const fullName = `${modelName}@${providerName}`; | |
| const modelTable = collectModelTable(DEFAULT_MODELS, customModels); | |
| return modelTable[fullName]?.available === false; | |
| } | |
| /** | |
| * Check if the model name is a GPT-4 related model | |
| * | |
| * @param modelName The name of the model to check | |
| * @returns True if the model is a GPT-4 related model (excluding gpt-4o-mini) | |
| */ | |
| export function isGPT4Model(modelName: string): boolean { | |
| return ( | |
| (modelName.startsWith("gpt-4") || | |
| modelName.startsWith("chatgpt-4o") || | |
| modelName.startsWith("o1")) && | |
| !modelName.startsWith("gpt-4o-mini") | |
| ); | |
| } | |
| /** | |
| * Checks if a model is not available on any of the specified providers in the server. | |
| * | |
| * @param {string} customModels - A string of custom models, comma-separated. | |
| * @param {string} modelName - The name of the model to check. | |
| * @param {string|string[]} providerNames - A string or array of provider names to check against. | |
| * | |
| * @returns {boolean} True if the model is not available on any of the specified providers, false otherwise. | |
| */ | |
| export function isModelNotavailableInServer( | |
| customModels: string, | |
| modelName: string, | |
| providerNames: string | string[], | |
| ): boolean { | |
| // Check DISABLE_GPT4 environment variable | |
| if ( | |
| process.env.DISABLE_GPT4 === "1" && | |
| isGPT4Model(modelName.toLowerCase()) | |
| ) { | |
| return true; | |
| } | |
| const modelTable = collectModelTable(DEFAULT_MODELS, customModels); | |
| const providerNamesArray = Array.isArray(providerNames) | |
| ? providerNames | |
| : [providerNames]; | |
| for (const providerName of providerNamesArray) { | |
| // if model provider is bytedance, use model config name to check if not avaliable | |
| if (providerName === ServiceProvider.ByteDance) { | |
| return !Object.values(modelTable).filter((v) => v.name === modelName)?.[0] | |
| ?.available; | |
| } | |
| const fullName = `${modelName}@${providerName.toLowerCase()}`; | |
| if (modelTable?.[fullName]?.available === true) return false; | |
| } | |
| return true; | |
| } | |