import * as starry from '../../src/starry'; import { PyClients } from './predictors'; import { Logger } from './ZeroClient'; import { SpartitoMeasure, EditableMeasure, evaluateMeasure } from '../../src/starry'; import { EquationPolicy } from '../../src/starry/spartitoMeasure'; import { genMeasureRectifications } from '../../src/starry/measureRectification'; import { SolutionStore, DefaultSolutionStore, SaveIssueMeasure } from './store'; export * from './regulationBead'; globalThis.btoa = globalThis.btoa || ((str) => Buffer.from(str, 'binary').toString('base64')); const RECTIFICATION_SEARCH_ITERATIONS = parseInt(process.env.RECTIFICATION_SEARCH_ITERATIONS || '30'); const BASE_QUOTA_FACTOR = parseInt(process.env.BASE_QUOTA_FACTOR || '40'); const RECTIFICATION_QUOTA_FACTOR = parseInt(process.env.RECTIFICATION_QUOTA_FACTOR || '80'); const MATRIXH_INTERPOLATION_K = 0.9; interface SolveMeasureOptions { solver?: (...args: any[]) => any; quotaMax?: number; quotaFactor?: number; solutionStore?: SolutionStore; ignoreCache?: boolean; logger?: Logger; } const computeQuota = (n: number, factor: number, limit: number) => Math.min(Math.ceil((n + 1) * factor * Math.log(n + 2)), Math.ceil(limit * Math.min(1, (24 / (n + 1)) ** 2))); interface BaseRegulationStat { cached: number; computed: number; solved: number; } async function solveMeasures( measures: SpartitoMeasure[], { solver, quotaMax = 1000, quotaFactor = BASE_QUOTA_FACTOR, solutionStore = DefaultSolutionStore, ignoreCache = false, logger }: SolveMeasureOptions = {} ): Promise { let cached = 0; let solved = 0; logger?.info(`[solveMeasures] begin, measure total: ${measures.length}.`); await Promise.all( measures.map(async (measure) => { if (!ignoreCache) { const solution = await solutionStore.get(measure.regulationHash); if (solution) { measure.applySolution(solution); ++cached; return; } } const quota = computeQuota(measure.events.length, quotaFactor, quotaMax); await measure.regulate({ policy: 'equations', quota, solver, }); const stat = evaluateMeasure(measure); if (!stat.error) solutionStore.set(measure.regulationHash0, { ...measure.asSolution(), priority: -measure?.solutionStat?.loss! }); if (stat.perfect) ++solved; logger?.info( `[solveMeasures] measure[${measure.measureIndex}/${measures.length}] regulated: ${stat.perfect ? 'solved' : stat.error ? 'error' : 'issue'}, ${ measure.regulationHash }` ); }) ); logger?.info(`[solveMeasures] ${cached}/${measures.length} cache hit, ${solved} solved.`); return { cached, computed: measures.length - cached, solved, }; } const solveMeasuresWithRectifications = async ( measure: SpartitoMeasure, { solver, quotaMax = 4000 }: SolveMeasureOptions ): Promise => { let best = evaluateMeasure(measure); let bestSolution: starry.RegulationSolution = measure.asSolution(); const quota = computeQuota(measure.events.length, RECTIFICATION_QUOTA_FACTOR, quotaMax); let n_rec = 0; // @ts-ignore for (const rec of genMeasureRectifications(measure)) { const solution = await EquationPolicy.regulateMeasureWithRectification(measure, rec, { solver, quota }); const testMeasure = measure.deepCopy() as SpartitoMeasure; testMeasure.applySolution(solution); const result = evaluateMeasure(testMeasure); if ( result.perfect > best.perfect || result.error < best.error || (!result.error && result.perfect >= best.perfect && solution.priority! > bestSolution.priority!) ) { best = result; bestSolution = solution; } if (result.perfect) break; ++n_rec; if (n_rec > RECTIFICATION_SEARCH_ITERATIONS) break; } return bestSolution; }; interface RegulateWithTopoOption { solutionStore: SolutionStore; pyClients: PyClients; solver: (...args: any[]) => any; onSaveIssueMeasure?: SaveIssueMeasure; } interface RegulateMaybeWithTopoOption { solutionStore: SolutionStore; pyClients?: PyClients; solver: (...args: any[]) => any; onSaveIssueMeasure?: SaveIssueMeasure; } interface RegulateSimpleOption { solutionStore: SolutionStore; solver: (...args: any[]) => any; logger?: Logger; quotaMax?: number; quotaFactor?: number; } interface TopoRegulationStat { solved: number; issue: number; fatal: number; } async function doRegulateWithTopo( score: starry.Score, { pyClients, solver, solutionStore = DefaultSolutionStore, onSaveIssueMeasure }: RegulateWithTopoOption ): Promise { pyClients.logger.info(`[RegulateWithTopo] regulate score: ${score.title}, measures: ${score.spartito!.measures.length}`); const issueMeasures = score.spartito!.measures.filter((measure) => { const stat = evaluateMeasure(measure); return !stat.perfect; }); pyClients.logger.info(`[RegulateWithTopo] basic issues: ${issueMeasures.length}`); if (issueMeasures.length === 0) { return { solved: 0, issue: 0, fatal: 0, }; } const clusters = ([] as starry.EventCluster[]).concat(...issueMeasures.map((measure) => measure.createClusters())); const results = await pyClients.predictScoreImages('topo', { clusters }); console.assert(results.length === clusters.length, 'prediction number mismatch:', clusters.length, results.length); clusters.forEach((cluster, index) => { const result = results[index]; console.assert(result, 'no result for cluster:', cluster.index); cluster.assignPrediction(result); }); issueMeasures.forEach((measure) => { const cs = clusters.filter((c) => c.index === measure.measureIndex); measure.applyClusters(cs); // intepolate matrixH const { matrixH } = EquationPolicy.estiamteMeasure(measure); matrixH.forEach((row, i) => row.forEach((v, j) => { measure.matrixH[i][j] = measure.matrixH[i][j] * MATRIXH_INTERPOLATION_K + v * (1 - MATRIXH_INTERPOLATION_K); }) ); }); const solvedIndices: number[] = []; const errorIndices: number[] = []; // rectification search await Promise.all( issueMeasures.map(async (measure) => { const hash = measure.regulationHash0; const solution = await solveMeasuresWithRectifications(measure, { solver }); if (solution) { measure.applySolution(solution); solutionStore.set(hash, solution); solutionStore.set(measure.regulationHash, measure.asSolution()); pyClients.logger.info(`[RegulateWithTopo] solutionStore set: ${measure.measureIndex}, ${hash}, ${measure.regulationHash}`); } const stat = evaluateMeasure(measure); onSaveIssueMeasure?.({ measureIndex: measure.measureIndex, measure: new EditableMeasure(measure), status: stat.error ? 2 : 1, }); if (stat.perfect) solvedIndices.push(measure.measureIndex); else if (stat.error) errorIndices.push(measure.measureIndex); }) ); const n_issues = issueMeasures.length - solvedIndices.length - errorIndices.length; pyClients.logger.info(`[RegulateWithTopo] score: ${score.title}, solved/issue/fatal: ${solvedIndices.length}/${n_issues}/${errorIndices.length}`); if (solvedIndices.length) pyClients.logger.info(`[RegulateWithTopo] solved measures: ${solvedIndices.join(', ')}`); if (errorIndices.length) pyClients.logger.info(`[RegulateWithTopo] error measures: ${errorIndices.join(', ')}`); return { solved: solvedIndices.length, issue: n_issues, fatal: errorIndices.length, }; } interface RegulationStat { baseCost: number; // in milliseconds topoCost: number; // in milliseconds baseMeasures: BaseRegulationStat; topoMeasures?: TopoRegulationStat; qualityScore: number; } const doRegulate = async ( score: starry.Score, { pyClients, solver, solutionStore = DefaultSolutionStore, onSaveIssueMeasure }: RegulateMaybeWithTopoOption ): Promise => { pyClients?.logger?.info(`[doRegulate] score: ${score.title}`); score.spartito = undefined; score.assemble(); const spartito = score.makeSpartito(); spartito.measures.forEach((measure) => score.assignBackgroundForMeasure(measure)); const t0 = Date.now(); const baseMeasures = await solveMeasures(spartito.measures, { solver, quotaMax: 1000, solutionStore, logger: pyClients?.logger }); const t1 = Date.now(); const topoMeasures = pyClients ? await doRegulateWithTopo(score, { pyClients, solver, solutionStore, onSaveIssueMeasure }) : undefined; const t2 = Date.now(); return { baseCost: t1 - t0, topoCost: t2 - t1, baseMeasures, topoMeasures, qualityScore: spartito.qualityScore, }; }; const doSimpleRegulate = async ( score: starry.Score, { solver, solutionStore = DefaultSolutionStore, logger, quotaMax = 240, quotaFactor = 16 }: RegulateSimpleOption ): Promise => { score.assemble(); const spartito = score.spartito || score.makeSpartito(); const measures = spartito.measures.filter((measure) => !measure.regulated); await solveMeasures(measures, { solver, quotaMax, quotaFactor, solutionStore, logger }); console.assert(score.spartito?.regulated, 'doSimpleRegulate: regulation incomplete:', spartito.measures.filter((measure) => !measure.regulated).length); }; const evaluateScoreQuality = async (score: starry.Score, options: RegulateSimpleOption): Promise => { if (!score.spartito?.regulated) await doSimpleRegulate(score, options); return score.spartito!.regulated ? score.spartito!.qualityScore : null; }; interface RegulationSummary { scoreN: number; baseCostTotal: number; // in milliseconds topoCostTotal: number; // in milliseconds baseCostPerMeasure: number | null; // in milliseconds topoCostPerMeasure: number | null; // in milliseconds cached: number; baseComputed: number; baseSolved: number; topoSolved: number; topoIssue: number; topoFatal: number; } const abstractRegulationStats = (stats: RegulationStat[]): RegulationSummary => { const { baseCostTotal, topoCostTotal, baseMeasures, topoMeasures } = stats.reduce( (sum, stat) => ({ baseCostTotal: sum.baseCostTotal + stat.baseCost, topoCostTotal: sum.topoCostTotal + stat.topoCost, baseMeasures: sum.baseMeasures + stat.baseMeasures.computed, topoMeasures: sum.topoMeasures + (stat.topoMeasures!.solved + stat.topoMeasures!.issue + stat.topoMeasures!.fatal), }), { baseCostTotal: 0, topoCostTotal: 0, baseMeasures: 0, topoMeasures: 0, } ); const baseCostPerMeasure = baseMeasures > 0 ? baseCostTotal / baseMeasures : null; const topoCostPerMeasure = topoMeasures > 0 ? topoCostTotal / topoMeasures : null; const { cached, baseComputed, baseSolved, topoSolved, topoIssue, topoFatal } = stats.reduce( (sum, stat) => ({ cached: sum.cached + stat.baseMeasures.cached, baseComputed: sum.baseComputed + stat.baseMeasures.computed, baseSolved: sum.baseSolved + stat.baseMeasures.solved, topoSolved: sum.topoSolved + stat.topoMeasures!.solved, topoIssue: sum.topoIssue + stat.topoMeasures!.issue, topoFatal: sum.topoFatal + stat.topoMeasures!.fatal, }), { cached: 0, baseComputed: 0, baseSolved: 0, topoSolved: 0, topoIssue: 0, topoFatal: 0 } ); return { scoreN: stats.length, baseCostTotal, topoCostTotal, baseCostPerMeasure, topoCostPerMeasure, cached, baseComputed, baseSolved, topoSolved, topoIssue, topoFatal, }; }; export { doRegulate, doSimpleRegulate, evaluateScoreQuality, abstractRegulationStats };