starry / backend /libs /regulationBead.ts
k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
import * as starry from '../../src/starry';
import { Logger } from './ZeroClient';
import { SolutionStore, DefaultSolutionStore, SaveIssueMeasure, MeasureStatus } from './store';
interface BeadRegulationCounting {
cached: number;
simple: number;
computed: number;
tryTimes: number;
solved: number;
issue: number;
fatal: number;
}
interface RegulationBeadStat {
totalCost: number; // in milliseconds
pickerCost: number; // in milliseconds
measures: BeadRegulationCounting;
qualityScore: number;
}
interface RegulationBeadSummary {
scoreN: number;
totalCost: number; // in milliseconds
pickerCost: number; // in milliseconds
costPerMeasure: number | null; // in milliseconds
costPerTime: number | null; // in milliseconds
cached: number;
simple: number;
computed: number;
tryTimes: number;
solved: number;
issue: number;
fatal: number;
}
interface ProgressInfo {
pass: number;
remaining: number;
total: number;
}
interface RegulateBeadOption {
logger?: Logger;
pickers: starry.BeadPicker[];
solutionStore?: SolutionStore;
ignoreCache?: boolean;
freshOnly?: boolean;
onSaveIssueMeasure?: SaveIssueMeasure;
onProgress?: (measure: starry.SpartitoMeasure, evaluation: starry.MeasureEvaluation, better: boolean, progress: ProgressInfo) => void;
onPassStart?: (pass: number, conditionName: string, pendingCount: number) => void;
}
interface MeasureReord {
origin: starry.SpartitoMeasure;
current: starry.SpartitoMeasure;
evaluation?: starry.MeasureEvaluation;
baseQuality: number;
picker: starry.BeadPicker;
}
interface BeadSolverOptions {
stopLoss: number;
quotaMax: number;
quotaFactor: number;
ptFactor: number;
}
enum PendingCondition {
ErrorOnly,
NotFine,
Imperfect,
}
const isPending = (evaluation: starry.MeasureEvaluation, condition: PendingCondition) => {
switch (condition) {
case PendingCondition.ErrorOnly:
return evaluation.error;
case PendingCondition.Imperfect:
return !evaluation.perfect;
}
return !evaluation.fine;
};
type OnUpdate = (measure: starry.SpartitoMeasure, evaluation: starry.MeasureEvaluation, better: boolean) => void;
const solveMeasureRecords = async (
records: MeasureReord[],
onUpdate: OnUpdate,
stdout: NodeJS.WritableStream | null,
options: Partial<BeadSolverOptions>,
pendingCondition: PendingCondition = PendingCondition.NotFine,
pass: number = 0,
onProgress?: RegulateBeadOption['onProgress']
): Promise<number> => {
const pendingRecords = records.filter(({ evaluation }) => !evaluation || isPending(evaluation, pendingCondition));
stdout?.write('.'.repeat(pendingRecords.length));
stdout?.write('\b'.repeat(pendingRecords.length));
const total = pendingRecords.length;
let done = 0;
for (const record of pendingRecords) {
const measure = record.current.deepCopy();
measure.staffGroups = record.current.staffGroups;
const solution = await starry.beadSolver.solveMeasure(measure, { picker: record.picker, ...options });
measure.applySolution(solution);
const evaluation = starry.evaluateMeasure(measure);
const better =
!record.evaluation ||
evaluation.fine > record.evaluation.fine ||
(evaluation.qualityScore > record.evaluation.qualityScore && evaluation.fine === record.evaluation.fine);
if (better) {
record.evaluation = evaluation;
Object.assign(record.current, measure);
}
onUpdate(record.current, evaluation, better);
done++;
onProgress?.(record.current, evaluation, better, { pass, remaining: total - done, total });
}
if (pendingRecords.length) stdout?.write('\n');
return pendingRecords.length;
};
const regulateWithBeadSolver = async (
score: starry.Score,
{ logger, pickers, solutionStore = DefaultSolutionStore, ignoreCache, freshOnly, onSaveIssueMeasure, onProgress, onPassStart }: RegulateBeadOption
): Promise<RegulationBeadStat> => {
score.spartito = undefined;
score.assemble();
const spartito = score.makeSpartito();
spartito.measures.forEach((measure) => score.assignBackgroundForMeasure(measure));
const t0 = Date.now();
logger?.info(`[regulateWithBeadSolver] begin, measure total: ${spartito.measures.length}.`, ignoreCache ? 'ignoreCache' : '', freshOnly ? 'freshOnly' : '');
const records = spartito.measures
.filter((measure) => measure.events?.length && !measure.patched)
.map(
(measure) =>
({
origin: measure.deepCopy(),
current: measure,
evaluation: undefined,
baseQuality: 0,
} as MeasureReord)
);
// rectify time signature
for (const measure of spartito.measures.filter((measure) => measure.events?.length)) {
const picker = pickers.find((picker) => picker.n_seq > measure.events.length + 1);
if (picker) await starry.beadSolver.estimateMeasure(measure, picker);
}
spartito.rectifyTimeSignatures(logger as any);
// zero pickers' cost
pickers.forEach((picker) => (picker.cost = 0));
const counting = {
cached: 0,
simple: 0,
computed: 0,
tryTimes: 0,
solved: 0,
issue: 0,
fatal: 0,
};
logger?.info(`[regulateWithBeadSolver] measures estimation finished.`);
// apply solutions
if (solutionStore && !ignoreCache)
for (const record of records) {
const solution = await solutionStore.get(record.origin.regulationHash0);
if (solution) {
record.current.applySolution(solution);
++counting.cached;
record.evaluation = starry.evaluateMeasure(record.current);
record.baseQuality = record.evaluation.qualityScore;
}
}
logger?.info('[regulateWithBeadSolver]', `${counting.cached}/${records.length}`, 'solutions loaded.');
const stdout = logger ? null : process.stdout;
if (counting.cached) stdout?.write(`${counting.cached}c`);
records.forEach((record) => {
const picker = pickers.find((picker) => picker.n_seq > record.current.events.length + 1);
if (!picker) {
logger?.info(`[regulateWithBeadSolver] measure[${record.current.measureIndex}] size out of range:`, record.current.events.length);
} else record.picker = picker;
});
const pendingRecords = records.filter((record) => record.picker && (!record.evaluation || (!record.evaluation.fine && !freshOnly))) as (MeasureReord & {
evaluation: starry.MeasureEvaluation;
})[];
// solve by simple policy
pendingRecords.forEach((record) => {
const measure = record.current.deepCopy();
measure.staffGroups = record.current.staffGroups;
measure.regulate({ policy: 'simple' });
const evaluation = starry.evaluateMeasure(measure);
const better = !record.evaluation || evaluation.qualityScore > record.evaluation.qualityScore;
if (better) {
record.evaluation = evaluation;
Object.assign(record.current, measure);
if (evaluation.perfect) {
logger?.info(`[regulateWithBeadSolver] measure[${record.current.measureIndex}] regulated by simple policy.`);
++counting.simple;
}
}
});
counting.computed = pendingRecords.length - counting.simple;
if (counting.simple) stdout?.write(`${counting.simple}s`);
const onUpdate = (measure, evaluation, better) => {
logger?.info(
`[regulateWithBeadSolver] measure[${measure.measureIndex}/${spartito.measures.length}] regulated${
better ? '+' : '-'
}: ${evaluation.qualityScore.toFixed(3)}, ${evaluation.fine ? 'solved' : evaluation.error ? 'error' : 'issue'}, ${measure.regulationHash}`
);
stdout?.write(`\x1b[${evaluation.fine ? '32' : evaluation.error ? '31' : '33'}m${better ? '+' : '-'}\x1b[0m`);
};
// Global progress: total = all measures, remaining = non-fine measures across all passes
const totalMeasures = spartito.measures.length;
const computeRemaining = () => pendingRecords.filter((r) => !r.evaluation?.fine).length;
const wrappedOnProgress = onProgress
? (measure: starry.SpartitoMeasure, evaluation: starry.MeasureEvaluation, better: boolean, progress: ProgressInfo) => {
onProgress(measure, evaluation, better, { pass: progress.pass, remaining: computeRemaining(), total: totalMeasures });
}
: undefined;
onPassStart?.(1, 'Imperfect', computeRemaining());
counting.tryTimes += await solveMeasureRecords(
pendingRecords,
onUpdate,
stdout,
{ stopLoss: 0.05, quotaMax: 200, quotaFactor: 3, ptFactor: 1 },
PendingCondition.Imperfect,
1,
wrappedOnProgress
);
onPassStart?.(2, 'NotFine', computeRemaining());
counting.tryTimes += await solveMeasureRecords(
pendingRecords,
onUpdate,
stdout,
{ stopLoss: 0.08, quotaMax: 1000, quotaFactor: 20, ptFactor: 1.6 },
PendingCondition.NotFine,
2,
wrappedOnProgress
);
onPassStart?.(3, 'ErrorOnly', computeRemaining());
counting.tryTimes += await solveMeasureRecords(
pendingRecords,
onUpdate,
stdout,
{ stopLoss: 0.08, quotaMax: 1000, quotaFactor: 40, ptFactor: 3 },
PendingCondition.ErrorOnly,
3,
wrappedOnProgress
);
pendingRecords.forEach(({ evaluation, baseQuality, current, origin }) => {
if (evaluation.fine) ++counting.solved;
else if (evaluation.error) ++counting.fatal;
else ++counting.issue;
if (evaluation.qualityScore > baseQuality || !baseQuality) {
solutionStore.set(origin.regulationHash0, { ...current.asSolution(origin), priority: -current?.solutionStat?.loss! });
if (current.regulationHash !== origin.regulationHash0)
solutionStore.set(current.regulationHash, { ...current.asSolution(), priority: -current?.solutionStat?.loss! });
//console.log('better:', current.measureIndex, evaluation.qualityScore, baseQuality);
}
if (!evaluation.fine) {
onSaveIssueMeasure?.({
measureIndex: current.measureIndex,
measure: new starry.EditableMeasure(current),
status: evaluation.error ? MeasureStatus.Fatal : MeasureStatus.Issue,
});
}
});
const t1 = Date.now();
const pickerCost = pickers.reduce((cost, picker) => cost + picker.cost, 0);
const qualityScore = spartito.qualityScore;
const totalCost = t1 - t0;
logger?.info('[regulateWithBeadSolver] done in ', totalCost, 'ms, qualityScore:', qualityScore);
// zero 'cached' statistics for freshOnly mode
if (freshOnly) counting.cached = 0;
return {
totalCost: t1 - t0,
pickerCost,
measures: counting,
qualityScore,
};
};
const abstractRegulationBeadStats = (stats: RegulationBeadStat[]): RegulationBeadSummary => {
const { totalCost, pickerCost, measureN, timeN } = stats.reduce(
(sum, stat) => ({
totalCost: sum.totalCost + stat.totalCost,
pickerCost: sum.pickerCost + stat.pickerCost,
measureN: sum.measureN + stat.measures.computed,
timeN: sum.timeN + stat.measures.tryTimes,
}),
{
totalCost: 0,
pickerCost: 0,
measureN: 0,
timeN: 0,
}
);
const costPerMeasure = measureN > 0 ? totalCost / measureN : null;
const costPerTime = timeN > 0 ? totalCost / timeN : null;
const { cached, simple, computed, tryTimes, solved, issue, fatal } = stats.reduce(
(sum, stat) => ({
cached: sum.cached + stat.measures.cached,
simple: sum.simple + stat.measures.simple,
computed: sum.computed + stat.measures.computed,
tryTimes: sum.tryTimes + stat.measures.tryTimes,
solved: sum.solved + stat.measures.solved,
issue: sum.issue + stat.measures.issue,
fatal: sum.fatal + stat.measures.fatal,
}),
{ cached: 0, simple: 0, computed: 0, tryTimes: 0, solved: 0, issue: 0, fatal: 0 }
);
return {
scoreN: stats.length,
totalCost,
pickerCost,
costPerMeasure,
costPerTime,
cached,
simple,
computed,
tryTimes,
solved,
issue,
fatal,
};
};
export { regulateWithBeadSolver, abstractRegulationBeadStats, RegulationBeadStat, ProgressInfo };