File size: 4,910 Bytes
2dd8e33
 
 
 
4898509
 
2dd8e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db5acd6
2dd8e33
 
 
 
 
 
 
 
 
 
 
 
8f3c5f6
db5acd6
2dd8e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.2';

// Skip local model checks since we are fetching from HF Hub
env.allowLocalModels = false;
// Enable caching
env.useBrowserCache = true;

const MODEL_ID = 'phiph/DA-2-WebGPU';
const INPUT_WIDTH = 1092;
const INPUT_HEIGHT = 546;

let depth_estimator = null;
const statusElement = document.getElementById('status');
const runBtn = document.getElementById('runBtn');
const imageInput = document.getElementById('imageInput');
const inputCanvas = document.getElementById('inputCanvas');
const outputCanvas = document.getElementById('outputCanvas');
const inputCtx = inputCanvas.getContext('2d');
const outputCtx = outputCanvas.getContext('2d');

// Initialize Transformers.js Pipeline
async function init() {
    try {
        statusElement.textContent = 'Loading model... (this may take a while)';
        
        // Initialize the pipeline
        depth_estimator = await pipeline('depth-estimation', MODEL_ID, {
            device: 'webgpu',
            dtype: 'fp32', // Important: Model is FP32
            quantized: false
        });

        statusElement.textContent = 'Model loaded. Ready.';
        runBtn.disabled = false;
    } catch (e) {
        console.error(e);
        statusElement.textContent = 'Error loading model: ' + e.message;
        // Fallback to wasm if webgpu fails
        try {
            statusElement.textContent = 'WebGPU failed, trying WASM...';
            depth_estimator = await pipeline('depth-estimation', MODEL_ID, {
                device: 'wasm',
                dtype: 'fp32',
                quantized: false
            });
            statusElement.textContent = 'Model loaded (WASM). Ready.';
            runBtn.disabled = false;
        } catch (e2) {
            statusElement.textContent = 'Error loading model (WASM): ' + e2.message;
        }
    }
}

imageInput.addEventListener('change', (e) => {
    const file = e.target.files[0];
    if (!file) return;

    const img = new Image();
    img.onload = () => {
        inputCanvas.width = INPUT_WIDTH;
        inputCanvas.height = INPUT_HEIGHT;
        inputCtx.drawImage(img, 0, 0, INPUT_WIDTH, INPUT_HEIGHT);
        
        // Clear output
        outputCanvas.width = INPUT_WIDTH;
        outputCanvas.height = INPUT_HEIGHT;
        outputCtx.clearRect(0, 0, INPUT_WIDTH, INPUT_HEIGHT);
    };
    img.src = URL.createObjectURL(file);
});

runBtn.addEventListener('click', async () => {
    if (!depth_estimator) return;
    
    statusElement.textContent = 'Running inference...';
    runBtn.disabled = true;

    try {
        // Get the image source from the canvas (or the file URL directly)
        // Using the canvas data ensures we are passing what the user sees
        const url = inputCanvas.toDataURL();

        // Run inference
        // The pipeline handles preprocessing (resize, rescale) automatically
        const output = await depth_estimator(url);
        
        // output.depth is the raw tensor
        // output.mask is the visualized depth map (Image object) if available, 
        // but for custom models it might just return the tensor.
        
        // Let's check what we got
        if (output.depth) {
             // Visualize the raw tensor manually to be safe
             visualize(output.depth.data, INPUT_WIDTH, INPUT_HEIGHT);
        } else {
             // Fallback if structure is different
             console.log("Output structure:", output);
             statusElement.textContent = 'Done (Check console for output structure).';
        }

        statusElement.textContent = 'Done.';
    } catch (e) {
        console.error(e);
        statusElement.textContent = 'Error running inference: ' + e.message;
    } finally {
        runBtn.disabled = false;
    }
});

function visualize(data, width, height) {
    // Find min and max for normalization
    let min = Infinity;
    let max = -Infinity;
    for (let i = 0; i < data.length; i++) {
        if (data[i] < min) min = data[i];
        if (data[i] > max) max = data[i];
    }

    const range = max - min;
    const imageData = outputCtx.createImageData(width, height);
    
    for (let i = 0; i < data.length; i++) {
        // Normalize to 0-1
        const val = (data[i] - min) / (range || 1);
        
        // Simple heatmap (Magma-like or just grayscale)
        // Inverted depth usually looks better (closer is brighter)
        // But here it's distance, so closer is smaller value.
        // If we map min (close) to 255 (white) and max (far) to 0 (black)
        
        const pixelVal = Math.floor((1 - val) * 255);

        imageData.data[i * 4] = pixelVal; // R
        imageData.data[i * 4 + 1] = pixelVal; // G
        imageData.data[i * 4 + 2] = pixelVal; // B
        imageData.data[i * 4 + 3] = 255; // Alpha
    }
    
    outputCtx.putImageData(imageData, 0, 0);
}

init();