| import { getStore } from 'app/store/nanostores/store'; |
| import type { ModelIdentifierField } from 'features/nodes/types/common'; |
| import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common'; |
| import type { ModelIdentifier } from 'features/nodes/types/v2/common'; |
| import { modelsApi } from 'services/api/endpoints/models'; |
| import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types'; |
|
|
| |
| |
| |
| class ModelConfigNotFoundError extends Error { |
| |
| |
| |
| |
| constructor(message: string) { |
| super(message); |
| this.name = this.constructor.name; |
| } |
| } |
|
|
| |
| |
| |
| export class InvalidModelConfigError extends Error { |
| |
| |
| |
| |
| constructor(message: string) { |
| super(message); |
| this.name = this.constructor.name; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => { |
| const { dispatch } = getStore(); |
| try { |
| const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key)); |
| req.unsubscribe(); |
| return await req.unwrap(); |
| } catch { |
| throw new ModelConfigNotFoundError(`Unable to retrieve model config for key ${key}`); |
| } |
| }; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type: ModelType): Promise<AnyModelConfig> => { |
| const { dispatch } = getStore(); |
| try { |
| const req = dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type })); |
| req.unsubscribe(); |
| return await req.unwrap(); |
| } catch { |
| throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`); |
| } |
| }; |
|
|
| |
| |
| |
| |
| |
| |
| export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise<AnyModelConfig> => { |
| try { |
| return await fetchModelConfig(identifier.key); |
| } catch { |
| try { |
| return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type); |
| } catch { |
| throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`); |
| } |
| } |
| }; |
|
|
| |
| |
| |
| |
| |
| |
| |
| export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>( |
| key: string, |
| typeGuard: (config: AnyModelConfig) => config is T |
| ) => { |
| const modelConfig = await fetchModelConfig(key); |
| if (!typeGuard(modelConfig)) { |
| throw new InvalidModelConfigError(`Invalid model type for key ${key}: ${modelConfig.type}`); |
| } |
| return modelConfig; |
| }; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| export const getModelKey = async ( |
| modelIdentifier: unknown | ModelIdentifierField | ModelIdentifier, |
| type: ModelType, |
| message?: string |
| ): Promise<string> => { |
| if (isModelIdentifier(modelIdentifier)) { |
| try { |
| |
| return (await fetchModelConfig(modelIdentifier.key)).key; |
| } catch { |
| |
| return (await fetchModelConfigByAttrs(modelIdentifier.name, modelIdentifier.base, type)).key; |
| } |
| } else if (isModelIdentifierV2(modelIdentifier)) { |
| |
| return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key; |
| } |
| |
| throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); |
| }; |
|
|