cstria0106's picture
Upload folder using huggingface_hub
2e4ca50 verified
const modelRepo = window.SHIT_DETECTOR_MODEL_REPO;
const ortVersion = window.SHIT_DETECTOR_ORT_VERSION || "1.22.0";
const defaultBase = `https://huggingface.co/${modelRepo}/resolve/main/`;
const params = new URLSearchParams(window.location.search);
const modelBase = new URL(params.get("modelBase") || defaultBase, window.location.href)
.toString()
.replace(/\/?$/, "/");
const state = {
session: null,
metadata: null,
inputName: null,
};
const els = {
dropzone: document.getElementById("dropzone"),
fileInput: document.getElementById("fileInput"),
chooseButton: document.getElementById("chooseButton"),
preview: document.getElementById("preview"),
empty: document.getElementById("empty"),
label: document.getElementById("label"),
badge: document.getElementById("badge"),
shitProbability: document.getElementById("shitProbability"),
notShitProbability: document.getElementById("notShitProbability"),
confidence: document.getElementById("confidence"),
threshold: document.getElementById("threshold"),
runtime: document.getElementById("runtime"),
shitMeterFill: document.getElementById("shitMeterFill"),
notShitMeterFill: document.getElementById("notShitMeterFill"),
status: document.getElementById("status"),
};
function setStatus(text) {
els.status.textContent = text;
}
function pct(value) {
return `${(value * 100).toFixed(2)}%`;
}
function softmax(values) {
const max = Math.max(...values);
const exps = values.map((value) => Math.exp(value - max));
const sum = exps.reduce((acc, value) => acc + value, 0);
return exps.map((value) => value / sum);
}
function resolveAsset(name) {
return new URL(name, modelBase).toString();
}
async function loadModel() {
ort.env.wasm.numThreads = 1;
ort.env.wasm.wasmPaths = `https://cdn.jsdelivr.net/npm/onnxruntime-web@${ortVersion}/dist/`;
const metadataResponse = await fetch(resolveAsset("metadata.json"));
if (!metadataResponse.ok) {
throw new Error(`metadata.json fetch failed: ${metadataResponse.status}`);
}
state.metadata = await metadataResponse.json();
state.session = await ort.InferenceSession.create(resolveAsset("shit_detector.onnx"), {
executionProviders: ["wasm"],
graphOptimizationLevel: "all",
});
state.inputName = state.session.inputNames[0];
els.threshold.textContent = String(state.metadata.shit_threshold ?? 0.5);
setStatus("Model loaded. Choose an image.");
}
function imageToTensor(image) {
const size = Number(state.metadata.input_size);
const resizeShortEdge = Math.trunc(size * 1.14);
const scale = resizeShortEdge / Math.min(image.naturalWidth, image.naturalHeight);
const resizedWidth = Math.round(image.naturalWidth * scale);
const resizedHeight = Math.round(image.naturalHeight * scale);
const cropX = Math.floor((resizedWidth - size) / 2);
const cropY = Math.floor((resizedHeight - size) / 2);
const canvas = document.createElement("canvas");
canvas.width = size;
canvas.height = size;
const ctx = canvas.getContext("2d", { willReadFrequently: true });
ctx.drawImage(image, -cropX, -cropY, resizedWidth, resizedHeight);
const pixels = ctx.getImageData(0, 0, size, size).data;
const mean = state.metadata.mean;
const std = state.metadata.std;
const tensor = new Float32Array(3 * size * size);
for (let i = 0; i < size * size; i += 1) {
const pixel = i * 4;
tensor[i] = (pixels[pixel] / 255 - mean[0]) / std[0];
tensor[size * size + i] = (pixels[pixel + 1] / 255 - mean[1]) / std[1];
tensor[2 * size * size + i] = (pixels[pixel + 2] / 255 - mean[2]) / std[2];
}
return new ort.Tensor("float32", tensor, [1, 3, size, size]);
}
async function decodeImage(file) {
const url = URL.createObjectURL(file);
try {
const image = new Image();
image.decoding = "async";
image.src = url;
await image.decode();
return image;
} finally {
URL.revokeObjectURL(url);
}
}
async function classify(file) {
if (!state.session) {
setStatus("Model is still loading.");
return;
}
setStatus("Running inference...");
const image = await decodeImage(file);
els.preview.src = URL.createObjectURL(file);
els.preview.style.display = "block";
els.empty.style.display = "none";
const input = imageToTensor(image);
const startedAt = performance.now();
const outputs = await state.session.run({ [state.inputName]: input });
const elapsed = performance.now() - startedAt;
const logits = Array.from(outputs[state.session.outputNames[0]].data);
const scale = Number(state.metadata.logit_scale ?? 1.0);
const probabilities = softmax(logits.map((value) => value * scale));
const shitProbability = probabilities[0];
const notShitProbability = probabilities[1];
const confidence = Math.max(...probabilities);
const threshold = Number(state.metadata.shit_threshold ?? 0.5);
const isShit = shitProbability >= threshold;
els.label.textContent = isShit ? "shit" : "not_shit";
els.badge.textContent = isShit ? "positive" : "negative";
els.badge.className = isShit ? "badge" : "badge ok";
els.shitProbability.textContent = pct(shitProbability);
els.notShitProbability.textContent = pct(notShitProbability);
els.confidence.textContent = pct(confidence);
els.threshold.textContent = String(threshold);
els.shitMeterFill.style.width = pct(shitProbability);
els.notShitMeterFill.style.width = pct(notShitProbability);
setStatus(`Inference completed in ${elapsed.toFixed(1)} ms.`);
}
function handleFiles(files) {
const file = files?.[0];
if (!file) {
return;
}
void classify(file).catch((error) => {
console.error(error);
setStatus(error instanceof Error ? error.message : "Inference failed.");
});
}
els.chooseButton.addEventListener("click", () => els.fileInput.click());
els.fileInput.addEventListener("change", (event) => handleFiles(event.target.files));
for (const eventName of ["dragenter", "dragover"]) {
els.dropzone.addEventListener(eventName, (event) => {
event.preventDefault();
els.dropzone.classList.add("dragging");
});
}
for (const eventName of ["dragleave", "drop"]) {
els.dropzone.addEventListener(eventName, (event) => {
event.preventDefault();
els.dropzone.classList.remove("dragging");
});
}
els.dropzone.addEventListener("drop", (event) => handleFiles(event.dataTransfer.files));
loadModel().catch((error) => {
console.error(error);
setStatus(error instanceof Error ? error.message : "Failed to load model.");
});