File size: 4,135 Bytes
6f1c297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import * as zmq from 'zeromq';
import { pack, unpack } from 'msgpackr';
import { config, PredictorType } from '../config.js';

interface ZeroResponse {
	code: number;
	msg: string;
	data?: any;
}

export class ZeroClient {
	private socket: zmq.Request;
	private address: string;
	private connected = false;

	constructor(address: string) {
		this.address = address;
		this.socket = new zmq.Request();
		this.socket.receiveTimeout = 300000; // 5 minutes timeout
		this.socket.sendTimeout = 15000;
	}

	async connect(): Promise<void> {
		if (this.connected) return;
		await this.socket.connect(this.address);
		this.connected = true;
	}

	async request<T = any>(method: string, args?: any[], kwargs?: Record<string, any>): Promise<T> {
		if (!this.connected) {
			await this.connect();
		}

		const msg: any = { method };
		if (args) msg.args = args;
		if (kwargs) msg.kwargs = kwargs;

		const packed = pack(msg);
		await this.socket.send(packed);
		const [response] = await this.socket.receive();
		const result = unpack(response) as ZeroResponse;

		if (result.code === 0) {
			return result.data as T;
		} else {
			throw new Error(result.msg || 'Predictor request failed');
		}
	}

	async close(): Promise<void> {
		if (this.connected) {
			this.socket.close();
			this.connected = false;
		}
	}
}

// Predictor client pool
const clients = new Map<PredictorType, ZeroClient>();

export function getPredictor(type: PredictorType): ZeroClient {
	if (!clients.has(type)) {
		const address = config.predictors[type];
		clients.set(type, new ZeroClient(address));
	}
	return clients.get(type)!;
}

export async function closeAllPredictors(): Promise<void> {
	for (const client of clients.values()) {
		await client.close();
	}
	clients.clear();
}

// High-level predictor functions
export interface LayoutResult {
	detection: {
		boxes: number[][];
		labels: string[];
		scores: number[];
	};
	theta: number;
	interval: number;
	sourceSize?: {
		width: number;
		height: number;
	};
}

export interface GaugeResult {
	image: Buffer;
}

export interface MaskResult {
	image: Buffer;
}

export interface SemanticResult {
	clusters: any[];
}

export interface LocResult {
	boxes: number[][];
	scores: number[];
}

export interface OcrResult {
	texts: string[];
	scores: number[];
}

// Layout: predictDetection([buffer])
export async function predictLayout(imageData: Buffer): Promise<LayoutResult> {
	const client = getPredictor('layout');
	const results = await client.request<LayoutResult[]>('predictDetection', [[imageData]]);
	return results[0];
}

// Gauge: predict([buffer], by_buffer=true)
export async function predictGauge(imageData: Buffer): Promise<GaugeResult> {
	const client = getPredictor('gauge');
	const results = await client.request<GaugeResult[]>('predict', [[imageData]], { by_buffer: true });
	return results[0];
}

// Mask: predict([buffer], by_buffer=true)
export async function predictMask(imageData: Buffer): Promise<MaskResult> {
	const client = getPredictor('mask');
	const results = await client.request<MaskResult[]>('predict', [[imageData]], { by_buffer: true });
	return results[0];
}

// Semantic: predict([buffer])
export async function predictSemantic(imageData: Buffer): Promise<SemanticResult> {
	const client = getPredictor('semantic');
	const results = await client.request<SemanticResult[]>('predict', [[imageData]]);
	return results[0];
}

// Loc (text localization): predict([buffer])
export async function predictLoc(imageData: Buffer): Promise<LocResult> {
	const client = getPredictor('loc');
	const results = await client.request<LocResult[]>('predict', [[imageData]]);
	return results[0];
}

// OCR: predict(buffers=[buffer], location=[...])
export async function predictOcr(imageData: Buffer, locations: any[]): Promise<OcrResult> {
	const client = getPredictor('ocr');
	return client.request<OcrResult>('predict', [], { buffers: [imageData], location: locations });
}

// Brackets: predict(buffers=[buffer])
export async function predictBrackets(imageData: Buffer): Promise<any> {
	const client = getPredictor('brackets');
	return client.request('predict', [], { buffers: [imageData] });
}