starry / backend /omr-service /src /routes /predictor.ts
k-l-lambda's picture
update: export from starry-refactor 2026-02-20 17:38
3a2977d
import { FastifyInstance } from 'fastify';
import { getPredictor } from '../lib/zero-client.js';
/**
* Synchronous predictor routes that accept base64 images
* POST /api/predict/:type with { images: base64[] }
*/
interface PredictorBody {
images: string[];
}
/** Recursively convert Buffer objects in 'image' fields to base64 data URLs. */
function convertBufferImages(obj: any): any {
if (obj == null || typeof obj !== 'object') return obj;
if (Buffer.isBuffer(obj)) return obj;
if (Array.isArray(obj)) return obj.map(convertBufferImages);
const result: any = {};
for (const [key, value] of Object.entries(obj)) {
if (key === 'image' && Buffer.isBuffer(value)) {
result[key] = `data:image/png;base64,${Buffer.from(value).toString('base64')}`;
} else if (typeof value === 'object' && value !== null) {
result[key] = convertBufferImages(value);
} else {
result[key] = value;
}
}
return result;
}
export default async function predictorRoutes(fastify: FastifyInstance) {
// Layout predictor - returns detection results
fastify.post<{ Body: PredictorBody }>('/predict/layout', async (request, reply) => {
const { images } = request.body;
if (!images || images.length === 0) {
reply.code(400);
return { code: 400, message: 'images array required' };
}
const buffers = images.map((img) => Buffer.from(img, 'base64'));
const client = getPredictor('layout');
const results = await client.request<any[]>('predictDetection', [buffers]);
return {
code: 0,
data: results.map((result) => convertBufferImages(result)),
};
});
// Gauge predictor - returns deformation-corrected images
fastify.post<{ Body: PredictorBody }>('/predict/gauge', async (request, reply) => {
const { images } = request.body;
if (!images || images.length === 0) {
reply.code(400);
return { code: 400, message: 'images array required' };
}
const buffers = images.map((img) => Buffer.from(img, 'base64'));
const client = getPredictor('gauge');
const results = await client.request<any[]>('predict', [buffers], { by_buffer: true });
return {
code: 0,
data: results.map((result) => convertBufferImages(result)),
};
});
// Mask predictor - returns staff mask images
fastify.post<{ Body: PredictorBody }>('/predict/mask', async (request, reply) => {
const { images } = request.body;
if (!images || images.length === 0) {
reply.code(400);
return { code: 400, message: 'images array required' };
}
const buffers = images.map((img) => Buffer.from(img, 'base64'));
const client = getPredictor('mask');
const results = await client.request<any[]>('predict', [buffers], { by_buffer: true });
return {
code: 0,
data: results.map((result) => convertBufferImages(result)),
};
});
// Semantic predictor - returns semantic point clusters
fastify.post<{ Body: PredictorBody }>('/predict/semantic', async (request, reply) => {
const { images } = request.body;
if (!images || images.length === 0) {
reply.code(400);
return { code: 400, message: 'images array required' };
}
const buffers = images.map((img) => Buffer.from(img, 'base64'));
const client = getPredictor('semantic');
const results = await client.request<any[]>('predict', [buffers]);
return { code: 0, data: results };
});
// Text predictor - combines loc + ocr for text detection
fastify.post<{ Body: PredictorBody }>('/predict/text', async (request, reply) => {
const { images } = request.body;
if (!images || images.length === 0) {
reply.code(400);
return { code: 400, message: 'images array required' };
}
const buffers = images.map((img) => Buffer.from(img, 'base64'));
// Run text localization
const locClient = getPredictor('loc');
const locResults = await locClient.request<any[]>('predict', [buffers]);
// Run OCR on found locations
const ocrClient = getPredictor('ocr');
const results = await Promise.all(
buffers.map(async (buffer, i) => {
const locations = locResults[i];
if (!locations || locations.length === 0) {
return { areas: [], imageSize: null };
}
const ocrResult = await ocrClient.request<any>('predict', [], {
buffers: [buffer],
location: locations,
});
return {
areas: ocrResult,
imageSize: [856, 600], // [height, width] tuple format expected by frontend
};
})
);
return { code: 0, data: results };
});
// Jianpu (simplified notation) predictor
fastify.post<{ Body: PredictorBody }>('/predict/jianpu', async (request, reply) => {
const { images } = request.body;
if (!images || images.length === 0) {
reply.code(400);
return { code: 400, message: 'images array required' };
}
const buffers = images.map((img) => Buffer.from(img, 'base64'));
const client = getPredictor('jianpu');
const results = await client.request<any[]>('predict', [buffers]);
return { code: 0, data: results };
});
// Viewport config - returns vision spec for semantic/gauge models
fastify.get('/predict/viewport', async () => {
return {
code: 0,
data: {
semantic: {
viewportHeight: 192,
viewportUnit: 8,
},
gauge: {
viewportHeight: 256,
viewportUnit: 8,
},
},
};
});
}