| | const MODEL_PATH = '../onnx/model.onnx'; |
| | const INPUT_WIDTH = 1092; |
| | const INPUT_HEIGHT = 546; |
| |
|
| | let session = null; |
| | const statusElement = document.getElementById('status'); |
| | const runBtn = document.getElementById('runBtn'); |
| | const imageInput = document.getElementById('imageInput'); |
| | const inputCanvas = document.getElementById('inputCanvas'); |
| | const outputCanvas = document.getElementById('outputCanvas'); |
| | const inputCtx = inputCanvas.getContext('2d'); |
| | const outputCtx = outputCanvas.getContext('2d'); |
| |
|
| | |
| | async function init() { |
| | try { |
| | |
| | ort.env.debug = true; |
| | ort.env.logLevel = 'verbose'; |
| |
|
| | statusElement.textContent = 'Loading model... (this may take a while)'; |
| | |
| | const options = { |
| | executionProviders: ['webgpu'], |
| | }; |
| | session = await ort.InferenceSession.create(MODEL_PATH, options); |
| | statusElement.textContent = 'Model loaded. Ready.'; |
| | runBtn.disabled = false; |
| | } catch (e) { |
| | console.error(e); |
| | statusElement.textContent = 'Error loading model: ' + e.message; |
| | |
| | try { |
| | statusElement.textContent = 'WebGPU failed, trying WASM...'; |
| | session = await ort.InferenceSession.create(MODEL_PATH, { executionProviders: ['wasm'] }); |
| | statusElement.textContent = 'Model loaded (WASM). Ready.'; |
| | runBtn.disabled = false; |
| | } catch (e2) { |
| | statusElement.textContent = 'Error loading model (WASM): ' + e2.message; |
| | } |
| | } |
| | } |
| |
|
| | imageInput.addEventListener('change', (e) => { |
| | const file = e.target.files[0]; |
| | if (!file) return; |
| |
|
| | const img = new Image(); |
| | img.onload = () => { |
| | inputCanvas.width = INPUT_WIDTH; |
| | inputCanvas.height = INPUT_HEIGHT; |
| | inputCtx.drawImage(img, 0, 0, INPUT_WIDTH, INPUT_HEIGHT); |
| | |
| | |
| | outputCanvas.width = INPUT_WIDTH; |
| | outputCanvas.height = INPUT_HEIGHT; |
| | outputCtx.clearRect(0, 0, INPUT_WIDTH, INPUT_HEIGHT); |
| | }; |
| | img.src = URL.createObjectURL(file); |
| | }); |
| |
|
| | runBtn.addEventListener('click', async () => { |
| | if (!session) return; |
| | |
| | statusElement.textContent = 'Running inference...'; |
| | runBtn.disabled = true; |
| |
|
| | try { |
| | |
| | const imageData = inputCtx.getImageData(0, 0, INPUT_WIDTH, INPUT_HEIGHT); |
| | const tensor = preprocess(imageData); |
| |
|
| | |
| | const feeds = { pixel_values: tensor }; |
| | const results = await session.run(feeds); |
| | const output = results.predicted_depth; |
| |
|
| | |
| | visualize(output.data, INPUT_WIDTH, INPUT_HEIGHT); |
| | statusElement.textContent = 'Done.'; |
| | } catch (e) { |
| | console.error(e); |
| | statusElement.textContent = 'Error running inference: ' + e.message; |
| | } finally { |
| | runBtn.disabled = false; |
| | } |
| | }); |
| |
|
| | function preprocess(imageData) { |
| | const { data, width, height } = imageData; |
| | const float32Data = new Float32Array(3 * width * height); |
| | |
| | |
| | for (let i = 0; i < width * height; i++) { |
| | const r = data[i * 4] / 255.0; |
| | const g = data[i * 4 + 1] / 255.0; |
| | const b = data[i * 4 + 2] / 255.0; |
| |
|
| | float32Data[i] = r; |
| | float32Data[width * height + i] = g; |
| | float32Data[2 * width * height + i] = b; |
| | } |
| |
|
| | return new ort.Tensor('float32', float32Data, [1, 3, height, width]); |
| | } |
| |
|
| | function visualize(data, width, height) { |
| | |
| | let min = Infinity; |
| | let max = -Infinity; |
| | for (let i = 0; i < data.length; i++) { |
| | if (data[i] < min) min = data[i]; |
| | if (data[i] > max) max = data[i]; |
| | } |
| |
|
| | const range = max - min; |
| | const imageData = outputCtx.createImageData(width, height); |
| | |
| | for (let i = 0; i < data.length; i++) { |
| | |
| | const val = (data[i] - min) / (range || 1); |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | const pixelVal = Math.floor((1 - val) * 255); |
| |
|
| | imageData.data[i * 4] = pixelVal; |
| | imageData.data[i * 4 + 1] = pixelVal; |
| | imageData.data[i * 4 + 2] = pixelVal; |
| | imageData.data[i * 4 + 3] = 255; |
| | } |
| | |
| | outputCtx.putImageData(imageData, 0, 0); |
| | } |
| |
|
| | init(); |