Spaces:
Running
Running
File size: 4,801 Bytes
2dd8e33 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.2';
// Skip local model checks since we are fetching from HF Hub
env.allowLocalModels = false;
const MODEL_ID = 'phiph/DA-2-WebGPU';
const INPUT_WIDTH = 1092;
const INPUT_HEIGHT = 546;
let depth_estimator = null;
const statusElement = document.getElementById('status');
const runBtn = document.getElementById('runBtn');
const imageInput = document.getElementById('imageInput');
const inputCanvas = document.getElementById('inputCanvas');
const outputCanvas = document.getElementById('outputCanvas');
const inputCtx = inputCanvas.getContext('2d');
const outputCtx = outputCanvas.getContext('2d');
// Initialize Transformers.js Pipeline
async function init() {
try {
statusElement.textContent = 'Loading model... (this may take a while)';
// Initialize the pipeline
depth_estimator = await pipeline('depth-estimation', MODEL_ID, {
device: 'webgpu',
dtype: 'fp32', // Important: Model is FP32
});
statusElement.textContent = 'Model loaded. Ready.';
runBtn.disabled = false;
} catch (e) {
console.error(e);
statusElement.textContent = 'Error loading model: ' + e.message;
// Fallback to wasm if webgpu fails
try {
statusElement.textContent = 'WebGPU failed, trying WASM...';
depth_estimator = await pipeline('depth-estimation', MODEL_ID, {
device: 'wasm',
dtype: 'fp32'
});
statusElement.textContent = 'Model loaded (WASM). Ready.';
runBtn.disabled = false;
} catch (e2) {
statusElement.textContent = 'Error loading model (WASM): ' + e2.message;
}
}
}
imageInput.addEventListener('change', (e) => {
const file = e.target.files[0];
if (!file) return;
const img = new Image();
img.onload = () => {
inputCanvas.width = INPUT_WIDTH;
inputCanvas.height = INPUT_HEIGHT;
inputCtx.drawImage(img, 0, 0, INPUT_WIDTH, INPUT_HEIGHT);
// Clear output
outputCanvas.width = INPUT_WIDTH;
outputCanvas.height = INPUT_HEIGHT;
outputCtx.clearRect(0, 0, INPUT_WIDTH, INPUT_HEIGHT);
};
img.src = URL.createObjectURL(file);
});
runBtn.addEventListener('click', async () => {
if (!depth_estimator) return;
statusElement.textContent = 'Running inference...';
runBtn.disabled = true;
try {
// Get the image source from the canvas (or the file URL directly)
// Using the canvas data ensures we are passing what the user sees
const url = inputCanvas.toDataURL();
// Run inference
// The pipeline handles preprocessing (resize, rescale) automatically
const output = await depth_estimator(url);
// output.depth is the raw tensor
// output.mask is the visualized depth map (Image object) if available,
// but for custom models it might just return the tensor.
// Let's check what we got
if (output.depth) {
// Visualize the raw tensor manually to be safe
visualize(output.depth.data, INPUT_WIDTH, INPUT_HEIGHT);
} else {
// Fallback if structure is different
console.log("Output structure:", output);
statusElement.textContent = 'Done (Check console for output structure).';
}
statusElement.textContent = 'Done.';
} catch (e) {
console.error(e);
statusElement.textContent = 'Error running inference: ' + e.message;
} finally {
runBtn.disabled = false;
}
});
function visualize(data, width, height) {
// Find min and max for normalization
let min = Infinity;
let max = -Infinity;
for (let i = 0; i < data.length; i++) {
if (data[i] < min) min = data[i];
if (data[i] > max) max = data[i];
}
const range = max - min;
const imageData = outputCtx.createImageData(width, height);
for (let i = 0; i < data.length; i++) {
// Normalize to 0-1
const val = (data[i] - min) / (range || 1);
// Simple heatmap (Magma-like or just grayscale)
// Inverted depth usually looks better (closer is brighter)
// But here it's distance, so closer is smaller value.
// If we map min (close) to 255 (white) and max (far) to 0 (black)
const pixelVal = Math.floor((1 - val) * 255);
imageData.data[i * 4] = pixelVal; // R
imageData.data[i * 4 + 1] = pixelVal; // G
imageData.data[i * 4 + 2] = pixelVal; // B
imageData.data[i * 4 + 3] = 255; // Alpha
}
outputCtx.putImageData(imageData, 0, 0);
}
init();
|