DannyChi's picture
Upload 65 files
104ee77 verified
class InferenceApp {
constructor() {
// 取得網頁上的 UI 元件
this.fileInput = document.getElementById('fileInput');
this.depthfileInput = document.getElementById('depthfileInput');
this.inputImage = document.getElementById('inputImage');
this.DepthInputImage = document.getElementById('DepthInputImage');
this.runBtn = document.getElementById('runBtn');
this.adjustDepthBtn = document.getElementById('adjustDepthBtn');
// 左側結果 (mask 推論結果) 之畫布
this.resultCanvas = document.getElementById('resultCanvas');
this.resultCtx = this.resultCanvas.getContext('2d');
// 右側最終調整後深度
this.resultCanvas_Adjust = document.getElementById('resultCanvas_Adjust');
this.resultCtx_Adjust = this.resultCanvas_Adjust.getContext('2d');
// ONNXRuntime Session
this.session = null;
// 綁定事件
this.fileInput.addEventListener('change', (e) => this.onFileChange(e));
this.depthfileInput.addEventListener('change', (e) => this.onDepthFileChange(e));
this.runBtn.addEventListener('click', () => this.runInference());
this.adjustDepthBtn.addEventListener('click', () => this.runDepthAdjustment());
// 這裡先記錄:推論後的 mask,和上傳的深度圖 (灰階)
this.maskData = null; // Float32Array,形狀 [1,1,H,W] 攤平成一維
this.depthData = null; // Uint8ClampedArray (或 Float32Array),灰階
this.depthWidth = 0;
this.depthHeight = 0;
}
/**
* 初始化:嘗試使用 WebGPU,若失敗則使用 CPU (WASM)
*/
async init() {
console.log('初始化模型...');
try {
console.log('嘗試使用 WebGPU...');
this.session = await ort.InferenceSession.create('isnet_infer.onnx', {
executionProviders: ['webgpu'],
});
console.log('成功使用 WebGPU');
} catch (error) {
console.warn('使用 WebGPU 失敗,改用 WASM:', error);
console.log('使用 WASM (CPU)...');
this.session = await ort.InferenceSession.create('isnet_infer.onnx', {
executionProviders: ['wasm'],
});
}
console.log('模型初始化完成!');
}
/**
* 使用者選擇 RGB 檔案 (用來推論出 mask)
*/
onFileChange(e) {
const file = e.target.files[0];
if (!file) return;
const url = URL.createObjectURL(file);
// 顯示圖片
this.inputImage.src = url;
}
/**
* 使用者選擇深度檔案 (灰階),存起來做後續調整
*/
onDepthFileChange(e) {
const file = e.target.files[0];
if (!file) return;
const img = new Image();
img.onload = () => {
// 目標縮放大小
const resizeWidth = 1024;
const resizeHeight = 1024;
// 先把深度圖「縮放」到 1024×1024
const tmpCanvas = document.createElement('canvas');
tmpCanvas.width = resizeWidth;
tmpCanvas.height = resizeHeight;
const tmpCtx = tmpCanvas.getContext('2d');
// drawImage 時指定目標寬高 => 強制縮放
tmpCtx.drawImage(img, 0, 0, resizeWidth, resizeHeight);
// 取出縮放後的影像像素 (RGBA)
const imageData = tmpCtx.getImageData(0, 0, resizeWidth, resizeHeight);
// 假設深度圖是純灰階 (R=G=B),只取 R 通道
this.depthData = new Uint8ClampedArray(resizeWidth * resizeHeight);
for (let i = 0; i < imageData.data.length; i += 4) {
// R=G=B => 直接取 data[i] 當深度
this.depthData[i / 4] = imageData.data[i];
}
// 記錄「縮放後」的寬高
this.depthWidth = resizeWidth;
this.depthHeight = resizeHeight;
// 右側預覽深度圖 (純視覺示範,可依需求改動)
this.DepthInputImage.src = URL.createObjectURL(file);
console.log(`已將深度圖縮放為: ${resizeWidth}x${resizeHeight}`);
};
img.src = URL.createObjectURL(file);
}
/**
* 讀取 <img> 並轉成可做推論的 Float32Array
* (示範:縮放至 1024x1024,並做 (x/255 - 0.5))
*/
async preprocessImage() {
const desiredWidth = 1024;
const desiredHeight = 1024;
// 1. 建立暫時的 canvas,把圖片畫上去
const tmpCanvas = document.createElement('canvas');
tmpCanvas.width = desiredWidth;
tmpCanvas.height = desiredHeight;
const tmpCtx = tmpCanvas.getContext('2d');
// 2. 將 <img> 縮放並畫到暫時 canvas
tmpCtx.drawImage(this.inputImage, 0, 0, desiredWidth, desiredHeight);
// 3. 取出影像的像素資料 (RGBA)
const imageData = tmpCtx.getImageData(0, 0, desiredWidth, desiredHeight);
const data = imageData.data; // [r, g, b, a, r, g, b, a, ...]
// 4. 建立 Float32Array (3 通道) => [3, 1024, 1024]
const floatData = new Float32Array(3 * desiredHeight * desiredWidth);
// 5. RGBA -> RGB,並做 (/255 - 0.5)
let idx = 0;
for (let i = 0; i < data.length; i += 4) {
const r = data[i];
const g = data[i + 1];
const b = data[i + 2];
floatData[idx + 0 * desiredWidth * desiredHeight] = r / 255.0 - 0.5;
floatData[idx + 1 * desiredWidth * desiredHeight] = g / 255.0 - 0.5;
floatData[idx + 2 * desiredWidth * desiredHeight] = b / 255.0 - 0.5;
idx++;
}
// 6. 加上 batch 維度 => [1, 3, 1024, 1024]
const inputTensor = new ort.Tensor('float32', floatData, [
1,
3,
desiredHeight,
desiredWidth,
]);
return inputTensor;
}
/**
* 按下「執行推論」: 針對 RGB 圖產生 mask
*/
async runInference() {
if (!this.session) {
alert('模型尚未初始化完成,請稍後再試。');
return;
}
if (!this.inputImage.src) {
alert('請先上傳 RGB 圖片 (做 mask 推論)。');
return;
}
try {
console.log('開始前處理圖片 (for mask 推論)...');
const inputTensor = await this.preprocessImage();
console.log('開始推論...');
const startTime = performance.now();
const feeds = { input: inputTensor };
const results = await this.session.run(feeds);
// 假設模型輸出名稱為 'output',且是 [1, 1, 1024, 1024]
const outputTensor = results['output'];
const outputData = outputTensor.data; // Float32Array
const endTime = performance.now();
console.log(`推論完成, 耗時: ${endTime - startTime} ms`);
// 後處理:先做 min-max normalization => normalized
let minVal = +Infinity;
let maxVal = -Infinity;
for (let i = 0; i < outputData.length; i++) {
const v = outputData[i];
if (v < minVal) minVal = v;
if (v > maxVal) maxVal = v;
}
const range = maxVal - minVal;
const normalized = new Float32Array(outputData.length);
for (let i = 0; i < outputData.length; i++) {
normalized[i] = (outputData[i] - minVal) / range;
}
// 畫在左邊的 resultCanvas 以做確認 (0~255 灰階)
const width = 1024, height = 1024;
const resultImageData = this.resultCtx.createImageData(width, height);
for (let i = 0; i < width * height; i++) {
const v = Math.floor(normalized[i] * 255);
resultImageData.data[i * 4 + 0] = v;
resultImageData.data[i * 4 + 1] = v;
resultImageData.data[i * 4 + 2] = v;
resultImageData.data[i * 4 + 3] = 255;
}
this.resultCtx.putImageData(resultImageData, 0, 0);
// 這裡把 normalized 當作 mask (閾值 > 0.5)
// 先存下來之後給深度調整使用
this.maskData = normalized;
// 1) 先建立暫時 canvas (1024×1024)
const tmpCanvas = document.createElement('canvas');
tmpCanvas.width = width;
tmpCanvas.height = height;
const tmpCtx = tmpCanvas.getContext('2d');
// 2) 用 putImageData() 將 resultImageData 貼到暫時 canvas
tmpCtx.putImageData(resultImageData, 0, 0);
// 3) 清空 resultCanvas 後,用 drawImage() 縮放貼到 400×400
this.resultCtx.clearRect(0, 0, this.resultCanvas.width, this.resultCanvas.height);
// drawImage(sourceCanvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight)
this.resultCtx.drawImage(tmpCanvas, 0, 0, width, height, 0, 0, 400, 400);
//alert('Mask 推論完成!如需對深度圖做 CLHE + 主體增強,請再上傳深度圖後點擊「對深度圖做 CLHE + 主體增強」。');
} catch (e) {
console.error('推論時發生錯誤:', e);
}
}
/**
* 按下「對深度圖做 CLHE + 主體增強」
*/
async runDepthAdjustment() {
if (!this.depthData) {
alert('尚未上傳深度圖');
return;
}
if (!this.maskData) {
alert('尚未進行 mask 推論');
return;
}
// 假設 maskData 與 depthData 都是 1024×1024 (若實際不同,需自行處理對應)
const width = 1024, height = 1024;
if (width * height !== this.depthWidth * this.depthHeight) {
alert('目前示範假設深度圖與 mask 同尺寸 1024×1024,請自行在程式裡做對應處理。');
return;
}
console.log('開始對深度圖做 CLHE_new + MainObjectTo128_new...');
const startTime = performance.now();
// 1) 把 0~255 深度轉成 [0,1]
const depthFloat = new Float32Array(width * height);
for (let i = 0; i < width * height; i++) {
depthFloat[i] = this.depthData[i] / 255.0;
}
// 2) 取得二值化 mask
const maskBin = new Uint8Array(width * height);
for (let i = 0; i < width * height; i++) {
maskBin[i] = this.maskData[i] > 0.5 ? 1 : 0;
}
// 3) CLHE_new
const { depthCLHE, cdfArray } = this.CLHE_new(depthFloat, maskBin, width, height);
//console.log(`cdfArray: ${cdfArray}`);
// 4) MainObjectTo128_new => 產生 yCurve
// 取出 mask>0 的 depthCLHE
const depthMaskVals = [];
for (let i = 0; i < width * height; i++) {
if (maskBin[i] === 1) {
depthMaskVals.push(depthCLHE[i]);
}
}
const yCurve = this.MainObjectTo128_new(depthMaskVals);
// 5) combineCurve = yCurve[cdf_[v]]
// cdf_ => cdfArray (長度 1024)
const combineCurve = new Float32Array(1024);
for (let i = 0; i < 1024; i++) {
// cdfArray[i] 也在 [0, 1023] 之間
const cdfVal = cdfArray[i];
combineCurve[i] = yCurve[cdfVal];
}
// 6) 將 combineCurve 映射回原始 depth
const finalDepth = new Float32Array(width * height);
for (let i = 0; i < width * height; i++) {
const idx = Math.floor(depthFloat[i] * 1023);
finalDepth[i] = combineCurve[idx];
}
const endTime = performance.now();
console.log(`調整完成, 耗時: ${endTime - startTime} ms`);
// 7) 畫到 resultCanvas_Adjust
const resultImageData = this.resultCtx_Adjust.createImageData(width, height);
for (let i = 0; i < width * height; i++) {
const v = Math.floor(finalDepth[i] / 1023 * 255);
const i4 = i * 4;
resultImageData.data[i4 + 0] = v;
resultImageData.data[i4 + 1] = v;
resultImageData.data[i4 + 2] = v;
resultImageData.data[i4 + 3] = 255;
}
// 1) 暫時 canvas 1024×1024
const tmpCanvas = document.createElement('canvas');
tmpCanvas.width = width; // 1024
tmpCanvas.height = height; // 1024
const tmpCtx = tmpCanvas.getContext('2d');
// 2) 放入暫時 canvas
tmpCtx.putImageData(resultImageData, 0, 0);
// 3) 縮放到 400×400
this.resultCtx_Adjust.clearRect(0, 0, this.resultCanvas_Adjust.width, this.resultCanvas_Adjust.height);
this.resultCtx_Adjust.drawImage(tmpCanvas, 0, 0, width, height, 0, 0, 400, 400);
console.log('深度圖調整完成!已顯示於右側畫布。');
//alert('深度圖調整完成!');
}
CLHE_new(depthArr, maskArr, width, height) {
const maskedVals = [];
for (let i = 0; i < width * height; i++) {
if (maskArr[i] === 1) {
maskedVals.push(depthArr[i]);
}
}
let median = 0.5, th25 = 0.5, th75 = 0.5;
if (maskedVals.length > 50) {
maskedVals.sort((a, b) => a - b);
const n = maskedVals.length;
median = maskedVals[Math.floor(n / 2)];
th25 = maskedVals[Math.floor(n * 0.15)];
th75 = maskedVals[Math.floor(n * 0.85)];
}
console.log(`th25: ${th25}, median: ${median}, th75: ${th75}`);
const bins = 1024;
// 2) 建立 histogram
const hist = new Float32Array(bins);
for (let i = 0; i < depthArr.length; i++) {
const idx = Math.floor(depthArr[i] * (bins - 1));
hist[idx]++;
}
//console.log(`hist: ${hist}`);
// 3) Clip Limit (Limit=2.5)
const totalnum = width * height;
const mean_val = hist.reduce((a, b) => a + b, 0) / bins;
const Limit = 2.5;
let excessSum = 0.0;
for (let i = 0; i < bins; i++) {
const limitVal = mean_val * Limit;
if (hist[i] > limitVal) {
excessSum += hist[i] - limitVal;
hist[i] = limitVal;
}
}
// 分攤 Excess
const avgExcess = excessSum / bins;
for (let i = 0; i < bins; i++) {
hist[i] += avgExcess;
}
// 4) PDF + 權重區間 (th25~th75) => cdf
const pdf = new Float32Array(bins);
for (let i = 0; i < bins; i++) {
pdf[i] = hist[i] / totalnum;
}
const weight = 1.5;
const lower_bound = th25 * (bins - 1);
const upper_bound = th75 * (bins - 1);
const pdfWeighted = new Float32Array(bins);
for (let i = 0; i < bins; i++) {
const inRange = (i >= lower_bound) && (i <= upper_bound);
pdfWeighted[i] = inRange ? pdf[i] * weight : pdf[i];
}
// 計算 cdf
const cdfArray = new Int32Array(bins);
let cumulative = 0;
let sumAll = pdfWeighted.reduce((a, b) => a + b, 0); // 用來歸一化
for (let i = 0; i < bins; i++) {
cumulative += pdfWeighted[i];
// cdf (0 ~ bins-1)
cdfArray[i] = Math.round((cumulative / sumAll) * (bins - 1));
}
// 5) 平滑 (簡易 mean filter)
const cdfFloat = new Float32Array(cdfArray.length);
for (let i = 0; i < cdfArray.length; i++) cdfFloat[i] = cdfArray[i];
const cdfSmooth = this.meanFilter1D(cdfFloat, 101);
const oldMin = Math.min(...cdfFloat);
const oldMax = Math.max(...cdfFloat);
const newMin = Math.min(...cdfSmooth);
const newMax = Math.max(...cdfSmooth);
const ratio = (oldMax - oldMin) / (newMax - newMin);
for (let i = 0; i < bins; i++) {
cdfFloat[i] = (cdfSmooth[i] - newMin) * ratio + oldMin;
}
for (let i = 0; i < bins; i++) {
cdfArray[i] = Math.round(cdfFloat[i]);
}
//console.log(`cdfArray: ${cdfArray}`);
// 6) 對原圖做 lookup => depthCLHE
const depthCLHE = new Float32Array(depthArr.length);
for (let i = 0; i < depthArr.length; i++) {
const idx = Math.floor(depthArr[i] * (bins - 1));
depthCLHE[i] = cdfArray[idx] / (bins - 1);
}
return { depthCLHE, cdfArray };
}
MainObjectTo128_new(depthMaskVals) {
const bins = 1024;
if (depthMaskVals.length < 50) {
console.log('MainObjectTo128_new: mask 值太少,直接回傳 identity');
const y_ = new Float32Array(bins);
for (let i = 0; i < bins; i++) {
y_[i] = i / (bins - 1);
}
return y_;
}
depthMaskVals.sort((a, b) => a - b);
const n = depthMaskVals.length;
const median = depthMaskVals[Math.floor(n / 2)];
const th25 = depthMaskVals[Math.floor(n * 0.25)];
const th75 = depthMaskVals[Math.floor(n * 0.75)];
console.log(`MainObjectTo128_new:: th25: ${th25}, median: ${median}, th75: ${th75}`);
if (median === 0 || median === 1 || (th75 - th25) === 0) {
console.log('[MainObjectTo128_new_js] Condition not met, return identity.');
const y_ = new Float32Array(bins);
for (let i = 0; i < bins; i++) {
y_[i] = i / (bins - 1);
}
return y_;
}
if (th25 === 0) th25 = 0.0000001;
if (th75 === 1) th75 = 0.9999999;
const a = 0.44535377;
const b = 0.6315172;
const AfterRange = Math.pow(a * (th75 - th25), b);
let scale = AfterRange / (th75 - th25);
let medianScaleShift = median * scale - 0.5;
let th25_ = th25 * scale - medianScaleShift;
let th75_ = th75 * scale - medianScaleShift;
// 極端情況修正
let low_break = false;
let high_break = false;
if (th25_ < 0) {
th25_ = 0;
low_break = true;
}
if (th75_ > 1) {
th75_ = 1;
high_break = true;
}
if (low_break) {
scale = (th75_ - th25_) / (th75 - th25);
medianScaleShift = th25 * scale;
th25 = (th25_ + medianScaleShift) / scale;
th75 = (th75_ + medianScaleShift) / scale;
}
if (high_break) {
scale = (th75_ - th25_) / (th75 - th25);
medianScaleShift = th75 * scale - 1;
th25 = (th25_ + medianScaleShift) / scale;
th75 = (th75_ + medianScaleShift) / scale;
}
const ext = 0.01;
const th25_2 = (th25 + ext) * scale - medianScaleShift;
const th75_2 = (th75 - ext) * scale - medianScaleShift;
// 事先算好 b1, a1 , b2, a2
let b1 = 0.0, a1 = 0.0;
let b2 = 0.0, a2 = 0.0;
// < th25 時的係數
if (!low_break) {
const logNumer = Math.log(th25_2) - Math.log(th25_);
const logDenom = Math.log(th25 + ext) - Math.log(th25);
b1 = logNumer / logDenom;
if (b1 > 1.2) b1 = 1.2;
a1 = th25_ / Math.pow(th25, b1);
}
// > th75 時的係數
if (!high_break) {
const logNumer = Math.log(1 - th75_2) - Math.log(th75_ - th75_2);
const logDenom = Math.log(1 - th75 + ext) - Math.log(ext);
b2 = logNumer / logDenom;
a2 = (th75_ - th75_2) / Math.pow(ext, b2);
}
console.log(`a1: ${a1}, b1: ${b1}, a2: ${a2}, b2: ${b2}`);
// 建立 1024 bins & 生成 x
const x = new Float32Array(bins);
for (let i = 0; i < bins; i++) {
x[i] = i / (bins - 1); // 0~1
}
const y = new Float32Array(bins).fill(0);
// 中段 (th25 <= x <= th75)
for (let i = 0; i < bins; i++) {
if (x[i] >= th25 && x[i] <= th75) {
y[i] = x[i] * scale - medianScaleShift;
}
}
// 左段 (x < th25)
if (low_break) {
for (let i = 0; i < bins; i++) {
if (x[i] < th25) {
y[i] = 0;
}
}
}
else {
for (let i = 0; i < bins; i++) {
if (x[i] < th25) {
if (b1 < 0.3) {
y[i] = (x[i] * (scale * 2)) - ((scale * 2) * th25 - th25_);
} else {
y[i] = a1 * Math.pow(x[i], b1);
}
}
}
}
// 右段 (x > th75)
if (high_break) {
for (let i = 0; i < bins; i++) {
if (x[i] > th75) {
y[i] = 1;
}
}
} else {
for (let i = 0; i < bins; i++) {
if (x[i] > th75) {
if (b2 > 1) {
y[i] = (x[i] * scale) - (scale * th75 - th75_);
} else {
y[i] = a2 * Math.pow((x[i] - th75 + ext), b2) + th75_2;
}
}
}
}
// smooth
const ksize = 301;
const y_ = this.MeanFilter2_fast_js(y, ksize);
// rescale y 的大致範圍
const yMin = Math.min(...y);
const yMax = Math.max(...y);
const y_FMin = Math.min(...y_);
const y_FMax = Math.max(...y_);
const ratio = (yMax - yMin) / (y_FMax - y_FMin);
for (let i = 0; i < bins; i++) {
y_[i] = (y_[i] - y_FMin) * ratio + yMin;
}
// 低對比縮放
const max_th = 0.85;
let curveMAX = Math.max(...y_);
if (curveMAX < max_th) {
const yMin2 = Math.min(...y_);
const max_new = curveMAX * (0.625 / 0.5);
let ratio2 = 1.0;
if (max_new > max_th) {
ratio2 = (max_th - yMin2) / (curveMAX - yMin2);
} else {
ratio2 = (max_new - yMin2) / (curveMAX - yMin2);
}
for (let i = 0; i < bins; i++) {
y_[i] = (y_[i] - yMin2) * ratio2 + yMin2;
}
}
const min_th = 0.15;
let curveMIN = Math.min(...y_);
if (curveMIN > min_th) {
const yMax2 = Math.max(...y_);
const diff = (1 - curveMIN) * (0.625 / 0.5);
const min_new = 1 - diff;
let ratio3 = 1.0, newmin = 0.0;
if (min_new > min_th) {
ratio3 = (yMax2 - min_th) / (yMax2 - curveMIN);
newmin = min_th;
} else if (min_new < 0) {
ratio3 = (yMax2 - 0) / (yMax2 - curveMIN);
newmin = 0;
} else {
ratio3 = (yMax2 - min_new) / (yMax2 - curveMIN);
newmin = min_new;
}
for (let i = 0; i < bins; i++) {
y_[i] = (y_[i] - curveMIN) * ratio3 + newmin;
}
}
for (let i = 0; i < bins; i++) {
y_[i] *= (bins - 1);
}
//console.log(`y_: ${y_}`);
return y_;
}
meanFilter1D(arr, k) {
const half = Math.floor(k / 2);
const n = arr.length;
const out = new Float32Array(n);
// replicate padding
const ext = new Float32Array(n + 2 * half);
for (let i = 0; i < half; i++) {
ext[i] = arr[0];
}
for (let i = 0; i < n; i++) {
ext[i + half] = arr[i];
}
for (let i = 0; i < half; i++) {
ext[n + half + i] = arr[n - 1];
}
// conv
const oneOverK = 1.0 / k;
for (let i = 0; i < n; i++) {
let sum = 0.0;
for (let j = 0; j < k; j++) {
sum += ext[i + j];
}
out[i] = sum * oneOverK;
}
return out;
}
MeanFilter2_fast_js(x, k) {
const n = x.length;
const k2 = Math.floor((k - 1) / 2);
const step = x[n - 1] - x[n - 2];
const leftPad = new Float32Array(k2).fill(x[0]);
const rightPad = new Float32Array(k2);
for (let i = 0; i < k2; i++) {
rightPad[i] = x[n - 1] + (i + 1) * step;
}
const x_ext = new Float32Array(n + 2 * k2);
x_ext.set(leftPad, 0);
x_ext.set(x, k2);
x_ext.set(rightPad, k2 + n);
// conv
const kernel = new Float32Array(k).fill(1.0 / k);
const out = new Float32Array(n);
for (let i = 0; i < n; i++) {
let sum = 0;
for (let j = 0; j < k; j++) {
sum += x_ext[i + j] * kernel[j];
}
out[i] = sum;
}
return out;
}
}
// 頁面載入完成後,初始化並執行
document.addEventListener('DOMContentLoaded', async () => {
const app = new InferenceApp();
await app.init(); // 初始化模型 (嘗試 WebGPU / WASM)
});