starry / backend /libs /regulation.ts
k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
import * as starry from '../../src/starry';
import { PyClients } from './predictors';
import { Logger } from './ZeroClient';
import { SpartitoMeasure, EditableMeasure, evaluateMeasure } from '../../src/starry';
import { EquationPolicy } from '../../src/starry/spartitoMeasure';
import { genMeasureRectifications } from '../../src/starry/measureRectification';
import { SolutionStore, DefaultSolutionStore, SaveIssueMeasure } from './store';
export * from './regulationBead';
globalThis.btoa = globalThis.btoa || ((str) => Buffer.from(str, 'binary').toString('base64'));
const RECTIFICATION_SEARCH_ITERATIONS = parseInt(process.env.RECTIFICATION_SEARCH_ITERATIONS || '30');
const BASE_QUOTA_FACTOR = parseInt(process.env.BASE_QUOTA_FACTOR || '40');
const RECTIFICATION_QUOTA_FACTOR = parseInt(process.env.RECTIFICATION_QUOTA_FACTOR || '80');
const MATRIXH_INTERPOLATION_K = 0.9;
interface SolveMeasureOptions {
solver?: (...args: any[]) => any;
quotaMax?: number;
quotaFactor?: number;
solutionStore?: SolutionStore;
ignoreCache?: boolean;
logger?: Logger;
}
const computeQuota = (n: number, factor: number, limit: number) =>
Math.min(Math.ceil((n + 1) * factor * Math.log(n + 2)), Math.ceil(limit * Math.min(1, (24 / (n + 1)) ** 2)));
interface BaseRegulationStat {
cached: number;
computed: number;
solved: number;
}
async function solveMeasures(
measures: SpartitoMeasure[],
{ solver, quotaMax = 1000, quotaFactor = BASE_QUOTA_FACTOR, solutionStore = DefaultSolutionStore, ignoreCache = false, logger }: SolveMeasureOptions = {}
): Promise<BaseRegulationStat> {
let cached = 0;
let solved = 0;
logger?.info(`[solveMeasures] begin, measure total: ${measures.length}.`);
await Promise.all(
measures.map(async (measure) => {
if (!ignoreCache) {
const solution = await solutionStore.get(measure.regulationHash);
if (solution) {
measure.applySolution(solution);
++cached;
return;
}
}
const quota = computeQuota(measure.events.length, quotaFactor, quotaMax);
await measure.regulate({
policy: 'equations',
quota,
solver,
});
const stat = evaluateMeasure(measure);
if (!stat.error) solutionStore.set(measure.regulationHash0, { ...measure.asSolution(), priority: -measure?.solutionStat?.loss! });
if (stat.perfect) ++solved;
logger?.info(
`[solveMeasures] measure[${measure.measureIndex}/${measures.length}] regulated: ${stat.perfect ? 'solved' : stat.error ? 'error' : 'issue'}, ${
measure.regulationHash
}`
);
})
);
logger?.info(`[solveMeasures] ${cached}/${measures.length} cache hit, ${solved} solved.`);
return {
cached,
computed: measures.length - cached,
solved,
};
}
const solveMeasuresWithRectifications = async (
measure: SpartitoMeasure,
{ solver, quotaMax = 4000 }: SolveMeasureOptions
): Promise<starry.RegulationSolution> => {
let best = evaluateMeasure(measure);
let bestSolution: starry.RegulationSolution = measure.asSolution();
const quota = computeQuota(measure.events.length, RECTIFICATION_QUOTA_FACTOR, quotaMax);
let n_rec = 0;
// @ts-ignore
for (const rec of genMeasureRectifications(measure)) {
const solution = await EquationPolicy.regulateMeasureWithRectification(measure, rec, { solver, quota });
const testMeasure = measure.deepCopy() as SpartitoMeasure;
testMeasure.applySolution(solution);
const result = evaluateMeasure(testMeasure);
if (
result.perfect > best.perfect ||
result.error < best.error ||
(!result.error && result.perfect >= best.perfect && solution.priority! > bestSolution.priority!)
) {
best = result;
bestSolution = solution;
}
if (result.perfect) break;
++n_rec;
if (n_rec > RECTIFICATION_SEARCH_ITERATIONS) break;
}
return bestSolution;
};
interface RegulateWithTopoOption {
solutionStore: SolutionStore;
pyClients: PyClients;
solver: (...args: any[]) => any;
onSaveIssueMeasure?: SaveIssueMeasure;
}
interface RegulateMaybeWithTopoOption {
solutionStore: SolutionStore;
pyClients?: PyClients;
solver: (...args: any[]) => any;
onSaveIssueMeasure?: SaveIssueMeasure;
}
interface RegulateSimpleOption {
solutionStore: SolutionStore;
solver: (...args: any[]) => any;
logger?: Logger;
quotaMax?: number;
quotaFactor?: number;
}
interface TopoRegulationStat {
solved: number;
issue: number;
fatal: number;
}
async function doRegulateWithTopo(
score: starry.Score,
{ pyClients, solver, solutionStore = DefaultSolutionStore, onSaveIssueMeasure }: RegulateWithTopoOption
): Promise<TopoRegulationStat> {
pyClients.logger.info(`[RegulateWithTopo] regulate score: ${score.title}, measures: ${score.spartito!.measures.length}`);
const issueMeasures = score.spartito!.measures.filter((measure) => {
const stat = evaluateMeasure(measure);
return !stat.perfect;
});
pyClients.logger.info(`[RegulateWithTopo] basic issues: ${issueMeasures.length}`);
if (issueMeasures.length === 0) {
return {
solved: 0,
issue: 0,
fatal: 0,
};
}
const clusters = ([] as starry.EventCluster[]).concat(...issueMeasures.map((measure) => measure.createClusters()));
const results = await pyClients.predictScoreImages('topo', { clusters });
console.assert(results.length === clusters.length, 'prediction number mismatch:', clusters.length, results.length);
clusters.forEach((cluster, index) => {
const result = results[index];
console.assert(result, 'no result for cluster:', cluster.index);
cluster.assignPrediction(result);
});
issueMeasures.forEach((measure) => {
const cs = clusters.filter((c) => c.index === measure.measureIndex);
measure.applyClusters(cs);
// intepolate matrixH
const { matrixH } = EquationPolicy.estiamteMeasure(measure);
matrixH.forEach((row, i) =>
row.forEach((v, j) => {
measure.matrixH[i][j] = measure.matrixH[i][j] * MATRIXH_INTERPOLATION_K + v * (1 - MATRIXH_INTERPOLATION_K);
})
);
});
const solvedIndices: number[] = [];
const errorIndices: number[] = [];
// rectification search
await Promise.all(
issueMeasures.map(async (measure) => {
const hash = measure.regulationHash0;
const solution = await solveMeasuresWithRectifications(measure, { solver });
if (solution) {
measure.applySolution(solution);
solutionStore.set(hash, solution);
solutionStore.set(measure.regulationHash, measure.asSolution());
pyClients.logger.info(`[RegulateWithTopo] solutionStore set: ${measure.measureIndex}, ${hash}, ${measure.regulationHash}`);
}
const stat = evaluateMeasure(measure);
onSaveIssueMeasure?.({
measureIndex: measure.measureIndex,
measure: new EditableMeasure(measure),
status: stat.error ? 2 : 1,
});
if (stat.perfect) solvedIndices.push(measure.measureIndex);
else if (stat.error) errorIndices.push(measure.measureIndex);
})
);
const n_issues = issueMeasures.length - solvedIndices.length - errorIndices.length;
pyClients.logger.info(`[RegulateWithTopo] score: ${score.title}, solved/issue/fatal: ${solvedIndices.length}/${n_issues}/${errorIndices.length}`);
if (solvedIndices.length) pyClients.logger.info(`[RegulateWithTopo] solved measures: ${solvedIndices.join(', ')}`);
if (errorIndices.length) pyClients.logger.info(`[RegulateWithTopo] error measures: ${errorIndices.join(', ')}`);
return {
solved: solvedIndices.length,
issue: n_issues,
fatal: errorIndices.length,
};
}
interface RegulationStat {
baseCost: number; // in milliseconds
topoCost: number; // in milliseconds
baseMeasures: BaseRegulationStat;
topoMeasures?: TopoRegulationStat;
qualityScore: number;
}
const doRegulate = async (
score: starry.Score,
{ pyClients, solver, solutionStore = DefaultSolutionStore, onSaveIssueMeasure }: RegulateMaybeWithTopoOption
): Promise<RegulationStat> => {
pyClients?.logger?.info(`[doRegulate] score: ${score.title}`);
score.spartito = undefined;
score.assemble();
const spartito = score.makeSpartito();
spartito.measures.forEach((measure) => score.assignBackgroundForMeasure(measure));
const t0 = Date.now();
const baseMeasures = await solveMeasures(spartito.measures, { solver, quotaMax: 1000, solutionStore, logger: pyClients?.logger });
const t1 = Date.now();
const topoMeasures = pyClients ? await doRegulateWithTopo(score, { pyClients, solver, solutionStore, onSaveIssueMeasure }) : undefined;
const t2 = Date.now();
return {
baseCost: t1 - t0,
topoCost: t2 - t1,
baseMeasures,
topoMeasures,
qualityScore: spartito.qualityScore,
};
};
const doSimpleRegulate = async (
score: starry.Score,
{ solver, solutionStore = DefaultSolutionStore, logger, quotaMax = 240, quotaFactor = 16 }: RegulateSimpleOption
): Promise<void> => {
score.assemble();
const spartito = score.spartito || score.makeSpartito();
const measures = spartito.measures.filter((measure) => !measure.regulated);
await solveMeasures(measures, { solver, quotaMax, quotaFactor, solutionStore, logger });
console.assert(score.spartito?.regulated, 'doSimpleRegulate: regulation incomplete:', spartito.measures.filter((measure) => !measure.regulated).length);
};
const evaluateScoreQuality = async (score: starry.Score, options: RegulateSimpleOption): Promise<number | null> => {
if (!score.spartito?.regulated) await doSimpleRegulate(score, options);
return score.spartito!.regulated ? score.spartito!.qualityScore : null;
};
interface RegulationSummary {
scoreN: number;
baseCostTotal: number; // in milliseconds
topoCostTotal: number; // in milliseconds
baseCostPerMeasure: number | null; // in milliseconds
topoCostPerMeasure: number | null; // in milliseconds
cached: number;
baseComputed: number;
baseSolved: number;
topoSolved: number;
topoIssue: number;
topoFatal: number;
}
const abstractRegulationStats = (stats: RegulationStat[]): RegulationSummary => {
const { baseCostTotal, topoCostTotal, baseMeasures, topoMeasures } = stats.reduce(
(sum, stat) => ({
baseCostTotal: sum.baseCostTotal + stat.baseCost,
topoCostTotal: sum.topoCostTotal + stat.topoCost,
baseMeasures: sum.baseMeasures + stat.baseMeasures.computed,
topoMeasures: sum.topoMeasures + (stat.topoMeasures!.solved + stat.topoMeasures!.issue + stat.topoMeasures!.fatal),
}),
{
baseCostTotal: 0,
topoCostTotal: 0,
baseMeasures: 0,
topoMeasures: 0,
}
);
const baseCostPerMeasure = baseMeasures > 0 ? baseCostTotal / baseMeasures : null;
const topoCostPerMeasure = topoMeasures > 0 ? topoCostTotal / topoMeasures : null;
const { cached, baseComputed, baseSolved, topoSolved, topoIssue, topoFatal } = stats.reduce(
(sum, stat) => ({
cached: sum.cached + stat.baseMeasures.cached,
baseComputed: sum.baseComputed + stat.baseMeasures.computed,
baseSolved: sum.baseSolved + stat.baseMeasures.solved,
topoSolved: sum.topoSolved + stat.topoMeasures!.solved,
topoIssue: sum.topoIssue + stat.topoMeasures!.issue,
topoFatal: sum.topoFatal + stat.topoMeasures!.fatal,
}),
{ cached: 0, baseComputed: 0, baseSolved: 0, topoSolved: 0, topoIssue: 0, topoFatal: 0 }
);
return {
scoreN: stats.length,
baseCostTotal,
topoCostTotal,
baseCostPerMeasure,
topoCostPerMeasure,
cached,
baseComputed,
baseSolved,
topoSolved,
topoIssue,
topoFatal,
};
};
export { doRegulate, doSimpleRegulate, evaluateScoreQuality, abstractRegulationStats };