DeOldify / browser /quantized.html
thookham's picture
Initial commit for Hugging Face sync (Clean History)
e9f9fd3
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>DeOldify Quantized (Browser)</title>
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<style>
body {
font-family: sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
}
h1 {
text-align: center;
}
.container {
display: flex;
flex-direction: column;
align-items: center;
gap: 20px;
}
canvas {
border: 1px solid #ccc;
max-width: 100%;
}
.controls {
margin-bottom: 20px;
}
#status {
font-weight: bold;
margin-top: 10px;
}
</style>
</head>
<body>
<h1>DeOldify Quantized Model</h1>
<p style="text-align: center;">Faster, smaller download (61MB), slightly lower quality.</p>
<div class="container">
<div class="controls">
<input type="file" id="imageInput" accept="image/*" />
</div>
<div id="status">Select an image to start...</div>
<canvas id="outputCanvas"></canvas>
</div>
<script>
const MODEL_URL = "https://huggingface.co/thookham/DeOldify-on-Browser/resolve/main/deoldify-quant.onnx";
let session = null;
const preprocess = (input_imageData, width, height) => {
const floatArr = new Float32Array(width * height * 3);
let j = 0;
for (let i = 0; i < input_imageData.data.length; i += 4) {
// Normalize to 0-1 range as expected by DeOldify
floatArr[j] = input_imageData.data[i] / 255.0; // red
floatArr[j + 1] = input_imageData.data[i + 1] / 255.0; // green
floatArr[j + 2] = input_imageData.data[i + 2] / 255.0; // blue
j += 3;
}
return floatArr;
};
const postprocess = (tensor) => {
const channels = tensor.dims[1];
const height = tensor.dims[2];
const width = tensor.dims[3];
const imageData = new ImageData(width, height);
const data = imageData.data;
const tensorData = new Float32Array(tensor.data);
for (let h = 0; h < height; h++) {
for (let w = 0; w < width; w++) {
let rgb = [];
for (let c = 0; c < channels; c++) {
const tensorIndex = (c * height + h) * width + w;
const value = tensorData[tensorIndex];
// Denormalize: multiply by 255 and clamp
let val = value * 255.0;
if (val < 0) val = 0;
if (val > 255) val = 255;
rgb.push(Math.round(val));
}
data[(h * width + w) * 4] = rgb[0];
data[(h * width + w) * 4 + 1] = rgb[1];
data[(h * width + w) * 4 + 2] = rgb[2];
data[(h * width + w) * 4 + 3] = 255;
}
}
return imageData;
};
async function init() {
const status = document.getElementById('status');
status.innerText = "Checking cache...";
try {
let buffer;
const cacheName = 'deoldify-models-v1';
// Try to load from cache first
try {
const cache = await caches.open(cacheName);
const cachedResponse = await cache.match(MODEL_URL);
if (cachedResponse) {
status.innerText = "Loading model from cache...";
const blob = await cachedResponse.blob();
buffer = await blob.arrayBuffer();
}
} catch (e) {
console.warn("Cache API not supported or failed:", e);
}
// If not in cache, download it
if (!buffer) {
status.innerText = "Downloading model from Hugging Face... 0%";
const response = await fetch(MODEL_URL);
if (!response.ok) throw new Error(`Failed to fetch model: ${response.statusText}`);
const contentLength = response.headers.get('content-length');
const total = contentLength ? parseInt(contentLength, 10) : 0;
let loaded = 0;
const reader = response.body.getReader();
const chunks = [];
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.length;
if (total) {
const progress = Math.round((loaded / total) * 100);
status.innerText = `Downloading model from Hugging Face... ${progress}%`;
} else {
status.innerText = `Downloading model from Hugging Face... ${(loaded / 1024 / 1024).toFixed(1)} MB`;
}
}
const blob = new Blob(chunks);
buffer = await blob.arrayBuffer();
// Save to cache for next time
try {
const cache = await caches.open(cacheName);
await cache.put(MODEL_URL, new Response(blob));
console.log("Model saved to cache");
} catch (e) {
console.warn("Failed to save to cache:", e);
}
}
status.innerText = "Initializing session...";
session = await ort.InferenceSession.create(buffer);
status.innerText = "Model loaded! Select an image.";
console.log("Session created:", session);
} catch (e) {
status.innerText = "Error loading model: " + e.message;
console.error(e);
if (e.message.includes("Failed to fetch")) {
status.innerHTML += "<br><br>⚠️ <b>CORS Error Detected</b>: If you are running this file directly (file://), you must use a local server.<br>Run <code>python -m http.server 8000</code> in the terminal and visit <code>http://localhost:8000/quantized.html</code>";
}
}
}
document.getElementById('imageInput').addEventListener('change', async function (e) {
if (!session) {
await init();
}
const file = e.target.files[0];
if (!file) return;
// Validate image type
if (!file.type.startsWith('image/')) {
alert('Please select a valid image file.');
return;
}
const image = new Image();
const objectUrl = URL.createObjectURL(file);
image.src = objectUrl;
image.onload = async function () {
document.getElementById('status').innerText = "Processing...";
// Pre-processing canvas (256x256)
let canvas = document.createElement("canvas");
const size = 256;
canvas.width = size;
canvas.height = size;
let ctx = canvas.getContext("2d");
ctx.drawImage(image, 0, 0, size, size);
const input_img = ctx.getImageData(0, 0, size, size);
const test = preprocess(input_img, size, size);
const input = new ort.Tensor(new Float32Array(test), [1, 3, size, size]);
try {
const result = await session.run({ "input": input });
// Handle potential output name differences
const output = result["output"] || result["out"] || Object.values(result)[0];
if (!output) throw new Error("No output tensor found in model result");
const imgdata = postprocess(output);
// Render to output canvas
const outCanvas = document.getElementById('outputCanvas');
outCanvas.width = image.width;
outCanvas.height = image.height;
const outCtx = outCanvas.getContext('2d');
// Draw 256x256 result to temp canvas
const tempCanvas = document.createElement('canvas');
tempCanvas.width = size;
tempCanvas.height = size;
tempCanvas.getContext('2d').putImageData(imgdata, 0, 0);
// Resize to original
outCtx.drawImage(tempCanvas, 0, 0, image.width, image.height);
document.getElementById('status').innerText = "Done!";
} catch (err) {
document.getElementById('status').innerText = "Error processing: " + err.message;
console.error(err);
} finally {
// Clean up memory
URL.revokeObjectURL(objectUrl);
}
};
});
// Start loading immediately
init();
</script>
</body>
</html>