| import type { IProviderSetting } from '~/types/model'; |
| import { BaseProvider } from './base-provider'; |
| import type { ModelInfo, ProviderInfo } from './types'; |
| import * as providers from './registry'; |
| import { createScopedLogger } from '~/utils/logger'; |
|
|
| const logger = createScopedLogger('LLMManager'); |
| export class LLMManager { |
| private static _instance: LLMManager; |
| private _providers: Map<string, BaseProvider> = new Map(); |
| private _modelList: ModelInfo[] = []; |
| private readonly _env: any = {}; |
|
|
| private constructor(_env: Record<string, string>) { |
| this._registerProvidersFromDirectory(); |
| this._env = _env; |
| } |
|
|
| static getInstance(env: Record<string, string> = {}): LLMManager { |
| if (!LLMManager._instance) { |
| LLMManager._instance = new LLMManager(env); |
| } |
|
|
| return LLMManager._instance; |
| } |
| get env() { |
| return this._env; |
| } |
|
|
| private async _registerProvidersFromDirectory() { |
| try { |
| |
| |
| |
| |
|
|
| |
| for (const exportedItem of Object.values(providers)) { |
| if (typeof exportedItem === 'function' && exportedItem.prototype instanceof BaseProvider) { |
| const provider = new exportedItem(); |
|
|
| try { |
| this.registerProvider(provider); |
| } catch (error: any) { |
| logger.warn('Failed To Register Provider: ', provider.name, 'error:', error.message); |
| } |
| } |
| } |
| } catch (error) { |
| logger.error('Error registering providers:', error); |
| } |
| } |
|
|
| registerProvider(provider: BaseProvider) { |
| if (this._providers.has(provider.name)) { |
| logger.warn(`Provider ${provider.name} is already registered. Skipping.`); |
| return; |
| } |
|
|
| logger.info('Registering Provider: ', provider.name); |
| this._providers.set(provider.name, provider); |
| this._modelList = [...this._modelList, ...provider.staticModels]; |
| } |
|
|
| getProvider(name: string): BaseProvider | undefined { |
| return this._providers.get(name); |
| } |
|
|
| getAllProviders(): BaseProvider[] { |
| return Array.from(this._providers.values()); |
| } |
|
|
| getModelList(): ModelInfo[] { |
| return this._modelList; |
| } |
|
|
| async updateModelList(options: { |
| apiKeys?: Record<string, string>; |
| providerSettings?: Record<string, IProviderSetting>; |
| serverEnv?: Record<string, string>; |
| }): Promise<ModelInfo[]> { |
| const { apiKeys, providerSettings, serverEnv } = options; |
|
|
| let enabledProviders = Array.from(this._providers.values()).map((p) => p.name); |
|
|
| if (providerSettings && Object.keys(providerSettings).length > 0) { |
| enabledProviders = enabledProviders.filter((p) => providerSettings[p].enabled); |
| } |
|
|
| |
| const dynamicModels = await Promise.all( |
| Array.from(this._providers.values()) |
| .filter((provider) => enabledProviders.includes(provider.name)) |
| .filter( |
| (provider): provider is BaseProvider & Required<Pick<ProviderInfo, 'getDynamicModels'>> => |
| !!provider.getDynamicModels, |
| ) |
| .map(async (provider) => { |
| const cachedModels = provider.getModelsFromCache(options); |
|
|
| if (cachedModels) { |
| return cachedModels; |
| } |
|
|
| const dynamicModels = await provider |
| .getDynamicModels(apiKeys, providerSettings?.[provider.name], serverEnv) |
| .then((models) => { |
| logger.info(`Caching ${models.length} dynamic models for ${provider.name}`); |
| provider.storeDynamicModels(options, models); |
|
|
| return models; |
| }) |
| .catch((err) => { |
| logger.error(`Error getting dynamic models ${provider.name} :`, err); |
| return []; |
| }); |
|
|
| return dynamicModels; |
| }), |
| ); |
|
|
| |
| const modelList = [ |
| ...dynamicModels.flat(), |
| ...Array.from(this._providers.values()).flatMap((p) => p.staticModels || []), |
| ]; |
| this._modelList = modelList; |
|
|
| return modelList; |
| } |
| getStaticModelList() { |
| return [...this._providers.values()].flatMap((p) => p.staticModels || []); |
| } |
| async getModelListFromProvider( |
| providerArg: BaseProvider, |
| options: { |
| apiKeys?: Record<string, string>; |
| providerSettings?: Record<string, IProviderSetting>; |
| serverEnv?: Record<string, string>; |
| }, |
| ): Promise<ModelInfo[]> { |
| const provider = this._providers.get(providerArg.name); |
|
|
| if (!provider) { |
| throw new Error(`Provider ${providerArg.name} not found`); |
| } |
|
|
| const staticModels = provider.staticModels || []; |
|
|
| if (!provider.getDynamicModels) { |
| return staticModels; |
| } |
|
|
| const { apiKeys, providerSettings, serverEnv } = options; |
|
|
| const cachedModels = provider.getModelsFromCache({ |
| apiKeys, |
| providerSettings, |
| serverEnv, |
| }); |
|
|
| if (cachedModels) { |
| logger.info(`Found ${cachedModels.length} cached models for ${provider.name}`); |
| return [...cachedModels, ...staticModels]; |
| } |
|
|
| logger.info(`Getting dynamic models for ${provider.name}`); |
|
|
| const dynamicModels = await provider |
| .getDynamicModels?.(apiKeys, providerSettings?.[provider.name], serverEnv) |
| .then((models) => { |
| logger.info(`Got ${models.length} dynamic models for ${provider.name}`); |
| provider.storeDynamicModels(options, models); |
|
|
| return models; |
| }) |
| .catch((err) => { |
| logger.error(`Error getting dynamic models ${provider.name} :`, err); |
| return []; |
| }); |
|
|
| return [...dynamicModels, ...staticModels]; |
| } |
| getStaticModelListFromProvider(providerArg: BaseProvider) { |
| const provider = this._providers.get(providerArg.name); |
|
|
| if (!provider) { |
| throw new Error(`Provider ${providerArg.name} not found`); |
| } |
|
|
| return [...(provider.staticModels || [])]; |
| } |
|
|
| getDefaultProvider(): BaseProvider { |
| const firstProvider = this._providers.values().next().value; |
|
|
| if (!firstProvider) { |
| throw new Error('No providers registered'); |
| } |
|
|
| return firstProvider; |
| } |
| } |
|
|