Spaces:
Sleeping
Sleeping
| import * as zmq from 'zeromq'; | |
| import { pack, unpack } from 'msgpackr'; | |
| import { config, PredictorType } from '../config.js'; | |
| interface ZeroResponse { | |
| code: number; | |
| msg: string; | |
| data?: any; | |
| } | |
| export class ZeroClient { | |
| private socket: zmq.Request; | |
| private address: string; | |
| private connected = false; | |
| constructor(address: string) { | |
| this.address = address; | |
| this.socket = new zmq.Request(); | |
| this.socket.receiveTimeout = 300000; // 5 minutes timeout | |
| this.socket.sendTimeout = 15000; | |
| } | |
| async connect(): Promise<void> { | |
| if (this.connected) return; | |
| await this.socket.connect(this.address); | |
| this.connected = true; | |
| } | |
| async request<T = any>(method: string, args?: any[], kwargs?: Record<string, any>): Promise<T> { | |
| if (!this.connected) { | |
| await this.connect(); | |
| } | |
| const msg: any = { method }; | |
| if (args) msg.args = args; | |
| if (kwargs) msg.kwargs = kwargs; | |
| const packed = pack(msg); | |
| await this.socket.send(packed); | |
| const [response] = await this.socket.receive(); | |
| const result = unpack(response) as ZeroResponse; | |
| if (result.code === 0) { | |
| return result.data as T; | |
| } else { | |
| throw new Error(result.msg || 'Predictor request failed'); | |
| } | |
| } | |
| async close(): Promise<void> { | |
| if (this.connected) { | |
| this.socket.close(); | |
| this.connected = false; | |
| } | |
| } | |
| } | |
| // Predictor client pool | |
| const clients = new Map<PredictorType, ZeroClient>(); | |
| export function getPredictor(type: PredictorType): ZeroClient { | |
| if (!clients.has(type)) { | |
| const address = config.predictors[type]; | |
| clients.set(type, new ZeroClient(address)); | |
| } | |
| return clients.get(type)!; | |
| } | |
| export async function closeAllPredictors(): Promise<void> { | |
| for (const client of clients.values()) { | |
| await client.close(); | |
| } | |
| clients.clear(); | |
| } | |
| // High-level predictor functions | |
| export interface LayoutResult { | |
| detection: { | |
| boxes: number[][]; | |
| labels: string[]; | |
| scores: number[]; | |
| }; | |
| theta: number; | |
| interval: number; | |
| sourceSize?: { | |
| width: number; | |
| height: number; | |
| }; | |
| } | |
| export interface GaugeResult { | |
| image: Buffer; | |
| } | |
| export interface MaskResult { | |
| image: Buffer; | |
| } | |
| export interface SemanticResult { | |
| clusters: any[]; | |
| } | |
| export interface LocResult { | |
| boxes: number[][]; | |
| scores: number[]; | |
| } | |
| export interface OcrResult { | |
| texts: string[]; | |
| scores: number[]; | |
| } | |
| // Layout: predictDetection([buffer]) | |
| export async function predictLayout(imageData: Buffer): Promise<LayoutResult> { | |
| const client = getPredictor('layout'); | |
| const results = await client.request<LayoutResult[]>('predictDetection', [[imageData]]); | |
| return results[0]; | |
| } | |
| // Gauge: predict([buffer], by_buffer=true) | |
| export async function predictGauge(imageData: Buffer): Promise<GaugeResult> { | |
| const client = getPredictor('gauge'); | |
| const results = await client.request<GaugeResult[]>('predict', [[imageData]], { by_buffer: true }); | |
| return results[0]; | |
| } | |
| // Mask: predict([buffer], by_buffer=true) | |
| export async function predictMask(imageData: Buffer): Promise<MaskResult> { | |
| const client = getPredictor('mask'); | |
| const results = await client.request<MaskResult[]>('predict', [[imageData]], { by_buffer: true }); | |
| return results[0]; | |
| } | |
| // Semantic: predict([buffer]) | |
| export async function predictSemantic(imageData: Buffer): Promise<SemanticResult> { | |
| const client = getPredictor('semantic'); | |
| const results = await client.request<SemanticResult[]>('predict', [[imageData]]); | |
| return results[0]; | |
| } | |
| // Loc (text localization): predict([buffer]) | |
| export async function predictLoc(imageData: Buffer): Promise<LocResult> { | |
| const client = getPredictor('loc'); | |
| const results = await client.request<LocResult[]>('predict', [[imageData]]); | |
| return results[0]; | |
| } | |
| // OCR: predict(buffers=[buffer], location=[...]) | |
| export async function predictOcr(imageData: Buffer, locations: any[]): Promise<OcrResult> { | |
| const client = getPredictor('ocr'); | |
| return client.request<OcrResult>('predict', [], { buffers: [imageData], location: locations }); | |
| } | |
| // Brackets: predict(buffers=[buffer]) | |
| export async function predictBrackets(imageData: Buffer): Promise<any> { | |
| const client = getPredictor('brackets'); | |
| return client.request('predict', [], { buffers: [imageData] }); | |
| } | |