| import type { IProviderSetting } from '~/types/model';
|
| import { BaseProvider } from './base-provider';
|
| import type { ModelInfo, ProviderInfo } from './types';
|
| import * as providers from './registry';
|
| import { createScopedLogger } from '~/utils/logger';
|
|
|
| const logger = createScopedLogger('LLMManager');
|
| export class LLMManager {
|
| private static _instance: LLMManager;
|
| private _providers: Map<string, BaseProvider> = new Map();
|
| private _modelList: ModelInfo[] = [];
|
| private readonly _env: any = {};
|
|
|
| private constructor(_env: Record<string, string>) {
|
| this._registerProvidersFromDirectory();
|
| this._env = _env;
|
| }
|
|
|
| static getInstance(env: Record<string, string> = {}): LLMManager {
|
| if (!LLMManager._instance) {
|
| LLMManager._instance = new LLMManager(env);
|
| }
|
|
|
| return LLMManager._instance;
|
| }
|
| get env() {
|
| return this._env;
|
| }
|
|
|
| private async _registerProvidersFromDirectory() {
|
| try {
|
| |
| |
| |
|
|
|
|
|
|
| for (const exportedItem of Object.values(providers)) {
|
| if (typeof exportedItem === 'function' && exportedItem.prototype instanceof BaseProvider) {
|
| const provider = new exportedItem();
|
|
|
| try {
|
| this.registerProvider(provider);
|
| } catch (error: any) {
|
| logger.warn('Failed To Register Provider: ', provider.name, 'error:', error.message);
|
| }
|
| }
|
| }
|
| } catch (error) {
|
| logger.error('Error registering providers:', error);
|
| }
|
| }
|
|
|
| registerProvider(provider: BaseProvider) {
|
| if (this._providers.has(provider.name)) {
|
| logger.warn(`Provider ${provider.name} is already registered. Skipping.`);
|
| return;
|
| }
|
|
|
| logger.info('Registering Provider: ', provider.name);
|
| this._providers.set(provider.name, provider);
|
| this._modelList = [...this._modelList, ...provider.staticModels];
|
| }
|
|
|
| getProvider(name: string): BaseProvider | undefined {
|
| return this._providers.get(name);
|
| }
|
|
|
| getAllProviders(): BaseProvider[] {
|
| return Array.from(this._providers.values());
|
| }
|
|
|
| getModelList(): ModelInfo[] {
|
| return this._modelList;
|
| }
|
|
|
| async updateModelList(options: {
|
| apiKeys?: Record<string, string>;
|
| providerSettings?: Record<string, IProviderSetting>;
|
| serverEnv?: Record<string, string>;
|
| }): Promise<ModelInfo[]> {
|
| const { apiKeys, providerSettings, serverEnv } = options;
|
|
|
| let enabledProviders = Array.from(this._providers.values()).map((p) => p.name);
|
|
|
| if (providerSettings) {
|
| enabledProviders = enabledProviders.filter((p) => providerSettings[p].enabled);
|
| }
|
|
|
|
|
| const dynamicModels = await Promise.all(
|
| Array.from(this._providers.values())
|
| .filter((provider) => enabledProviders.includes(provider.name))
|
| .filter(
|
| (provider): provider is BaseProvider & Required<Pick<ProviderInfo, 'getDynamicModels'>> =>
|
| !!provider.getDynamicModels,
|
| )
|
| .map(async (provider) => {
|
| const cachedModels = provider.getModelsFromCache(options);
|
|
|
| if (cachedModels) {
|
| return cachedModels;
|
| }
|
|
|
| const dynamicModels = await provider
|
| .getDynamicModels(apiKeys, providerSettings?.[provider.name], serverEnv)
|
| .then((models) => {
|
| logger.info(`Caching ${models.length} dynamic models for ${provider.name}`);
|
| provider.storeDynamicModels(options, models);
|
|
|
| return models;
|
| })
|
| .catch((err) => {
|
| logger.error(`Error getting dynamic models ${provider.name} :`, err);
|
| return [];
|
| });
|
|
|
| return dynamicModels;
|
| }),
|
| );
|
|
|
|
|
| const modelList = [
|
| ...dynamicModels.flat(),
|
| ...Array.from(this._providers.values()).flatMap((p) => p.staticModels || []),
|
| ];
|
| this._modelList = modelList;
|
|
|
| return modelList;
|
| }
|
| getStaticModelList() {
|
| return [...this._providers.values()].flatMap((p) => p.staticModels || []);
|
| }
|
| async getModelListFromProvider(
|
| providerArg: BaseProvider,
|
| options: {
|
| apiKeys?: Record<string, string>;
|
| providerSettings?: Record<string, IProviderSetting>;
|
| serverEnv?: Record<string, string>;
|
| },
|
| ): Promise<ModelInfo[]> {
|
| const provider = this._providers.get(providerArg.name);
|
|
|
| if (!provider) {
|
| throw new Error(`Provider ${providerArg.name} not found`);
|
| }
|
|
|
| const staticModels = provider.staticModels || [];
|
|
|
| if (!provider.getDynamicModels) {
|
| return staticModels;
|
| }
|
|
|
| const { apiKeys, providerSettings, serverEnv } = options;
|
|
|
| const cachedModels = provider.getModelsFromCache({
|
| apiKeys,
|
| providerSettings,
|
| serverEnv,
|
| });
|
|
|
| if (cachedModels) {
|
| logger.info(`Found ${cachedModels.length} cached models for ${provider.name}`);
|
| return [...cachedModels, ...staticModels];
|
| }
|
|
|
| logger.info(`Getting dynamic models for ${provider.name}`);
|
|
|
| const dynamicModels = await provider
|
| .getDynamicModels?.(apiKeys, providerSettings?.[provider.name], serverEnv)
|
| .then((models) => {
|
| logger.info(`Got ${models.length} dynamic models for ${provider.name}`);
|
| provider.storeDynamicModels(options, models);
|
|
|
| return models;
|
| })
|
| .catch((err) => {
|
| logger.error(`Error getting dynamic models ${provider.name} :`, err);
|
| return [];
|
| });
|
|
|
| return [...dynamicModels, ...staticModels];
|
| }
|
| getStaticModelListFromProvider(providerArg: BaseProvider) {
|
| const provider = this._providers.get(providerArg.name);
|
|
|
| if (!provider) {
|
| throw new Error(`Provider ${providerArg.name} not found`);
|
| }
|
|
|
| return [...(provider.staticModels || [])];
|
| }
|
|
|
| getDefaultProvider(): BaseProvider {
|
| const firstProvider = this._providers.values().next().value;
|
|
|
| if (!firstProvider) {
|
| throw new Error('No providers registered');
|
| }
|
|
|
| return firstProvider;
|
| }
|
| }
|
|
|