thookham commited on
Commit
6858afa
·
verified ·
1 Parent(s): 3c49d0a

Add artistic.html for browser implementation

Browse files
Files changed (1) hide show
  1. artistic.html +247 -0
artistic.html ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>DeOldify Artistic (Browser)</title>
8
+ <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
9
+ <style>
10
+ body {
11
+ font-family: sans-serif;
12
+ max-width: 800px;
13
+ margin: 0 auto;
14
+ padding: 20px;
15
+ }
16
+
17
+ h1 {
18
+ text-align: center;
19
+ }
20
+
21
+ .container {
22
+ display: flex;
23
+ flex-direction: column;
24
+ align-items: center;
25
+ gap: 20px;
26
+ }
27
+
28
+ canvas {
29
+ border: 1px solid #ccc;
30
+ max-width: 100%;
31
+ }
32
+
33
+ .controls {
34
+ margin-bottom: 20px;
35
+ }
36
+
37
+ #status {
38
+ font-weight: bold;
39
+ margin-top: 10px;
40
+ }
41
+ </style>
42
+ </head>
43
+
44
+ <body>
45
+ <h1>DeOldify Artistic Model</h1>
46
+ <div class="container">
47
+ <div class="controls">
48
+ <input type="file" id="imageInput" accept="image/*" />
49
+ </div>
50
+ <div id="status">Select an image to start...</div>
51
+ <canvas id="outputCanvas"></canvas>
52
+ </div>
53
+
54
+ <script>
55
+ const MODEL_URL = "https://huggingface.co/thookham/DeOldify-on-Browser/resolve/main/deoldify-art.onnx";
56
+ let session = null;
57
+
58
+ const preprocess = (input_imageData, width, height) => {
59
+ const floatArr = new Float32Array(width * height * 3);
60
+ let j = 0;
61
+ for (let i = 0; i < input_imageData.data.length; i += 4) {
62
+ // Normalize to 0-1 range as expected by DeOldify
63
+ floatArr[j] = input_imageData.data[i] / 255.0; // red
64
+ floatArr[j + 1] = input_imageData.data[i + 1] / 255.0; // green
65
+ floatArr[j + 2] = input_imageData.data[i + 2] / 255.0; // blue
66
+ j += 3;
67
+ }
68
+ return floatArr;
69
+ };
70
+
71
+ const postprocess = (tensor) => {
72
+ const channels = tensor.dims[1];
73
+ const height = tensor.dims[2];
74
+ const width = tensor.dims[3];
75
+ const imageData = new ImageData(width, height);
76
+ const data = imageData.data;
77
+ const tensorData = new Float32Array(tensor.data);
78
+
79
+ for (let h = 0; h < height; h++) {
80
+ for (let w = 0; w < width; w++) {
81
+ let rgb = [];
82
+ for (let c = 0; c < channels; c++) {
83
+ const tensorIndex = (c * height + h) * width + w;
84
+ const value = tensorData[tensorIndex];
85
+ // Denormalize: multiply by 255 and clamp
86
+ let val = value * 255.0;
87
+ if (val < 0) val = 0;
88
+ if (val > 255) val = 255;
89
+ rgb.push(Math.round(val));
90
+ }
91
+ data[(h * width + w) * 4] = rgb[0];
92
+ data[(h * width + w) * 4 + 1] = rgb[1];
93
+ data[(h * width + w) * 4 + 2] = rgb[2];
94
+ data[(h * width + w) * 4 + 3] = 255;
95
+ }
96
+ }
97
+ return imageData;
98
+ };
99
+
100
+ async function init() {
101
+ const status = document.getElementById('status');
102
+ status.innerText = "Checking cache...";
103
+ try {
104
+ let buffer;
105
+ const cacheName = 'deoldify-models-v1';
106
+
107
+ // Try to load from cache first
108
+ try {
109
+ const cache = await caches.open(cacheName);
110
+ const cachedResponse = await cache.match(MODEL_URL);
111
+
112
+ if (cachedResponse) {
113
+ status.innerText = "Loading model from cache...";
114
+ const blob = await cachedResponse.blob();
115
+ buffer = await blob.arrayBuffer();
116
+ }
117
+ } catch (e) {
118
+ console.warn("Cache API not supported or failed:", e);
119
+ }
120
+
121
+ // If not in cache, download it
122
+ if (!buffer) {
123
+ status.innerText = "Downloading model from Hugging Face... 0%";
124
+ const response = await fetch(MODEL_URL);
125
+ if (!response.ok) throw new Error(`Failed to fetch model: ${response.statusText}`);
126
+
127
+ const contentLength = response.headers.get('content-length');
128
+ const total = contentLength ? parseInt(contentLength, 10) : 0;
129
+ let loaded = 0;
130
+
131
+ const reader = response.body.getReader();
132
+ const chunks = [];
133
+
134
+ while (true) {
135
+ const { done, value } = await reader.read();
136
+ if (done) break;
137
+ chunks.push(value);
138
+ loaded += value.length;
139
+ if (total) {
140
+ const progress = Math.round((loaded / total) * 100);
141
+ status.innerText = `Downloading model from Hugging Face... ${progress}%`;
142
+ } else {
143
+ status.innerText = `Downloading model from Hugging Face... ${(loaded / 1024 / 1024).toFixed(1)} MB`;
144
+ }
145
+ }
146
+
147
+ const blob = new Blob(chunks);
148
+ buffer = await blob.arrayBuffer();
149
+
150
+ // Save to cache for next time
151
+ try {
152
+ const cache = await caches.open(cacheName);
153
+ await cache.put(MODEL_URL, new Response(blob));
154
+ console.log("Model saved to cache");
155
+ } catch (e) {
156
+ console.warn("Failed to save to cache:", e);
157
+ }
158
+ }
159
+
160
+ status.innerText = "Initializing session...";
161
+ session = await ort.InferenceSession.create(buffer);
162
+
163
+ status.innerText = "Model loaded! Select an image.";
164
+ console.log("Session created:", session);
165
+ } catch (e) {
166
+ status.innerText = "Error loading model: " + e.message;
167
+ console.error(e);
168
+ if (e.message.includes("Failed to fetch")) {
169
+ status.innerHTML += "<br><br>⚠️ <b>CORS Error Detected</b>: If you are running this file directly (file://), you must use a local server.<br>Run <code>python -m http.server 8000</code> in the terminal and visit <code>http://localhost:8000/artistic.html</code>";
170
+ }
171
+ }
172
+ }
173
+
174
+ document.getElementById('imageInput').addEventListener('change', async function (e) {
175
+ if (!session) {
176
+ await init();
177
+ }
178
+
179
+ const file = e.target.files[0];
180
+ if (!file) return;
181
+
182
+ // Validate image type
183
+ if (!file.type.startsWith('image/')) {
184
+ alert('Please select a valid image file.');
185
+ return;
186
+ }
187
+
188
+ const image = new Image();
189
+ const objectUrl = URL.createObjectURL(file);
190
+ image.src = objectUrl;
191
+
192
+ image.onload = async function () {
193
+ document.getElementById('status').innerText = "Processing...";
194
+
195
+ // Pre-processing canvas (256x256)
196
+ let canvas = document.createElement("canvas");
197
+ const size = 256;
198
+ canvas.width = size;
199
+ canvas.height = size;
200
+ let ctx = canvas.getContext("2d");
201
+ ctx.drawImage(image, 0, 0, size, size);
202
+
203
+ const input_img = ctx.getImageData(0, 0, size, size);
204
+ const test = preprocess(input_img, size, size);
205
+ const input = new ort.Tensor(new Float32Array(test), [1, 3, size, size]);
206
+
207
+ try {
208
+ const result = await session.run({ "input": input });
209
+ // Handle potential output name differences
210
+ const output = result["output"] || result["out"] || Object.values(result)[0];
211
+
212
+ if (!output) throw new Error("No output tensor found in model result");
213
+
214
+ const imgdata = postprocess(output);
215
+
216
+ // Render to output canvas
217
+ const outCanvas = document.getElementById('outputCanvas');
218
+ outCanvas.width = image.width;
219
+ outCanvas.height = image.height;
220
+ const outCtx = outCanvas.getContext('2d');
221
+
222
+ // Draw 256x256 result to temp canvas
223
+ const tempCanvas = document.createElement('canvas');
224
+ tempCanvas.width = size;
225
+ tempCanvas.height = size;
226
+ tempCanvas.getContext('2d').putImageData(imgdata, 0, 0);
227
+
228
+ // Resize to original
229
+ outCtx.drawImage(tempCanvas, 0, 0, image.width, image.height);
230
+
231
+ document.getElementById('status').innerText = "Done!";
232
+ } catch (err) {
233
+ document.getElementById('status').innerText = "Error processing: " + err.message;
234
+ console.error(err);
235
+ } finally {
236
+ // Clean up memory
237
+ URL.revokeObjectURL(objectUrl);
238
+ }
239
+ };
240
+ });
241
+
242
+ // Start loading immediately
243
+ init();
244
+ </script>
245
+ </body>
246
+
247
+ </html>