Reza2kn's picture
Fix ONNX sidecar loading and WebGPU sample paths
8653de1 verified
Raw
History Blame Contribute Delete
14.5 kB
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>PIXLRelight WebGPU INT4</title>
<style>
:root { color-scheme: light; font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; }
body { margin: 0; background: #f8fafc; color: #0f172a; }
main { max-width: 1180px; margin: 0 auto; padding: 20px; }
.hero { padding: 22px 24px; background: #0f172a; color: white; border-radius: 8px; }
.hero h1 { margin: 0 0 6px; font-size: 28px; line-height: 1.1; }
.hero p { margin: 0; color: #cbd5e1; font-size: 14px; }
.grid { display: grid; grid-template-columns: minmax(280px, 1fr) minmax(280px, 1fr); gap: 16px; margin-top: 16px; }
.panel { background: white; border: 1px solid #e2e8f0; border-radius: 8px; padding: 14px; }
.row { display: flex; gap: 10px; align-items: center; flex-wrap: wrap; }
button, select, input[type="file"]::file-selector-button {
border: 1px solid #cbd5e1; background: white; border-radius: 6px; padding: 8px 11px; color: #0f172a; cursor: pointer;
}
button.primary { background: #f59e0b; border-color: #d97706; color: #111827; font-weight: 700; }
button:disabled { opacity: .55; cursor: wait; }
label { display: grid; gap: 5px; font-size: 12px; color: #475569; }
input[type="range"] { width: 150px; }
canvas { display: block; width: 100%; max-width: 512px; aspect-ratio: 1 / 1; background: #111827; border-radius: 6px; }
.canvases { display: grid; grid-template-columns: repeat(2, minmax(220px, 1fr)); gap: 12px; align-items: start; }
.status { margin-top: 12px; padding: 10px 12px; background: #fff7ed; border: 1px solid #fed7aa; border-radius: 6px; font-size: 13px; color: #7c2d12; white-space: pre-wrap; }
.muted { font-size: 12px; color: #64748b; }
@media (max-width: 820px) { .grid, .canvases { grid-template-columns: 1fr; } }
</style>
</head>
<body>
<main>
<section class="hero">
<h1>PIXLRelight in-browser WebGPU</h1>
<p>This path loads the INT4 MatMulNBits ONNX renderer directly in your browser with ONNX Runtime WebGPU. No Python CPU inference is used here.</p>
</section>
<section class="grid">
<div class="panel">
<div class="row">
<select id="sample">
<option value="https://huggingface.co/spaces/Reza2kn/PIXLRelight-ONNX-Light-Studio/resolve/main/samples/room00_source.png">Warm room</option>
<option value="https://huggingface.co/spaces/Reza2kn/PIXLRelight-ONNX-Light-Studio/resolve/main/samples/room01_source.png">Studio chair</option>
</select>
<button id="loadSample">Load sample</button>
<input id="file" type="file" accept="image/*" />
</div>
<p class="muted">Click the light map to move the selected light. This page generates 9-channel target intrinsics from the colored light field and runs the INT4 ONNX renderer on WebGPU.</p>
<div class="canvases">
<div>
<label>Source</label>
<canvas id="source" width="512" height="512"></canvas>
</div>
<div>
<label>Light map</label>
<canvas id="lightmap" width="512" height="512"></canvas>
</div>
</div>
</div>
<div class="panel">
<div class="row">
<label>Selected
<select id="selected">
<option value="0">Light 1</option>
<option value="1">Light 2</option>
</select>
</label>
<button id="add">Add</button>
<button id="remove">Remove</button>
<button id="run" class="primary">Load WebGPU model and relight</button>
</div>
<div class="row" style="margin-top:10px">
<label>X <input id="x" type="range" min="0" max="1" value="0.28" step="0.005"></label>
<label>Y <input id="y" type="range" min="0" max="1" value="0.22" step="0.005"></label>
<label>Radius <input id="radius" type="range" min="0.05" max="1" value="0.42" step="0.01"></label>
<label>Intensity <input id="intensity" type="range" min="0" max="2.5" value="1.25" step="0.01"></label>
<label>Color <input id="color" type="color" value="#ffd08a"></label>
</div>
<div style="margin-top:12px">
<label>Relit output</label>
<canvas id="output" width="512" height="512"></canvas>
</div>
<div id="status" class="status">Checking WebGPU support...</div>
</div>
</section>
</main>
<script type="module">
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.0/dist/ort.webgpu.min.mjs";
const MODEL = "https://huggingface.co/Reza2kn/PIXLRelight-ONNX/resolve/main/int4/pixlrelight_renderer_int4.onnx";
const DATA = "https://huggingface.co/Reza2kn/PIXLRelight-ONNX/resolve/main/int4/pixlrelight_renderer_int4.onnx.data";
const SIZE = 512;
const sourceCanvas = document.getElementById("source");
const lightCanvas = document.getElementById("lightmap");
const outputCanvas = document.getElementById("output");
const statusEl = document.getElementById("status");
const controls = {
selected: document.getElementById("selected"),
x: document.getElementById("x"),
y: document.getElementById("y"),
radius: document.getElementById("radius"),
intensity: document.getElementById("intensity"),
color: document.getElementById("color"),
};
let session = null;
let sourceFloat = new Float32Array(3 * SIZE * SIZE);
let sourceImage = null;
let lights = [
{x: 0.28, y: 0.22, radius: 0.42, intensity: 1.25, color: "#ffd08a"},
{x: 0.78, y: 0.62, radius: 0.32, intensity: 0.85, color: "#86b7ff"},
];
function setStatus(text) { statusEl.textContent = text; }
function hexToRgb01(hex) {
const value = hex.replace("#", "");
return [0, 2, 4].map(i => parseInt(value.slice(i, i + 2), 16) / 255);
}
function canvasToFloatCHW(canvas) {
const data = canvas.getContext("2d").getImageData(0, 0, SIZE, SIZE).data;
const out = new Float32Array(3 * SIZE * SIZE);
for (let y = 0; y < SIZE; y++) {
for (let x = 0; x < SIZE; x++) {
const src = (y * SIZE + x) * 4;
const dst = y * SIZE + x;
out[dst] = data[src] / 255;
out[SIZE * SIZE + dst] = data[src + 1] / 255;
out[2 * SIZE * SIZE + dst] = data[src + 2] / 255;
}
}
return out;
}
function drawSource(img) {
const ctx = sourceCanvas.getContext("2d");
const side = Math.min(img.naturalWidth || img.width, img.naturalHeight || img.height);
const sx = ((img.naturalWidth || img.width) - side) / 2;
const sy = ((img.naturalHeight || img.height) - side) / 2;
ctx.clearRect(0, 0, SIZE, SIZE);
ctx.drawImage(img, sx, sy, side, side, 0, 0, SIZE, SIZE);
sourceFloat = canvasToFloatCHW(sourceCanvas);
sourceImage = img;
}
async function loadImageUrl(url) {
const img = new Image();
img.crossOrigin = "anonymous";
img.src = url;
await img.decode();
drawSource(img);
}
function syncControls() {
const idx = Number(controls.selected.value);
const light = lights[idx];
controls.x.value = light.x;
controls.y.value = light.y;
controls.radius.value = light.radius;
controls.intensity.value = light.intensity;
controls.color.value = light.color;
}
function applyControls() {
const idx = Number(controls.selected.value);
lights[idx] = {
x: Number(controls.x.value),
y: Number(controls.y.value),
radius: Number(controls.radius.value),
intensity: Number(controls.intensity.value),
color: controls.color.value,
};
drawLightMap();
}
function drawLightMap() {
const ctx = lightCanvas.getContext("2d");
const img = ctx.createImageData(SIZE, SIZE);
const accum = new Float32Array(SIZE * SIZE * 3);
let maxv = 1;
for (let i = 0; i < SIZE * SIZE; i++) {
accum[i * 3] = 0.055;
accum[i * 3 + 1] = 0.055;
accum[i * 3 + 2] = 0.055;
}
for (const light of lights) {
const [r, g, b] = hexToRgb01(light.color);
const cx = light.x * (SIZE - 1);
const cy = light.y * (SIZE - 1);
const sigma = Math.max(1, light.radius * SIZE * 0.42);
const denom = 2 * sigma * sigma;
for (let y = 0; y < SIZE; y++) {
for (let x = 0; x < SIZE; x++) {
const falloff = Math.exp(-(((x - cx) ** 2 + (y - cy) ** 2) / denom)) * light.intensity;
const k = (y * SIZE + x) * 3;
accum[k] += falloff * r;
accum[k + 1] += falloff * g;
accum[k + 2] += falloff * b;
maxv = Math.max(maxv, accum[k], accum[k + 1], accum[k + 2]);
}
}
}
for (let i = 0; i < SIZE * SIZE; i++) {
img.data[i * 4] = Math.min(255, Math.round(accum[i * 3] / maxv * 255));
img.data[i * 4 + 1] = Math.min(255, Math.round(accum[i * 3 + 1] / maxv * 255));
img.data[i * 4 + 2] = Math.min(255, Math.round(accum[i * 3 + 2] / maxv * 255));
img.data[i * 4 + 3] = 255;
}
ctx.putImageData(img, 0, 0);
lights.forEach((light, idx) => {
const cx = light.x * (SIZE - 1);
const cy = light.y * (SIZE - 1);
ctx.beginPath();
ctx.arc(cx, cy, idx === Number(controls.selected.value) ? 11 : 8, 0, Math.PI * 2);
ctx.lineWidth = 3;
ctx.strokeStyle = "white";
ctx.stroke();
ctx.fillStyle = light.color;
ctx.fill();
});
}
function makeIntrinsics() {
const light = canvasToFloatCHW(lightCanvas);
const intr = new Float32Array(9 * SIZE * SIZE);
const plane = SIZE * SIZE;
for (let i = 0; i < plane; i++) {
for (let c = 0; c < 3; c++) {
const src = sourceFloat[c * plane + i];
const lm = light[c * plane + i];
intr[c * plane + i] = Math.min(1, src * 0.85 + 0.15);
intr[(c + 3) * plane + i] = lm;
intr[(c + 6) * plane + i] = Math.min(1, Math.max(0, (lm - 0.5) * 0.45 + src * 0.12 + 0.5));
}
}
return intr;
}
function drawCHW(floatData, canvas) {
const ctx = canvas.getContext("2d");
const img = ctx.createImageData(SIZE, SIZE);
const plane = SIZE * SIZE;
for (let i = 0; i < plane; i++) {
img.data[i * 4] = Math.max(0, Math.min(255, Math.round(floatData[i] * 255)));
img.data[i * 4 + 1] = Math.max(0, Math.min(255, Math.round(floatData[plane + i] * 255)));
img.data[i * 4 + 2] = Math.max(0, Math.min(255, Math.round(floatData[2 * plane + i] * 255)));
img.data[i * 4 + 3] = 255;
}
ctx.putImageData(img, 0, 0);
}
async function ensureSession() {
if (session) return session;
if (!("gpu" in navigator)) throw new Error("WebGPU is not available in this browser. Use current Chrome/Edge with WebGPU enabled.");
setStatus("Loading INT4 ONNX over the network. First load is large (~480 MB), then browser cache should help.");
ort.env.wasm.numThreads = Math.min(4, navigator.hardwareConcurrency || 4);
session = await ort.InferenceSession.create(MODEL, {
executionProviders: ["webgpu"],
externalData: [{path: "pixlrelight_renderer_int4.onnx.data", data: DATA}],
freeDimensionOverrides: {batch: 1},
graphOptimizationLevel: "all",
});
setStatus("WebGPU session ready. Provider requested: webgpu. Running relight next.");
return session;
}
async function run() {
const btn = document.getElementById("run");
btn.disabled = true;
const t0 = performance.now();
try {
const sess = await ensureSession();
const feeds = {
source_images: new ort.Tensor("float32", sourceFloat, [1, 3, SIZE, SIZE]),
target_intrinsics: new ort.Tensor("float32", makeIntrinsics(), [1, 9, SIZE, SIZE]),
};
const t1 = performance.now();
const results = await sess.run(feeds);
const t2 = performance.now();
drawCHW(results.rgb.data, outputCanvas);
setStatus(`Done on browser WebGPU path.\nModel/load time this click: ${((t1 - t0) / 1000).toFixed(2)}s\nInference: ${((t2 - t1) / 1000).toFixed(2)}s`);
} catch (err) {
console.error(err);
setStatus(`WebGPU run failed: ${err.message || err}\nNo Python CPU fallback was used on this page.`);
} finally {
btn.disabled = false;
}
}
document.getElementById("loadSample").onclick = () => loadImageUrl(document.getElementById("sample").value);
document.getElementById("file").onchange = async e => {
const file = e.target.files?.[0];
if (!file) return;
const img = new Image();
img.src = URL.createObjectURL(file);
await img.decode();
drawSource(img);
};
for (const el of [controls.x, controls.y, controls.radius, controls.intensity, controls.color]) el.oninput = applyControls;
controls.selected.onchange = syncControls;
document.getElementById("add").onclick = () => {
if (lights.length >= 5) return;
lights.push({x: 0.5, y: 0.35, radius: 0.28, intensity: 0.9, color: "#ffffff"});
const option = document.createElement("option");
option.value = String(lights.length - 1);
option.textContent = `Light ${lights.length}`;
controls.selected.appendChild(option);
controls.selected.value = String(lights.length - 1);
syncControls();
drawLightMap();
};
document.getElementById("remove").onclick = () => {
if (lights.length <= 1) return;
lights.splice(Number(controls.selected.value), 1);
controls.selected.innerHTML = "";
lights.forEach((_, i) => {
const option = document.createElement("option");
option.value = String(i);
option.textContent = `Light ${i + 1}`;
controls.selected.appendChild(option);
});
controls.selected.value = "0";
syncControls();
drawLightMap();
};
lightCanvas.onclick = e => {
const rect = lightCanvas.getBoundingClientRect();
const idx = Number(controls.selected.value);
lights[idx].x = (e.clientX - rect.left) / rect.width;
lights[idx].y = (e.clientY - rect.top) / rect.height;
syncControls();
drawLightMap();
};
document.getElementById("run").onclick = run;
if ("gpu" in navigator) setStatus("WebGPU detected. Load a sample, move lights, then run the INT4 model.");
else setStatus("WebGPU not detected in this browser. This page intentionally does not use CPU fallback.");
syncControls();
drawLightMap();
await loadImageUrl("https://huggingface.co/spaces/Reza2kn/PIXLRelight-ONNX-Light-Studio/resolve/main/samples/room00_source.png");
</script>
</body>
</html>