|
|
<!DOCTYPE html> |
|
|
<html lang="en"> |
|
|
|
|
|
<head> |
|
|
<meta charset="UTF-8" /> |
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> |
|
|
<title>DeOldify Quantized (Browser)</title> |
|
|
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script> |
|
|
<style> |
|
|
body { |
|
|
font-family: sans-serif; |
|
|
max-width: 800px; |
|
|
margin: 0 auto; |
|
|
padding: 20px; |
|
|
} |
|
|
|
|
|
h1 { |
|
|
text-align: center; |
|
|
} |
|
|
|
|
|
.container { |
|
|
display: flex; |
|
|
flex-direction: column; |
|
|
align-items: center; |
|
|
gap: 20px; |
|
|
} |
|
|
|
|
|
canvas { |
|
|
border: 1px solid #ccc; |
|
|
max-width: 100%; |
|
|
} |
|
|
|
|
|
.controls { |
|
|
margin-bottom: 20px; |
|
|
} |
|
|
|
|
|
#status { |
|
|
font-weight: bold; |
|
|
margin-top: 10px; |
|
|
} |
|
|
</style> |
|
|
</head> |
|
|
|
|
|
<body> |
|
|
<h1>DeOldify Quantized Model</h1> |
|
|
<p style="text-align: center;">Faster, smaller download (61MB), slightly lower quality.</p> |
|
|
<div class="container"> |
|
|
<div class="controls"> |
|
|
<input type="file" id="imageInput" accept="image/*" /> |
|
|
</div> |
|
|
<div id="status">Select an image to start...</div> |
|
|
<canvas id="outputCanvas"></canvas> |
|
|
</div> |
|
|
|
|
|
<script> |
|
|
const MODEL_URL = "https://huggingface.co/thookham/DeOldify-on-Browser/resolve/main/deoldify-quant.onnx"; |
|
|
let session = null; |
|
|
|
|
|
const preprocess = (input_imageData, width, height) => { |
|
|
const floatArr = new Float32Array(width * height * 3); |
|
|
let j = 0; |
|
|
for (let i = 0; i < input_imageData.data.length; i += 4) { |
|
|
|
|
|
floatArr[j] = input_imageData.data[i] / 255.0; |
|
|
floatArr[j + 1] = input_imageData.data[i + 1] / 255.0; |
|
|
floatArr[j + 2] = input_imageData.data[i + 2] / 255.0; |
|
|
j += 3; |
|
|
} |
|
|
return floatArr; |
|
|
}; |
|
|
|
|
|
const postprocess = (tensor) => { |
|
|
const channels = tensor.dims[1]; |
|
|
const height = tensor.dims[2]; |
|
|
const width = tensor.dims[3]; |
|
|
const imageData = new ImageData(width, height); |
|
|
const data = imageData.data; |
|
|
const tensorData = new Float32Array(tensor.data); |
|
|
|
|
|
for (let h = 0; h < height; h++) { |
|
|
for (let w = 0; w < width; w++) { |
|
|
let rgb = []; |
|
|
for (let c = 0; c < channels; c++) { |
|
|
const tensorIndex = (c * height + h) * width + w; |
|
|
const value = tensorData[tensorIndex]; |
|
|
|
|
|
let val = value * 255.0; |
|
|
if (val < 0) val = 0; |
|
|
if (val > 255) val = 255; |
|
|
rgb.push(Math.round(val)); |
|
|
} |
|
|
data[(h * width + w) * 4] = rgb[0]; |
|
|
data[(h * width + w) * 4 + 1] = rgb[1]; |
|
|
data[(h * width + w) * 4 + 2] = rgb[2]; |
|
|
data[(h * width + w) * 4 + 3] = 255; |
|
|
} |
|
|
} |
|
|
return imageData; |
|
|
}; |
|
|
|
|
|
async function init() { |
|
|
const status = document.getElementById('status'); |
|
|
status.innerText = "Checking cache..."; |
|
|
try { |
|
|
let buffer; |
|
|
const cacheName = 'deoldify-models-v1'; |
|
|
|
|
|
|
|
|
try { |
|
|
const cache = await caches.open(cacheName); |
|
|
const cachedResponse = await cache.match(MODEL_URL); |
|
|
|
|
|
if (cachedResponse) { |
|
|
status.innerText = "Loading model from cache..."; |
|
|
const blob = await cachedResponse.blob(); |
|
|
buffer = await blob.arrayBuffer(); |
|
|
} |
|
|
} catch (e) { |
|
|
console.warn("Cache API not supported or failed:", e); |
|
|
} |
|
|
|
|
|
|
|
|
if (!buffer) { |
|
|
status.innerText = "Downloading model from Hugging Face... 0%"; |
|
|
const response = await fetch(MODEL_URL); |
|
|
if (!response.ok) throw new Error(`Failed to fetch model: ${response.statusText}`); |
|
|
|
|
|
const contentLength = response.headers.get('content-length'); |
|
|
const total = contentLength ? parseInt(contentLength, 10) : 0; |
|
|
let loaded = 0; |
|
|
|
|
|
const reader = response.body.getReader(); |
|
|
const chunks = []; |
|
|
|
|
|
while (true) { |
|
|
const { done, value } = await reader.read(); |
|
|
if (done) break; |
|
|
chunks.push(value); |
|
|
loaded += value.length; |
|
|
if (total) { |
|
|
const progress = Math.round((loaded / total) * 100); |
|
|
status.innerText = `Downloading model from Hugging Face... ${progress}%`; |
|
|
} else { |
|
|
status.innerText = `Downloading model from Hugging Face... ${(loaded / 1024 / 1024).toFixed(1)} MB`; |
|
|
} |
|
|
} |
|
|
|
|
|
const blob = new Blob(chunks); |
|
|
buffer = await blob.arrayBuffer(); |
|
|
|
|
|
|
|
|
try { |
|
|
const cache = await caches.open(cacheName); |
|
|
await cache.put(MODEL_URL, new Response(blob)); |
|
|
console.log("Model saved to cache"); |
|
|
} catch (e) { |
|
|
console.warn("Failed to save to cache:", e); |
|
|
} |
|
|
} |
|
|
|
|
|
status.innerText = "Initializing session..."; |
|
|
session = await ort.InferenceSession.create(buffer); |
|
|
|
|
|
status.innerText = "Model loaded! Select an image."; |
|
|
console.log("Session created:", session); |
|
|
} catch (e) { |
|
|
status.innerText = "Error loading model: " + e.message; |
|
|
console.error(e); |
|
|
if (e.message.includes("Failed to fetch")) { |
|
|
status.innerHTML += "<br><br>⚠️ <b>CORS Error Detected</b>: If you are running this file directly (file://), you must use a local server.<br>Run <code>python -m http.server 8000</code> in the terminal and visit <code>http://localhost:8000/quantized.html</code>"; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
document.getElementById('imageInput').addEventListener('change', async function (e) { |
|
|
if (!session) { |
|
|
await init(); |
|
|
} |
|
|
|
|
|
const file = e.target.files[0]; |
|
|
if (!file) return; |
|
|
|
|
|
|
|
|
if (!file.type.startsWith('image/')) { |
|
|
alert('Please select a valid image file.'); |
|
|
return; |
|
|
} |
|
|
|
|
|
const image = new Image(); |
|
|
const objectUrl = URL.createObjectURL(file); |
|
|
image.src = objectUrl; |
|
|
|
|
|
image.onload = async function () { |
|
|
document.getElementById('status').innerText = "Processing..."; |
|
|
|
|
|
|
|
|
let canvas = document.createElement("canvas"); |
|
|
const size = 256; |
|
|
canvas.width = size; |
|
|
canvas.height = size; |
|
|
let ctx = canvas.getContext("2d"); |
|
|
ctx.drawImage(image, 0, 0, size, size); |
|
|
|
|
|
const input_img = ctx.getImageData(0, 0, size, size); |
|
|
const test = preprocess(input_img, size, size); |
|
|
const input = new ort.Tensor(new Float32Array(test), [1, 3, size, size]); |
|
|
|
|
|
try { |
|
|
const result = await session.run({ "input": input }); |
|
|
|
|
|
const output = result["output"] || result["out"] || Object.values(result)[0]; |
|
|
|
|
|
if (!output) throw new Error("No output tensor found in model result"); |
|
|
|
|
|
const imgdata = postprocess(output); |
|
|
|
|
|
|
|
|
const outCanvas = document.getElementById('outputCanvas'); |
|
|
outCanvas.width = image.width; |
|
|
outCanvas.height = image.height; |
|
|
const outCtx = outCanvas.getContext('2d'); |
|
|
|
|
|
|
|
|
const tempCanvas = document.createElement('canvas'); |
|
|
tempCanvas.width = size; |
|
|
tempCanvas.height = size; |
|
|
tempCanvas.getContext('2d').putImageData(imgdata, 0, 0); |
|
|
|
|
|
|
|
|
outCtx.drawImage(tempCanvas, 0, 0, image.width, image.height); |
|
|
|
|
|
document.getElementById('status').innerText = "Done!"; |
|
|
} catch (err) { |
|
|
document.getElementById('status').innerText = "Error processing: " + err.message; |
|
|
console.error(err); |
|
|
} finally { |
|
|
|
|
|
URL.revokeObjectURL(objectUrl); |
|
|
} |
|
|
}; |
|
|
}); |
|
|
|
|
|
|
|
|
init(); |
|
|
</script> |
|
|
</body> |
|
|
|
|
|
</html> |