NSFW-API / src /NsfwImageClassifier.ts
root
update
dfe9a5e
import * as tf from '@tensorflow/tfjs-node';
import * as nsfwjs from 'nsfwjs';
import {NSFWJS} from 'nsfwjs';
import {Tensor3D} from '@tensorflow/tfjs';
import {model as config} from 'app/config/model';
tf.enableProdMode();
export class NsfwImageClassifier {
#model?: NSFWJS;
async classify(imageBuffer: Buffer) {
const [model, image] = await Promise.all([
this.#getModel(),
tf.node.decodeImage(imageBuffer, 3),
]);
const predictions = await model.classify(image as Tensor3D);
image.dispose();
return this.#transformData(predictions);
}
async classifyMany(imagesBuffers: Buffer[]) {
return await Promise.all(imagesBuffers.map(buffer => this.classify(buffer)));
}
async #getModel(): Promise<NSFWJS> {
if (!this.#model) {
this.#model = await nsfwjs.load('file://model/', {size: config.size});
}
return this.#model;
}
#transformData(data: { className: string; probability: number }[]): Record<string, number> {
const result: Record<string, number> = {};
for (const item of data) {
result[item.className.toLowerCase()] = item.probability;
}
return result;
}
}