| import { |
| getModelStrings as getModelStringsState, |
| setModelStrings as setModelStringsState, |
| } from 'src/bootstrap/state.js' |
| import { logError } from '../log.js' |
| import { sequential } from '../sequential.js' |
| import { getInitialSettings } from '../settings/settings.js' |
| import { findFirstMatch, getBedrockInferenceProfiles } from './bedrock.js' |
| import { |
| ALL_MODEL_CONFIGS, |
| CANONICAL_ID_TO_KEY, |
| type CanonicalModelId, |
| type ModelKey, |
| } from './configs.js' |
| import { type APIProvider, getAPIProvider } from './providers.js' |
|
|
| |
| |
| |
| |
| export type ModelStrings = Record<ModelKey, string> |
|
|
| const MODEL_KEYS = Object.keys(ALL_MODEL_CONFIGS) as ModelKey[] |
|
|
| function getBuiltinModelStrings(provider: APIProvider): ModelStrings { |
| const out = {} as ModelStrings |
| for (const key of MODEL_KEYS) { |
| out[key] = ALL_MODEL_CONFIGS[key][provider] |
| } |
| return out |
| } |
|
|
| async function getBedrockModelStrings(): Promise<ModelStrings> { |
| const fallback = getBuiltinModelStrings('bedrock') |
| let profiles: string[] | undefined |
| try { |
| profiles = await getBedrockInferenceProfiles() |
| } catch (error) { |
| logError(error as Error) |
| return fallback |
| } |
| if (!profiles?.length) { |
| return fallback |
| } |
| |
| |
| |
| |
| const out = {} as ModelStrings |
| for (const key of MODEL_KEYS) { |
| const needle = ALL_MODEL_CONFIGS[key].firstParty |
| out[key] = findFirstMatch(profiles, needle) || fallback[key] |
| } |
| return out |
| } |
|
|
| |
| |
| |
| |
| |
| |
| function applyModelOverrides(ms: ModelStrings): ModelStrings { |
| const overrides = getInitialSettings().modelOverrides |
| if (!overrides) { |
| return ms |
| } |
| const out = { ...ms } |
| for (const [canonicalId, override] of Object.entries(overrides)) { |
| const key = CANONICAL_ID_TO_KEY[canonicalId as CanonicalModelId] |
| if (key && override) { |
| out[key] = override |
| } |
| } |
| return out |
| } |
|
|
| |
| |
| |
| |
| |
| |
| export function resolveOverriddenModel(modelId: string): string { |
| let overrides: Record<string, string> | undefined |
| try { |
| overrides = getInitialSettings().modelOverrides |
| } catch { |
| return modelId |
| } |
| if (!overrides) { |
| return modelId |
| } |
| for (const [canonicalId, override] of Object.entries(overrides)) { |
| if (override === modelId) { |
| return canonicalId |
| } |
| } |
| return modelId |
| } |
|
|
| const updateBedrockModelStrings = sequential(async () => { |
| if (getModelStringsState() !== null) { |
| |
| |
| |
| |
| return |
| } |
| try { |
| const ms = await getBedrockModelStrings() |
| setModelStringsState(ms) |
| } catch (error) { |
| logError(error as Error) |
| } |
| }) |
|
|
| function initModelStrings(): void { |
| const ms = getModelStringsState() |
| if (ms !== null) { |
| |
| return |
| } |
| |
| if (getAPIProvider() !== 'bedrock') { |
| setModelStringsState(getBuiltinModelStrings(getAPIProvider())) |
| return |
| } |
| |
| |
| |
| |
| void updateBedrockModelStrings() |
| } |
|
|
| export function getModelStrings(): ModelStrings { |
| const ms = getModelStringsState() |
| if (ms === null) { |
| initModelStrings() |
| |
| |
| return applyModelOverrides(getBuiltinModelStrings(getAPIProvider())) |
| } |
| return applyModelOverrides(ms) |
| } |
|
|
| |
| |
| |
| |
| |
| export async function ensureModelStringsInitialized(): Promise<void> { |
| const ms = getModelStringsState() |
| if (ms !== null) { |
| return |
| } |
|
|
| |
| if (getAPIProvider() !== 'bedrock') { |
| setModelStringsState(getBuiltinModelStrings(getAPIProvider())) |
| return |
| } |
|
|
| |
| await updateBedrockModelStrings() |
| } |
|
|