| const modelRepo = window.SHIT_DETECTOR_MODEL_REPO; | |
| const ortVersion = window.SHIT_DETECTOR_ORT_VERSION || "1.22.0"; | |
| const defaultBase = `https://huggingface.co/${modelRepo}/resolve/main/`; | |
| const params = new URLSearchParams(window.location.search); | |
| const modelBase = new URL(params.get("modelBase") || defaultBase, window.location.href) | |
| .toString() | |
| .replace(/\/?$/, "/"); | |
| const state = { | |
| session: null, | |
| metadata: null, | |
| inputName: null, | |
| }; | |
| const els = { | |
| dropzone: document.getElementById("dropzone"), | |
| fileInput: document.getElementById("fileInput"), | |
| chooseButton: document.getElementById("chooseButton"), | |
| preview: document.getElementById("preview"), | |
| empty: document.getElementById("empty"), | |
| label: document.getElementById("label"), | |
| badge: document.getElementById("badge"), | |
| shitProbability: document.getElementById("shitProbability"), | |
| notShitProbability: document.getElementById("notShitProbability"), | |
| confidence: document.getElementById("confidence"), | |
| threshold: document.getElementById("threshold"), | |
| runtime: document.getElementById("runtime"), | |
| shitMeterFill: document.getElementById("shitMeterFill"), | |
| notShitMeterFill: document.getElementById("notShitMeterFill"), | |
| status: document.getElementById("status"), | |
| }; | |
| function setStatus(text) { | |
| els.status.textContent = text; | |
| } | |
| function pct(value) { | |
| return `${(value * 100).toFixed(2)}%`; | |
| } | |
| function softmax(values) { | |
| const max = Math.max(...values); | |
| const exps = values.map((value) => Math.exp(value - max)); | |
| const sum = exps.reduce((acc, value) => acc + value, 0); | |
| return exps.map((value) => value / sum); | |
| } | |
| function resolveAsset(name) { | |
| return new URL(name, modelBase).toString(); | |
| } | |
| async function loadModel() { | |
| ort.env.wasm.numThreads = 1; | |
| ort.env.wasm.wasmPaths = `https://cdn.jsdelivr.net/npm/onnxruntime-web@${ortVersion}/dist/`; | |
| const metadataResponse = await fetch(resolveAsset("metadata.json")); | |
| if (!metadataResponse.ok) { | |
| throw new Error(`metadata.json fetch failed: ${metadataResponse.status}`); | |
| } | |
| state.metadata = await metadataResponse.json(); | |
| state.session = await ort.InferenceSession.create(resolveAsset("shit_detector.onnx"), { | |
| executionProviders: ["wasm"], | |
| graphOptimizationLevel: "all", | |
| }); | |
| state.inputName = state.session.inputNames[0]; | |
| els.threshold.textContent = String(state.metadata.shit_threshold ?? 0.5); | |
| setStatus("Model loaded. Choose an image."); | |
| } | |
| function imageToTensor(image) { | |
| const size = Number(state.metadata.input_size); | |
| const resizeShortEdge = Math.trunc(size * 1.14); | |
| const scale = resizeShortEdge / Math.min(image.naturalWidth, image.naturalHeight); | |
| const resizedWidth = Math.round(image.naturalWidth * scale); | |
| const resizedHeight = Math.round(image.naturalHeight * scale); | |
| const cropX = Math.floor((resizedWidth - size) / 2); | |
| const cropY = Math.floor((resizedHeight - size) / 2); | |
| const canvas = document.createElement("canvas"); | |
| canvas.width = size; | |
| canvas.height = size; | |
| const ctx = canvas.getContext("2d", { willReadFrequently: true }); | |
| ctx.drawImage(image, -cropX, -cropY, resizedWidth, resizedHeight); | |
| const pixels = ctx.getImageData(0, 0, size, size).data; | |
| const mean = state.metadata.mean; | |
| const std = state.metadata.std; | |
| const tensor = new Float32Array(3 * size * size); | |
| for (let i = 0; i < size * size; i += 1) { | |
| const pixel = i * 4; | |
| tensor[i] = (pixels[pixel] / 255 - mean[0]) / std[0]; | |
| tensor[size * size + i] = (pixels[pixel + 1] / 255 - mean[1]) / std[1]; | |
| tensor[2 * size * size + i] = (pixels[pixel + 2] / 255 - mean[2]) / std[2]; | |
| } | |
| return new ort.Tensor("float32", tensor, [1, 3, size, size]); | |
| } | |
| async function decodeImage(file) { | |
| const url = URL.createObjectURL(file); | |
| try { | |
| const image = new Image(); | |
| image.decoding = "async"; | |
| image.src = url; | |
| await image.decode(); | |
| return image; | |
| } finally { | |
| URL.revokeObjectURL(url); | |
| } | |
| } | |
| async function classify(file) { | |
| if (!state.session) { | |
| setStatus("Model is still loading."); | |
| return; | |
| } | |
| setStatus("Running inference..."); | |
| const image = await decodeImage(file); | |
| els.preview.src = URL.createObjectURL(file); | |
| els.preview.style.display = "block"; | |
| els.empty.style.display = "none"; | |
| const input = imageToTensor(image); | |
| const startedAt = performance.now(); | |
| const outputs = await state.session.run({ [state.inputName]: input }); | |
| const elapsed = performance.now() - startedAt; | |
| const logits = Array.from(outputs[state.session.outputNames[0]].data); | |
| const scale = Number(state.metadata.logit_scale ?? 1.0); | |
| const probabilities = softmax(logits.map((value) => value * scale)); | |
| const shitProbability = probabilities[0]; | |
| const notShitProbability = probabilities[1]; | |
| const confidence = Math.max(...probabilities); | |
| const threshold = Number(state.metadata.shit_threshold ?? 0.5); | |
| const isShit = shitProbability >= threshold; | |
| els.label.textContent = isShit ? "shit" : "not_shit"; | |
| els.badge.textContent = isShit ? "positive" : "negative"; | |
| els.badge.className = isShit ? "badge" : "badge ok"; | |
| els.shitProbability.textContent = pct(shitProbability); | |
| els.notShitProbability.textContent = pct(notShitProbability); | |
| els.confidence.textContent = pct(confidence); | |
| els.threshold.textContent = String(threshold); | |
| els.shitMeterFill.style.width = pct(shitProbability); | |
| els.notShitMeterFill.style.width = pct(notShitProbability); | |
| setStatus(`Inference completed in ${elapsed.toFixed(1)} ms.`); | |
| } | |
| function handleFiles(files) { | |
| const file = files?.[0]; | |
| if (!file) { | |
| return; | |
| } | |
| void classify(file).catch((error) => { | |
| console.error(error); | |
| setStatus(error instanceof Error ? error.message : "Inference failed."); | |
| }); | |
| } | |
| els.chooseButton.addEventListener("click", () => els.fileInput.click()); | |
| els.fileInput.addEventListener("change", (event) => handleFiles(event.target.files)); | |
| for (const eventName of ["dragenter", "dragover"]) { | |
| els.dropzone.addEventListener(eventName, (event) => { | |
| event.preventDefault(); | |
| els.dropzone.classList.add("dragging"); | |
| }); | |
| } | |
| for (const eventName of ["dragleave", "drop"]) { | |
| els.dropzone.addEventListener(eventName, (event) => { | |
| event.preventDefault(); | |
| els.dropzone.classList.remove("dragging"); | |
| }); | |
| } | |
| els.dropzone.addEventListener("drop", (event) => handleFiles(event.dataTransfer.files)); | |
| loadModel().catch((error) => { | |
| console.error(error); | |
| setStatus(error instanceof Error ? error.message : "Failed to load model."); | |
| }); | |