Spaces:
Sleeping
Sleeping
| import * as starry from '../../src/starry'; | |
| import { PyClients } from './predictors'; | |
| import { Logger } from './ZeroClient'; | |
| import { SpartitoMeasure, EditableMeasure, evaluateMeasure } from '../../src/starry'; | |
| import { EquationPolicy } from '../../src/starry/spartitoMeasure'; | |
| import { genMeasureRectifications } from '../../src/starry/measureRectification'; | |
| import { SolutionStore, DefaultSolutionStore, SaveIssueMeasure } from './store'; | |
| export * from './regulationBead'; | |
| globalThis.btoa = globalThis.btoa || ((str) => Buffer.from(str, 'binary').toString('base64')); | |
| const RECTIFICATION_SEARCH_ITERATIONS = parseInt(process.env.RECTIFICATION_SEARCH_ITERATIONS || '30'); | |
| const BASE_QUOTA_FACTOR = parseInt(process.env.BASE_QUOTA_FACTOR || '40'); | |
| const RECTIFICATION_QUOTA_FACTOR = parseInt(process.env.RECTIFICATION_QUOTA_FACTOR || '80'); | |
| const MATRIXH_INTERPOLATION_K = 0.9; | |
| interface SolveMeasureOptions { | |
| solver?: (...args: any[]) => any; | |
| quotaMax?: number; | |
| quotaFactor?: number; | |
| solutionStore?: SolutionStore; | |
| ignoreCache?: boolean; | |
| logger?: Logger; | |
| } | |
| const computeQuota = (n: number, factor: number, limit: number) => | |
| Math.min(Math.ceil((n + 1) * factor * Math.log(n + 2)), Math.ceil(limit * Math.min(1, (24 / (n + 1)) ** 2))); | |
| interface BaseRegulationStat { | |
| cached: number; | |
| computed: number; | |
| solved: number; | |
| } | |
| async function solveMeasures( | |
| measures: SpartitoMeasure[], | |
| { solver, quotaMax = 1000, quotaFactor = BASE_QUOTA_FACTOR, solutionStore = DefaultSolutionStore, ignoreCache = false, logger }: SolveMeasureOptions = {} | |
| ): Promise<BaseRegulationStat> { | |
| let cached = 0; | |
| let solved = 0; | |
| logger?.info(`[solveMeasures] begin, measure total: ${measures.length}.`); | |
| await Promise.all( | |
| measures.map(async (measure) => { | |
| if (!ignoreCache) { | |
| const solution = await solutionStore.get(measure.regulationHash); | |
| if (solution) { | |
| measure.applySolution(solution); | |
| ++cached; | |
| return; | |
| } | |
| } | |
| const quota = computeQuota(measure.events.length, quotaFactor, quotaMax); | |
| await measure.regulate({ | |
| policy: 'equations', | |
| quota, | |
| solver, | |
| }); | |
| const stat = evaluateMeasure(measure); | |
| if (!stat.error) solutionStore.set(measure.regulationHash0, { ...measure.asSolution(), priority: -measure?.solutionStat?.loss! }); | |
| if (stat.perfect) ++solved; | |
| logger?.info( | |
| `[solveMeasures] measure[${measure.measureIndex}/${measures.length}] regulated: ${stat.perfect ? 'solved' : stat.error ? 'error' : 'issue'}, ${ | |
| measure.regulationHash | |
| }` | |
| ); | |
| }) | |
| ); | |
| logger?.info(`[solveMeasures] ${cached}/${measures.length} cache hit, ${solved} solved.`); | |
| return { | |
| cached, | |
| computed: measures.length - cached, | |
| solved, | |
| }; | |
| } | |
| const solveMeasuresWithRectifications = async ( | |
| measure: SpartitoMeasure, | |
| { solver, quotaMax = 4000 }: SolveMeasureOptions | |
| ): Promise<starry.RegulationSolution> => { | |
| let best = evaluateMeasure(measure); | |
| let bestSolution: starry.RegulationSolution = measure.asSolution(); | |
| const quota = computeQuota(measure.events.length, RECTIFICATION_QUOTA_FACTOR, quotaMax); | |
| let n_rec = 0; | |
| // @ts-ignore | |
| for (const rec of genMeasureRectifications(measure)) { | |
| const solution = await EquationPolicy.regulateMeasureWithRectification(measure, rec, { solver, quota }); | |
| const testMeasure = measure.deepCopy() as SpartitoMeasure; | |
| testMeasure.applySolution(solution); | |
| const result = evaluateMeasure(testMeasure); | |
| if ( | |
| result.perfect > best.perfect || | |
| result.error < best.error || | |
| (!result.error && result.perfect >= best.perfect && solution.priority! > bestSolution.priority!) | |
| ) { | |
| best = result; | |
| bestSolution = solution; | |
| } | |
| if (result.perfect) break; | |
| ++n_rec; | |
| if (n_rec > RECTIFICATION_SEARCH_ITERATIONS) break; | |
| } | |
| return bestSolution; | |
| }; | |
| interface RegulateWithTopoOption { | |
| solutionStore: SolutionStore; | |
| pyClients: PyClients; | |
| solver: (...args: any[]) => any; | |
| onSaveIssueMeasure?: SaveIssueMeasure; | |
| } | |
| interface RegulateMaybeWithTopoOption { | |
| solutionStore: SolutionStore; | |
| pyClients?: PyClients; | |
| solver: (...args: any[]) => any; | |
| onSaveIssueMeasure?: SaveIssueMeasure; | |
| } | |
| interface RegulateSimpleOption { | |
| solutionStore: SolutionStore; | |
| solver: (...args: any[]) => any; | |
| logger?: Logger; | |
| quotaMax?: number; | |
| quotaFactor?: number; | |
| } | |
| interface TopoRegulationStat { | |
| solved: number; | |
| issue: number; | |
| fatal: number; | |
| } | |
| async function doRegulateWithTopo( | |
| score: starry.Score, | |
| { pyClients, solver, solutionStore = DefaultSolutionStore, onSaveIssueMeasure }: RegulateWithTopoOption | |
| ): Promise<TopoRegulationStat> { | |
| pyClients.logger.info(`[RegulateWithTopo] regulate score: ${score.title}, measures: ${score.spartito!.measures.length}`); | |
| const issueMeasures = score.spartito!.measures.filter((measure) => { | |
| const stat = evaluateMeasure(measure); | |
| return !stat.perfect; | |
| }); | |
| pyClients.logger.info(`[RegulateWithTopo] basic issues: ${issueMeasures.length}`); | |
| if (issueMeasures.length === 0) { | |
| return { | |
| solved: 0, | |
| issue: 0, | |
| fatal: 0, | |
| }; | |
| } | |
| const clusters = ([] as starry.EventCluster[]).concat(...issueMeasures.map((measure) => measure.createClusters())); | |
| const results = await pyClients.predictScoreImages('topo', { clusters }); | |
| console.assert(results.length === clusters.length, 'prediction number mismatch:', clusters.length, results.length); | |
| clusters.forEach((cluster, index) => { | |
| const result = results[index]; | |
| console.assert(result, 'no result for cluster:', cluster.index); | |
| cluster.assignPrediction(result); | |
| }); | |
| issueMeasures.forEach((measure) => { | |
| const cs = clusters.filter((c) => c.index === measure.measureIndex); | |
| measure.applyClusters(cs); | |
| // intepolate matrixH | |
| const { matrixH } = EquationPolicy.estiamteMeasure(measure); | |
| matrixH.forEach((row, i) => | |
| row.forEach((v, j) => { | |
| measure.matrixH[i][j] = measure.matrixH[i][j] * MATRIXH_INTERPOLATION_K + v * (1 - MATRIXH_INTERPOLATION_K); | |
| }) | |
| ); | |
| }); | |
| const solvedIndices: number[] = []; | |
| const errorIndices: number[] = []; | |
| // rectification search | |
| await Promise.all( | |
| issueMeasures.map(async (measure) => { | |
| const hash = measure.regulationHash0; | |
| const solution = await solveMeasuresWithRectifications(measure, { solver }); | |
| if (solution) { | |
| measure.applySolution(solution); | |
| solutionStore.set(hash, solution); | |
| solutionStore.set(measure.regulationHash, measure.asSolution()); | |
| pyClients.logger.info(`[RegulateWithTopo] solutionStore set: ${measure.measureIndex}, ${hash}, ${measure.regulationHash}`); | |
| } | |
| const stat = evaluateMeasure(measure); | |
| onSaveIssueMeasure?.({ | |
| measureIndex: measure.measureIndex, | |
| measure: new EditableMeasure(measure), | |
| status: stat.error ? 2 : 1, | |
| }); | |
| if (stat.perfect) solvedIndices.push(measure.measureIndex); | |
| else if (stat.error) errorIndices.push(measure.measureIndex); | |
| }) | |
| ); | |
| const n_issues = issueMeasures.length - solvedIndices.length - errorIndices.length; | |
| pyClients.logger.info(`[RegulateWithTopo] score: ${score.title}, solved/issue/fatal: ${solvedIndices.length}/${n_issues}/${errorIndices.length}`); | |
| if (solvedIndices.length) pyClients.logger.info(`[RegulateWithTopo] solved measures: ${solvedIndices.join(', ')}`); | |
| if (errorIndices.length) pyClients.logger.info(`[RegulateWithTopo] error measures: ${errorIndices.join(', ')}`); | |
| return { | |
| solved: solvedIndices.length, | |
| issue: n_issues, | |
| fatal: errorIndices.length, | |
| }; | |
| } | |
| interface RegulationStat { | |
| baseCost: number; // in milliseconds | |
| topoCost: number; // in milliseconds | |
| baseMeasures: BaseRegulationStat; | |
| topoMeasures?: TopoRegulationStat; | |
| qualityScore: number; | |
| } | |
| const doRegulate = async ( | |
| score: starry.Score, | |
| { pyClients, solver, solutionStore = DefaultSolutionStore, onSaveIssueMeasure }: RegulateMaybeWithTopoOption | |
| ): Promise<RegulationStat> => { | |
| pyClients?.logger?.info(`[doRegulate] score: ${score.title}`); | |
| score.spartito = undefined; | |
| score.assemble(); | |
| const spartito = score.makeSpartito(); | |
| spartito.measures.forEach((measure) => score.assignBackgroundForMeasure(measure)); | |
| const t0 = Date.now(); | |
| const baseMeasures = await solveMeasures(spartito.measures, { solver, quotaMax: 1000, solutionStore, logger: pyClients?.logger }); | |
| const t1 = Date.now(); | |
| const topoMeasures = pyClients ? await doRegulateWithTopo(score, { pyClients, solver, solutionStore, onSaveIssueMeasure }) : undefined; | |
| const t2 = Date.now(); | |
| return { | |
| baseCost: t1 - t0, | |
| topoCost: t2 - t1, | |
| baseMeasures, | |
| topoMeasures, | |
| qualityScore: spartito.qualityScore, | |
| }; | |
| }; | |
| const doSimpleRegulate = async ( | |
| score: starry.Score, | |
| { solver, solutionStore = DefaultSolutionStore, logger, quotaMax = 240, quotaFactor = 16 }: RegulateSimpleOption | |
| ): Promise<void> => { | |
| score.assemble(); | |
| const spartito = score.spartito || score.makeSpartito(); | |
| const measures = spartito.measures.filter((measure) => !measure.regulated); | |
| await solveMeasures(measures, { solver, quotaMax, quotaFactor, solutionStore, logger }); | |
| console.assert(score.spartito?.regulated, 'doSimpleRegulate: regulation incomplete:', spartito.measures.filter((measure) => !measure.regulated).length); | |
| }; | |
| const evaluateScoreQuality = async (score: starry.Score, options: RegulateSimpleOption): Promise<number | null> => { | |
| if (!score.spartito?.regulated) await doSimpleRegulate(score, options); | |
| return score.spartito!.regulated ? score.spartito!.qualityScore : null; | |
| }; | |
| interface RegulationSummary { | |
| scoreN: number; | |
| baseCostTotal: number; // in milliseconds | |
| topoCostTotal: number; // in milliseconds | |
| baseCostPerMeasure: number | null; // in milliseconds | |
| topoCostPerMeasure: number | null; // in milliseconds | |
| cached: number; | |
| baseComputed: number; | |
| baseSolved: number; | |
| topoSolved: number; | |
| topoIssue: number; | |
| topoFatal: number; | |
| } | |
| const abstractRegulationStats = (stats: RegulationStat[]): RegulationSummary => { | |
| const { baseCostTotal, topoCostTotal, baseMeasures, topoMeasures } = stats.reduce( | |
| (sum, stat) => ({ | |
| baseCostTotal: sum.baseCostTotal + stat.baseCost, | |
| topoCostTotal: sum.topoCostTotal + stat.topoCost, | |
| baseMeasures: sum.baseMeasures + stat.baseMeasures.computed, | |
| topoMeasures: sum.topoMeasures + (stat.topoMeasures!.solved + stat.topoMeasures!.issue + stat.topoMeasures!.fatal), | |
| }), | |
| { | |
| baseCostTotal: 0, | |
| topoCostTotal: 0, | |
| baseMeasures: 0, | |
| topoMeasures: 0, | |
| } | |
| ); | |
| const baseCostPerMeasure = baseMeasures > 0 ? baseCostTotal / baseMeasures : null; | |
| const topoCostPerMeasure = topoMeasures > 0 ? topoCostTotal / topoMeasures : null; | |
| const { cached, baseComputed, baseSolved, topoSolved, topoIssue, topoFatal } = stats.reduce( | |
| (sum, stat) => ({ | |
| cached: sum.cached + stat.baseMeasures.cached, | |
| baseComputed: sum.baseComputed + stat.baseMeasures.computed, | |
| baseSolved: sum.baseSolved + stat.baseMeasures.solved, | |
| topoSolved: sum.topoSolved + stat.topoMeasures!.solved, | |
| topoIssue: sum.topoIssue + stat.topoMeasures!.issue, | |
| topoFatal: sum.topoFatal + stat.topoMeasures!.fatal, | |
| }), | |
| { cached: 0, baseComputed: 0, baseSolved: 0, topoSolved: 0, topoIssue: 0, topoFatal: 0 } | |
| ); | |
| return { | |
| scoreN: stats.length, | |
| baseCostTotal, | |
| topoCostTotal, | |
| baseCostPerMeasure, | |
| topoCostPerMeasure, | |
| cached, | |
| baseComputed, | |
| baseSolved, | |
| topoSolved, | |
| topoIssue, | |
| topoFatal, | |
| }; | |
| }; | |
| export { doRegulate, doSimpleRegulate, evaluateScoreQuality, abstractRegulationStats }; | |