thookham commited on
Commit
142634b
·
verified ·
1 Parent(s): 0ef06dd

Add quantized.html for browser implementation

Browse files
Files changed (1) hide show
  1. quantized.html +248 -0
quantized.html ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>DeOldify Quantized (Browser)</title>
8
+ <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
9
+ <style>
10
+ body {
11
+ font-family: sans-serif;
12
+ max-width: 800px;
13
+ margin: 0 auto;
14
+ padding: 20px;
15
+ }
16
+
17
+ h1 {
18
+ text-align: center;
19
+ }
20
+
21
+ .container {
22
+ display: flex;
23
+ flex-direction: column;
24
+ align-items: center;
25
+ gap: 20px;
26
+ }
27
+
28
+ canvas {
29
+ border: 1px solid #ccc;
30
+ max-width: 100%;
31
+ }
32
+
33
+ .controls {
34
+ margin-bottom: 20px;
35
+ }
36
+
37
+ #status {
38
+ font-weight: bold;
39
+ margin-top: 10px;
40
+ }
41
+ </style>
42
+ </head>
43
+
44
+ <body>
45
+ <h1>DeOldify Quantized Model</h1>
46
+ <p style="text-align: center;">Faster, smaller download (61MB), slightly lower quality.</p>
47
+ <div class="container">
48
+ <div class="controls">
49
+ <input type="file" id="imageInput" accept="image/*" />
50
+ </div>
51
+ <div id="status">Select an image to start...</div>
52
+ <canvas id="outputCanvas"></canvas>
53
+ </div>
54
+
55
+ <script>
56
+ const MODEL_URL = "https://huggingface.co/thookham/DeOldify-on-Browser/resolve/main/deoldify-quant.onnx";
57
+ let session = null;
58
+
59
+ const preprocess = (input_imageData, width, height) => {
60
+ const floatArr = new Float32Array(width * height * 3);
61
+ let j = 0;
62
+ for (let i = 0; i < input_imageData.data.length; i += 4) {
63
+ // Normalize to 0-1 range as expected by DeOldify
64
+ floatArr[j] = input_imageData.data[i] / 255.0; // red
65
+ floatArr[j + 1] = input_imageData.data[i + 1] / 255.0; // green
66
+ floatArr[j + 2] = input_imageData.data[i + 2] / 255.0; // blue
67
+ j += 3;
68
+ }
69
+ return floatArr;
70
+ };
71
+
72
+ const postprocess = (tensor) => {
73
+ const channels = tensor.dims[1];
74
+ const height = tensor.dims[2];
75
+ const width = tensor.dims[3];
76
+ const imageData = new ImageData(width, height);
77
+ const data = imageData.data;
78
+ const tensorData = new Float32Array(tensor.data);
79
+
80
+ for (let h = 0; h < height; h++) {
81
+ for (let w = 0; w < width; w++) {
82
+ let rgb = [];
83
+ for (let c = 0; c < channels; c++) {
84
+ const tensorIndex = (c * height + h) * width + w;
85
+ const value = tensorData[tensorIndex];
86
+ // Denormalize: multiply by 255 and clamp
87
+ let val = value * 255.0;
88
+ if (val < 0) val = 0;
89
+ if (val > 255) val = 255;
90
+ rgb.push(Math.round(val));
91
+ }
92
+ data[(h * width + w) * 4] = rgb[0];
93
+ data[(h * width + w) * 4 + 1] = rgb[1];
94
+ data[(h * width + w) * 4 + 2] = rgb[2];
95
+ data[(h * width + w) * 4 + 3] = 255;
96
+ }
97
+ }
98
+ return imageData;
99
+ };
100
+
101
+ async function init() {
102
+ const status = document.getElementById('status');
103
+ status.innerText = "Checking cache...";
104
+ try {
105
+ let buffer;
106
+ const cacheName = 'deoldify-models-v1';
107
+
108
+ // Try to load from cache first
109
+ try {
110
+ const cache = await caches.open(cacheName);
111
+ const cachedResponse = await cache.match(MODEL_URL);
112
+
113
+ if (cachedResponse) {
114
+ status.innerText = "Loading model from cache...";
115
+ const blob = await cachedResponse.blob();
116
+ buffer = await blob.arrayBuffer();
117
+ }
118
+ } catch (e) {
119
+ console.warn("Cache API not supported or failed:", e);
120
+ }
121
+
122
+ // If not in cache, download it
123
+ if (!buffer) {
124
+ status.innerText = "Downloading model from Hugging Face... 0%";
125
+ const response = await fetch(MODEL_URL);
126
+ if (!response.ok) throw new Error(`Failed to fetch model: ${response.statusText}`);
127
+
128
+ const contentLength = response.headers.get('content-length');
129
+ const total = contentLength ? parseInt(contentLength, 10) : 0;
130
+ let loaded = 0;
131
+
132
+ const reader = response.body.getReader();
133
+ const chunks = [];
134
+
135
+ while (true) {
136
+ const { done, value } = await reader.read();
137
+ if (done) break;
138
+ chunks.push(value);
139
+ loaded += value.length;
140
+ if (total) {
141
+ const progress = Math.round((loaded / total) * 100);
142
+ status.innerText = `Downloading model from Hugging Face... ${progress}%`;
143
+ } else {
144
+ status.innerText = `Downloading model from Hugging Face... ${(loaded / 1024 / 1024).toFixed(1)} MB`;
145
+ }
146
+ }
147
+
148
+ const blob = new Blob(chunks);
149
+ buffer = await blob.arrayBuffer();
150
+
151
+ // Save to cache for next time
152
+ try {
153
+ const cache = await caches.open(cacheName);
154
+ await cache.put(MODEL_URL, new Response(blob));
155
+ console.log("Model saved to cache");
156
+ } catch (e) {
157
+ console.warn("Failed to save to cache:", e);
158
+ }
159
+ }
160
+
161
+ status.innerText = "Initializing session...";
162
+ session = await ort.InferenceSession.create(buffer);
163
+
164
+ status.innerText = "Model loaded! Select an image.";
165
+ console.log("Session created:", session);
166
+ } catch (e) {
167
+ status.innerText = "Error loading model: " + e.message;
168
+ console.error(e);
169
+ if (e.message.includes("Failed to fetch")) {
170
+ status.innerHTML += "<br><br>⚠️ <b>CORS Error Detected</b>: If you are running this file directly (file://), you must use a local server.<br>Run <code>python -m http.server 8000</code> in the terminal and visit <code>http://localhost:8000/quantized.html</code>";
171
+ }
172
+ }
173
+ }
174
+
175
+ document.getElementById('imageInput').addEventListener('change', async function (e) {
176
+ if (!session) {
177
+ await init();
178
+ }
179
+
180
+ const file = e.target.files[0];
181
+ if (!file) return;
182
+
183
+ // Validate image type
184
+ if (!file.type.startsWith('image/')) {
185
+ alert('Please select a valid image file.');
186
+ return;
187
+ }
188
+
189
+ const image = new Image();
190
+ const objectUrl = URL.createObjectURL(file);
191
+ image.src = objectUrl;
192
+
193
+ image.onload = async function () {
194
+ document.getElementById('status').innerText = "Processing...";
195
+
196
+ // Pre-processing canvas (256x256)
197
+ let canvas = document.createElement("canvas");
198
+ const size = 256;
199
+ canvas.width = size;
200
+ canvas.height = size;
201
+ let ctx = canvas.getContext("2d");
202
+ ctx.drawImage(image, 0, 0, size, size);
203
+
204
+ const input_img = ctx.getImageData(0, 0, size, size);
205
+ const test = preprocess(input_img, size, size);
206
+ const input = new ort.Tensor(new Float32Array(test), [1, 3, size, size]);
207
+
208
+ try {
209
+ const result = await session.run({ "input": input });
210
+ // Handle potential output name differences
211
+ const output = result["output"] || result["out"] || Object.values(result)[0];
212
+
213
+ if (!output) throw new Error("No output tensor found in model result");
214
+
215
+ const imgdata = postprocess(output);
216
+
217
+ // Render to output canvas
218
+ const outCanvas = document.getElementById('outputCanvas');
219
+ outCanvas.width = image.width;
220
+ outCanvas.height = image.height;
221
+ const outCtx = outCanvas.getContext('2d');
222
+
223
+ // Draw 256x256 result to temp canvas
224
+ const tempCanvas = document.createElement('canvas');
225
+ tempCanvas.width = size;
226
+ tempCanvas.height = size;
227
+ tempCanvas.getContext('2d').putImageData(imgdata, 0, 0);
228
+
229
+ // Resize to original
230
+ outCtx.drawImage(tempCanvas, 0, 0, image.width, image.height);
231
+
232
+ document.getElementById('status').innerText = "Done!";
233
+ } catch (err) {
234
+ document.getElementById('status').innerText = "Error processing: " + err.message;
235
+ console.error(err);
236
+ } finally {
237
+ // Clean up memory
238
+ URL.revokeObjectURL(objectUrl);
239
+ }
240
+ };
241
+ });
242
+
243
+ // Start loading immediately
244
+ init();
245
+ </script>
246
+ </body>
247
+
248
+ </html>