sahirp commited on
Commit
5bcc422
·
verified ·
1 Parent(s): 4f7a5ba

Add v15 ONNX + WebGPU viewer bundle

Browse files
.gitattributes CHANGED
@@ -43,3 +43,8 @@ koi/v14_scare/onnx/ae_decode_fp16.onnx.data filter=lfs diff=lfs merge=lfs -text
43
  koi/v14_scare/onnx/ae_encode.onnx.data filter=lfs diff=lfs merge=lfs -text
44
  koi/v14_scare/onnx/edm_denoise.onnx.data filter=lfs diff=lfs merge=lfs -text
45
  koi/v14_scare/onnx/edm_denoise_fp16.onnx.data filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
43
  koi/v14_scare/onnx/ae_encode.onnx.data filter=lfs diff=lfs merge=lfs -text
44
  koi/v14_scare/onnx/edm_denoise.onnx.data filter=lfs diff=lfs merge=lfs -text
45
  koi/v14_scare/onnx/edm_denoise_fp16.onnx.data filter=lfs diff=lfs merge=lfs -text
46
+ koi/v15_scare/onnx/ae_decode.onnx.data filter=lfs diff=lfs merge=lfs -text
47
+ koi/v15_scare/onnx/ae_decode_fp16.onnx.data filter=lfs diff=lfs merge=lfs -text
48
+ koi/v15_scare/onnx/ae_encode.onnx.data filter=lfs diff=lfs merge=lfs -text
49
+ koi/v15_scare/onnx/edm_denoise.onnx.data filter=lfs diff=lfs merge=lfs -text
50
+ koi/v15_scare/onnx/edm_denoise_fp16.onnx.data filter=lfs diff=lfs merge=lfs -text
koi/v15_scare/onnx/ae_decode.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1eece1505ae09c2cce4f4237793691d59b4875e30a7352a536688397e263872f
3
+ size 255916
koi/v15_scare/onnx/ae_decode.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecff1f0217e2a280b226a88003ab4ed0f30a55cb6fe71b11140391e565a48f82
3
+ size 25755648
koi/v15_scare/onnx/ae_decode_fp16.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:786edf9be7de763efb72453983e5dcc8a745df6c6306cdebce76aa5b1f957676
3
+ size 271865
koi/v15_scare/onnx/ae_decode_fp16.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f3f932c4167400121f69ff82e4f6276b199964b794c61174fc65638a37d414f
3
+ size 12836224
koi/v15_scare/onnx/ae_encode.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44de0def6223faa4d0c529547a6da9e13b3aa55309fa9a466a86a1dccbc7866a
3
+ size 251463
koi/v15_scare/onnx/ae_encode.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3229b442f69daaa1d46ed49ebcd446601b9321b396dffdb781f6ce8d5610daff
3
+ size 28704768
koi/v15_scare/onnx/ae_weights_only.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfda8641f269b32936a3933c9ba5422db8f008997242d29a94ab62fe46f916ca
3
+ size 54450037
koi/v15_scare/onnx/edm_denoise.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8572fb58f0c6041d43ee2bc17dded739797d9bc538cb114bd863c8fc842b218e
3
+ size 568476
koi/v15_scare/onnx/edm_denoise.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dbd8ce53b951b109e1e4229d28fb4a776e775f3dfdd3bcbccacb97c04be9108
3
+ size 190906368
koi/v15_scare/onnx/edm_denoise_fp16.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65a021050912128fd6eacfbb5f5f5044bc8f558cfd7177d33cd40186010e36df
3
+ size 594378
koi/v15_scare/onnx/edm_denoise_fp16.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dd037798567c40f972c34f67b51d1a8448d486820fcfc43de24aa72ef8a201f
3
+ size 95421952
koi/v15_scare/onnx/edm_weights_only.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fc4599b69c78e3a44584d3020e90da1606d929216a3123474b8a9b0aede34e0
3
+ size 190953835
koi/v15_scare/onnx/index.html ADDED
@@ -0,0 +1,1194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>Koi Pond - ONNX WebGPU Demo</title>
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
6
+ <style>
7
+ * { box-sizing: border-box; }
8
+ body {
9
+ margin: 0; padding: 10px;
10
+ background: #000;
11
+ display: flex; justify-content: center; align-items: center;
12
+ min-height: 100vh; flex-direction: column;
13
+ font-family: monospace;
14
+ }
15
+ .canvas-container { width: 100%; max-width: 512px; }
16
+ canvas {
17
+ width: 100%; height: auto; aspect-ratio: 1;
18
+ border: 2px solid #ccc; cursor: crosshair;
19
+ image-rendering: pixelated; display: block;
20
+ background: #000;
21
+ }
22
+ .info { color: #aaa; margin-top: 10px; font-size: 12px; text-align: center; }
23
+ .controls { color: #ddd; margin-top: 10px; font-size: 12px; text-align: center; }
24
+ button { margin: 0 5px; padding: 5px 10px; }
25
+ .status { color: #888; margin-top: 10px; font-size: 11px; max-width: 512px; text-align: center; }
26
+ .diag { color: #9aa; margin-top: 8px; font-size: 11px; max-width: 700px; text-align: center; white-space: pre-wrap; }
27
+ .progress { width: 100%; max-width: 300px; height: 20px; margin: 10px auto; }
28
+ @media (max-width: 540px) {
29
+ body { padding: 0; }
30
+ .canvas-container { max-width: 100%; }
31
+ canvas { border-width: 0; }
32
+ }
33
+ </style>
34
+ </head>
35
+ <body>
36
+ <div class="canvas-container">
37
+ <canvas id="c" width="256" height="256"></canvas>
38
+ </div>
39
+ <div class="info">
40
+ fps: <span id="fps">0</span> |
41
+ frame: <span id="frame">0</span> |
42
+ backend: <span id="backend">?</span> |
43
+ precision: <span id="precision">?</span>
44
+ <span style="opacity:0.7">| click to tap</span>
45
+ </div>
46
+ <div class="controls">
47
+ <button onclick="resetWorld()">Reset</button>
48
+ <label>Mode:
49
+ <select id="mode">
50
+ <option value="onnx">local (onnx)</option>
51
+ <option value="ws">server (ws)</option>
52
+ </select>
53
+ </label>
54
+ <label>CFG: <input type="range" id="cfg" min="1" max="2.5" step="0.1" value="1.5"></label>
55
+ <span id="cfgVal">1.5</span>
56
+ <button onclick="togglePause()" id="pauseBtn">Pause</button>
57
+ </div>
58
+ <div class="status" id="status">Loading ONNX Runtime...</div>
59
+ <div class="diag" id="diag"></div>
60
+ <progress class="progress" id="progress" value="0" max="100" style="display:none;"></progress>
61
+
62
+ <!-- onnxruntime-web (WebGPU-enabled build) from CDN -->
63
+ <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.2/dist/ort.webgpu.min.js"></script>
64
+
65
+ <script>
66
+ // =====================
67
+ // config
68
+ // =====================
69
+ const CONFIG = {
70
+ imgSize: 256,
71
+ latentSize: 32,
72
+ latentCh: 8,
73
+ history: 4,
74
+ condSigmaCh: 4, // overridden by model_meta.json
75
+ sigmaData: 0.125, // overridden by model_meta.json
76
+ eulerSteps: 4,
77
+ rho: 7.0, // Karras schedule rho (matches koi.train.model_latent_edm)
78
+ sigmaMax: 1.0, // tap frames
79
+ sigmaMaxNoTap: 1.0,
80
+ sigmaMin: 0.002,
81
+ noiseScale: 1.0,
82
+ noiseScaleNoTap: 1.0,
83
+ tapHoldFrames: 2,
84
+ tapStrength: 1.0,
85
+ heatmapSigmaPx: 1.25,
86
+ // Playback limiter for ONNX mode. The world model is "1 step = 1 frame";
87
+ // running faster changes perceived world speed. Set `playbackFps=0` for unlimited.
88
+ playbackFps: 15,
89
+ };
90
+
91
+ // =====================
92
+ // globals
93
+ // =====================
94
+ let aeDecode = null;
95
+ let edmDenoise = null;
96
+ let aeInputName = 'z';
97
+ let aeOutputName = 'img';
98
+ let edmInputX = 'x_noisy';
99
+ let edmInputSigma = 'sigma';
100
+ let edmInputCond = 'cond';
101
+ let edmOutputName = 'x_hat';
102
+ let edmInputMeta = null;
103
+ let aeInputMeta = null;
104
+ let zHistory = null; // Float32Array[history * latentCh * latentSize * latentSize]
105
+ let tapQueue = [];
106
+ let lastTap = { x: -1, y: -1 };
107
+ let tapHoldLeft = 0;
108
+ let paused = true;
109
+ let frameCount = 0;
110
+ let totalFrames = 0;
111
+ let lastFpsTime = Date.now();
112
+ let running = false;
113
+ let lastStepStartMs = 0;
114
+ let ortBackend = 'unknown';
115
+ let ortPrecision = 'fp32';
116
+ let _imageData = null;
117
+ let GRAPH_CAPTURE = false;
118
+ let FORCE_FP16 = false;
119
+ let FORCE_SIGMA_SCALAR = false;
120
+ let FALLBACK_ON_ERROR = true;
121
+ let fallbackAttempted = false;
122
+ // base url for model/asset files (set by ?hf=1 to load from huggingface cdn)
123
+ // Tip: pin to a commit with ?hf_rev=<commit_sha> to avoid breaking changes on main.
124
+ const HF_BASE = 'https://huggingface.co/sahirp/tap-conditioned-world-model/resolve/main/koi/v15_scare/onnx/';
125
+ let MODEL_BASE_URL = null; // null = relative to current page
126
+ // WebSocket backend globals (for ?mode=ws)
127
+ let ws = null;
128
+ let wsServerMode = 'unknown'; // 'push' | 'pull' | 'unknown'
129
+ let pendingTap = null; // {x,y,t_ms}
130
+
131
+ const canvas = document.getElementById('c');
132
+ const ctx = canvas.getContext('2d');
133
+ const fpsEl = document.getElementById('fps');
134
+ const frameEl = document.getElementById('frame');
135
+ const cfgSlider = document.getElementById('cfg');
136
+ const cfgValEl = document.getElementById('cfgVal');
137
+ const statusEl = document.getElementById('status');
138
+ const progressEl = document.getElementById('progress');
139
+ const pauseBtn = document.getElementById('pauseBtn');
140
+ const backendEl = document.getElementById('backend');
141
+ const precisionEl = document.getElementById('precision');
142
+ const modeEl = document.getElementById('mode');
143
+
144
+ cfgSlider.oninput = () => { cfgValEl.textContent = cfgSlider.value; };
145
+
146
+ function setQueryParam(key, value) {
147
+ const u = new URL(location.href);
148
+ if (value === null || value === undefined) {
149
+ u.searchParams.delete(key);
150
+ } else {
151
+ u.searchParams.set(key, String(value));
152
+ }
153
+ return u.toString();
154
+ }
155
+
156
+ // =====================
157
+ // tensor helpers
158
+ // =====================
159
+ // note: fp16 models use keep_io_types=True, so inputs/outputs stay fp32.
160
+ // no type conversion needed - just load the different .onnx files.
161
+ function createTensor(data, dims) {
162
+ return new ort.Tensor('float32', data, dims);
163
+ }
164
+
165
+ async function getTensorData(tensor) {
166
+ if (tensor && tensor.data && ArrayBuffer.isView(tensor.data)) {
167
+ return tensor.data;
168
+ } else if (tensor && typeof tensor.getData === 'function') {
169
+ return await tensor.getData();
170
+ }
171
+ return tensor?.data;
172
+ }
173
+
174
+ function firstOutput(results, preferredName) {
175
+ if (preferredName && results[preferredName]) return results[preferredName];
176
+ const keys = Object.keys(results || {});
177
+ if (keys.length === 0) throw new Error('ONNX run returned no outputs');
178
+ return results[keys[0]];
179
+ }
180
+
181
+ function pickName(names, candidates, fallbackIndex = 0) {
182
+ for (const c of candidates) {
183
+ const found = names.find(n => n.toLowerCase().includes(c));
184
+ if (found) return found;
185
+ }
186
+ return names[fallbackIndex] || names[0];
187
+ }
188
+
189
+ function getMetaDims(meta) {
190
+ try {
191
+ return meta?.dimensions || meta?.dim || null;
192
+ } catch (e) {
193
+ return null;
194
+ }
195
+ }
196
+
197
+ function makeSigmaTensor(sigma) {
198
+ if (FORCE_SIGMA_SCALAR) {
199
+ return createTensor(new Float32Array([sigma]), []);
200
+ }
201
+ const meta = edmInputMeta && edmInputMeta[edmInputSigma];
202
+ const dims = getMetaDims(meta);
203
+ if (Array.isArray(dims) && dims.length === 0) {
204
+ return createTensor(new Float32Array([sigma]), []);
205
+ }
206
+ return createTensor(new Float32Array([sigma]), [1]);
207
+ }
208
+
209
+ function zeros(shape) {
210
+ const size = shape.reduce((a, b) => a * b, 1);
211
+ return new Float32Array(size);
212
+ }
213
+
214
+ function randn(shape) {
215
+ const size = shape.reduce((a, b) => a * b, 1);
216
+ const arr = new Float32Array(size);
217
+ for (let i = 0; i < size; i += 2) {
218
+ const u1 = Math.random();
219
+ const u2 = Math.random();
220
+ const r = Math.sqrt(-2 * Math.log(u1 + 1e-10));
221
+ arr[i] = r * Math.cos(2 * Math.PI * u2);
222
+ if (i + 1 < size) arr[i + 1] = r * Math.sin(2 * Math.PI * u2);
223
+ }
224
+ return arr;
225
+ }
226
+
227
+ // =====================
228
+ // action heatmap
229
+ // =====================
230
+ function _gaussAtPixel(px, py, cx, cy, sigmaPx) {
231
+ const dx = (px + 0.5) - cx;
232
+ const dy = (py + 0.5) - cy;
233
+ const distSq = dx * dx + dy * dy;
234
+ return Math.exp(-distSq / (2 * sigmaPx * sigmaPx));
235
+ }
236
+
237
+ function _bilinearSampleGaussian(ox, oy, cx, cy, sigmaPx, inSize, outSize) {
238
+ // Match torch.nn.functional.interpolate(..., mode="bilinear", align_corners=False)
239
+ const scale = inSize / outSize;
240
+ const inX = (ox + 0.5) * scale - 0.5;
241
+ const inY = (oy + 0.5) * scale - 0.5;
242
+ let x0 = Math.floor(inX);
243
+ let y0 = Math.floor(inY);
244
+ let x1 = x0 + 1;
245
+ let y1 = y0 + 1;
246
+ const wx1 = inX - x0;
247
+ const wy1 = inY - y0;
248
+ const wx0 = 1.0 - wx1;
249
+ const wy0 = 1.0 - wy1;
250
+
251
+ // clamp to valid pixels
252
+ x0 = Math.max(0, Math.min(inSize - 1, x0));
253
+ x1 = Math.max(0, Math.min(inSize - 1, x1));
254
+ y0 = Math.max(0, Math.min(inSize - 1, y0));
255
+ y1 = Math.max(0, Math.min(inSize - 1, y1));
256
+
257
+ const v00 = _gaussAtPixel(x0, y0, cx, cy, sigmaPx);
258
+ const v10 = _gaussAtPixel(x1, y0, cx, cy, sigmaPx);
259
+ const v01 = _gaussAtPixel(x0, y1, cx, cy, sigmaPx);
260
+ const v11 = _gaussAtPixel(x1, y1, cx, cy, sigmaPx);
261
+ return (v00 * wx0 + v10 * wx1) * wy0 + (v01 * wx0 + v11 * wx1) * wy1;
262
+ }
263
+
264
+ function makeActionMap(tapX, tapY) {
265
+ const { latentSize, imgSize, heatmapSigmaPx } = CONFIG;
266
+ const actionMap = new Float32Array(2 * latentSize * latentSize);
267
+
268
+ if (tapX < 0 || tapY < 0) {
269
+ return actionMap; // all zeros
270
+ }
271
+
272
+ const sigmaPx = Math.max(heatmapSigmaPx, 1e-6);
273
+ const cx = tapX * imgSize;
274
+ const cy = tapY * imgSize;
275
+
276
+ // channel 0: heatmap, channel 1: tap_flag
277
+ for (let py = 0; py < latentSize; py++) {
278
+ for (let px = 0; px < latentSize; px++) {
279
+ const heat0 = _bilinearSampleGaussian(px, py, cx, cy, sigmaPx, imgSize, latentSize);
280
+ const heat = heat0 * CONFIG.tapStrength;
281
+ const idx = py * latentSize + px;
282
+ actionMap[idx] = Math.min(1, Math.max(0, heat)); // heatmap
283
+ actionMap[latentSize * latentSize + idx] = 1.0; // tap_flag
284
+ }
285
+ }
286
+ return actionMap;
287
+ }
288
+
289
+ // =====================
290
+ // build conditioning tensor
291
+ // =====================
292
+ function buildCond(actionMap) {
293
+ const { latentSize, latentCh, history, condSigmaCh } = CONFIG;
294
+ const spatialSize = latentSize * latentSize;
295
+ const condCh = history * latentCh + 2 + condSigmaCh;
296
+ const cond = new Float32Array(condCh * spatialSize);
297
+
298
+ // copy history latents
299
+ const histSize = history * latentCh * spatialSize;
300
+ for (let i = 0; i < histSize; i++) {
301
+ cond[i] = zHistory[i];
302
+ }
303
+
304
+ // copy action map (2 channels)
305
+ const actionOffset = histSize;
306
+ for (let i = 0; i < 2 * spatialSize; i++) {
307
+ cond[actionOffset + i] = actionMap[i];
308
+ }
309
+
310
+ // sigma channels (zeros at inference = clean context)
311
+ // already zeros from initialization
312
+
313
+ return cond;
314
+ }
315
+
316
+ // =====================
317
+ // euler sampler
318
+ // =====================
319
+ function karrasSigmas(num, sigmaMin, sigmaMax, rho) {
320
+ if (num < 2) throw new Error('karrasSigmas: num must be >= 2');
321
+ const sigmas = new Array(num);
322
+ const minInv = Math.pow(sigmaMin, 1.0 / rho);
323
+ const maxInv = Math.pow(sigmaMax, 1.0 / rho);
324
+ for (let i = 0; i < num; i++) {
325
+ const t = (num === 1) ? 0.0 : (i / (num - 1));
326
+ const s = maxInv + t * (minInv - maxInv);
327
+ sigmas[i] = Math.pow(s, rho);
328
+ }
329
+ return sigmas;
330
+ }
331
+
332
+ async function eulerSample(cond, condUncond, cfgScale, isTap) {
333
+ const { latentCh, latentSize, eulerSteps, sigmaMax, sigmaMaxNoTap, sigmaMin, history, rho, noiseScale, noiseScaleNoTap } = CONFIG;
334
+ const sigmaMaxFrame = isTap ? sigmaMax : sigmaMaxNoTap;
335
+ const noiseScaleFrame = isTap ? noiseScale : noiseScaleNoTap;
336
+ const shape = [1, latentCh, latentSize, latentSize];
337
+ const spatialSize = latentSize * latentSize;
338
+ const condCh = cond.length / spatialSize;
339
+ const condTensor = createTensor(cond, [1, condCh, latentSize, latentSize]);
340
+ const condUncondTensor = (cfgScale !== 1.0) ? createTensor(condUncond, [1, condCh, latentSize, latentSize]) : null;
341
+
342
+ // initialize from last history frame + sigma_max noise (img2img-style)
343
+ const lastFrameOffset = (history - 1) * latentCh * spatialSize;
344
+ let x = new Float32Array(latentCh * spatialSize);
345
+ const noise = randn([latentCh * spatialSize]);
346
+ for (let i = 0; i < x.length; i++) {
347
+ x[i] = zHistory[lastFrameOffset + i] + sigmaMaxFrame * noise[i] * noiseScaleFrame;
348
+ }
349
+
350
+ // Karras noise schedule (matches koi.train.model_latent_edm.karras_sigmas)
351
+ const sigmas = karrasSigmas(eulerSteps + 1, sigmaMin, sigmaMaxFrame, rho);
352
+
353
+ // euler steps
354
+ for (let step = 0; step < eulerSteps; step++) {
355
+ const sigma = sigmas[step];
356
+ const sigmaNext = sigmas[step + 1];
357
+ const dt = sigmaNext - sigma;
358
+
359
+ // denoise
360
+ const xTensor = createTensor(x, shape);
361
+ const sigmaTensor = makeSigmaTensor(sigma);
362
+
363
+ let denoised;
364
+ if (cfgScale !== 1.0) {
365
+ // cfg: run both conditional and unconditional
366
+ const feedC = { [edmInputX]: xTensor, [edmInputSigma]: sigmaTensor, [edmInputCond]: condTensor };
367
+ const feedU = { [edmInputX]: xTensor, [edmInputSigma]: sigmaTensor, [edmInputCond]: condUncondTensor };
368
+ const outC = firstOutput(await edmDenoise.run(feedC), edmOutputName);
369
+ const outU = firstOutput(await edmDenoise.run(feedU), edmOutputName);
370
+
371
+ const denC = await getTensorData(outC);
372
+ const denU = await getTensorData(outU);
373
+ denoised = new Float32Array(x.length);
374
+ for (let i = 0; i < denoised.length; i++) {
375
+ denoised[i] = denU[i] + cfgScale * (denC[i] - denU[i]);
376
+ }
377
+ } else {
378
+ const feed = { [edmInputX]: xTensor, [edmInputSigma]: sigmaTensor, [edmInputCond]: condTensor };
379
+ const out = firstOutput(await edmDenoise.run(feed), edmOutputName);
380
+ const outData = await getTensorData(out);
381
+ denoised = new Float32Array(outData);
382
+ }
383
+
384
+ // euler step (matches koi.train.model_latent_edm.edm_sample_euler):
385
+ // d = (x - denoised) / sigma
386
+ // x = x + d * dt
387
+ for (let i = 0; i < x.length; i++) {
388
+ const d = (x[i] - denoised[i]) / sigma;
389
+ x[i] = x[i] + d * dt;
390
+ }
391
+ }
392
+
393
+ // final denoise at sigma_min
394
+ const xTensor = createTensor(x, shape);
395
+ const sigmaTensor = makeSigmaTensor(sigmaMin);
396
+
397
+ if (cfgScale !== 1.0) {
398
+ const feedC = { [edmInputX]: xTensor, [edmInputSigma]: sigmaTensor, [edmInputCond]: condTensor };
399
+ const feedU = { [edmInputX]: xTensor, [edmInputSigma]: sigmaTensor, [edmInputCond]: condUncondTensor };
400
+ const outC = firstOutput(await edmDenoise.run(feedC), edmOutputName);
401
+ const outU = firstOutput(await edmDenoise.run(feedU), edmOutputName);
402
+ const denC = await getTensorData(outC);
403
+ const denU = await getTensorData(outU);
404
+ x = new Float32Array(x.length);
405
+ for (let i = 0; i < x.length; i++) {
406
+ x[i] = denU[i] + cfgScale * (denC[i] - denU[i]);
407
+ }
408
+ } else {
409
+ const feed = { [edmInputX]: xTensor, [edmInputSigma]: sigmaTensor, [edmInputCond]: condTensor };
410
+ const out = firstOutput(await edmDenoise.run(feed), edmOutputName);
411
+ const outData = await getTensorData(out);
412
+ x = new Float32Array(outData);
413
+ }
414
+
415
+ return x;
416
+ }
417
+
418
+ // =====================
419
+ // decode latent to image
420
+ // =====================
421
+ async function decodeLatent(z) {
422
+ const { latentCh, latentSize, imgSize } = CONFIG;
423
+ const zTensor = createTensor(z, [1, latentCh, latentSize, latentSize]);
424
+ const imgTensor = firstOutput(await aeDecode.run({ [aeInputName]: zTensor }), aeOutputName);
425
+ const imgData = await getTensorData(imgTensor);
426
+ return { data: imgData, dims: imgTensor.dims }; // CHW or NHWC
427
+ }
428
+
429
+ // =====================
430
+ // render to canvas
431
+ // =====================
432
+ function renderImage(imgOut) {
433
+ const { imgSize } = CONFIG;
434
+ const imgData = imgOut?.data || imgOut;
435
+ const dims = imgOut?.dims || null;
436
+ if (!_imageData || _imageData.width !== imgSize || _imageData.height !== imgSize) {
437
+ _imageData = ctx.createImageData(imgSize, imgSize);
438
+ }
439
+ const dst = _imageData.data;
440
+ const spatial = imgSize * imgSize;
441
+ const isNHWC = dims && dims.length === 4 && dims[3] === 3;
442
+ for (let i = 0; i < spatial; i++) {
443
+ const di = i * 4;
444
+ let r, g, b;
445
+ if (isNHWC) {
446
+ const base = i * 3;
447
+ r = imgData[base + 0];
448
+ g = imgData[base + 1];
449
+ b = imgData[base + 2];
450
+ } else {
451
+ r = imgData[i];
452
+ g = imgData[spatial + i];
453
+ b = imgData[2 * spatial + i];
454
+ }
455
+ dst[di + 0] = Math.min(255, Math.max(0, (r * 255) | 0));
456
+ dst[di + 1] = Math.min(255, Math.max(0, (g * 255) | 0));
457
+ dst[di + 2] = Math.min(255, Math.max(0, (b * 255) | 0));
458
+ dst[di + 3] = 255;
459
+ }
460
+ ctx.putImageData(_imageData, 0, 0);
461
+ }
462
+
463
+ // =====================
464
+ // update history buffer
465
+ // =====================
466
+ function updateHistory(newZ) {
467
+ const { latentCh, latentSize, history } = CONFIG;
468
+ const frameSize = latentCh * latentSize * latentSize;
469
+
470
+ // shift left by one frame
471
+ for (let i = 0; i < (history - 1) * frameSize; i++) {
472
+ zHistory[i] = zHistory[i + frameSize];
473
+ }
474
+
475
+ // copy new frame to end
476
+ const offset = (history - 1) * frameSize;
477
+ for (let i = 0; i < frameSize; i++) {
478
+ zHistory[offset + i] = newZ[i];
479
+ }
480
+ }
481
+
482
+ // =====================
483
+ // main step
484
+ // =====================
485
+ async function step() {
486
+ if (paused || running) return;
487
+ running = true;
488
+ if (ortBackend !== 'ws' && CONFIG.playbackFps > 0) {
489
+ lastStepStartMs = performance.now();
490
+ }
491
+
492
+ try {
493
+ let tap = { x: -1, y: -1 };
494
+ if (tapQueue.length) {
495
+ tap = tapQueue.shift();
496
+ lastTap = tap;
497
+ tapHoldLeft = CONFIG.tapHoldFrames;
498
+ } else if (tapHoldLeft > 0) {
499
+ tap = lastTap;
500
+ tapHoldLeft -= 1;
501
+ }
502
+ const wantCfgScale = parseFloat(cfgSlider.value);
503
+ const hasTap = tap.x >= 0 && tap.y >= 0;
504
+ // CFG is only meaningful when cond != uncond. If there's no tap, action maps are all-zero anyway.
505
+ const cfgScale = hasTap ? wantCfgScale : 1.0;
506
+
507
+ // build action map and conditioning
508
+ const actionMap = makeActionMap(tap.x, tap.y);
509
+ const cond = buildCond(actionMap);
510
+
511
+ // unconditional cond (zero action) for CFG
512
+ let condUncond = cond;
513
+ if (cfgScale !== 1.0) {
514
+ const zeroAction = makeActionMap(-1, -1);
515
+ condUncond = buildCond(zeroAction);
516
+ }
517
+
518
+ // sample next latent
519
+ const zNext = await eulerSample(cond, condUncond, cfgScale, hasTap);
520
+
521
+ // decode to image
522
+ const imgOut = await decodeLatent(zNext);
523
+ renderImage(imgOut);
524
+
525
+ // update history
526
+ updateHistory(zNext);
527
+
528
+ // fps counter
529
+ frameCount++;
530
+ totalFrames++;
531
+ frameEl.textContent = totalFrames;
532
+
533
+ const now = Date.now();
534
+ if (now - lastFpsTime > 1000) {
535
+ fpsEl.textContent = frameCount;
536
+ frameCount = 0;
537
+ lastFpsTime = now;
538
+ }
539
+ } catch (e) {
540
+ console.error('step error:', e);
541
+ const msg = (e && e.message) ? e.message : String(e);
542
+ if (ortBackend === 'webgpu' && FALLBACK_ON_ERROR) {
543
+ await switchToWasmFallback(msg);
544
+ } else {
545
+ statusEl.textContent = 'Error: ' + msg;
546
+ }
547
+ }
548
+
549
+ running = false;
550
+ if (!paused && ortBackend !== 'ws') scheduleNextStep();
551
+ }
552
+
553
+ function scheduleNextStep() {
554
+ if (paused || ortBackend === 'ws') return;
555
+ if (CONFIG.playbackFps > 0) {
556
+ const targetDt = 1000 / CONFIG.playbackFps;
557
+ const delay = Math.max(0, targetDt - (performance.now() - lastStepStartMs));
558
+ setTimeout(() => { if (!paused) requestAnimationFrame(step); }, delay);
559
+ } else {
560
+ requestAnimationFrame(step);
561
+ }
562
+ }
563
+
564
+ // =====================
565
+ // controls
566
+ // =====================
567
+ function togglePause() {
568
+ paused = !paused;
569
+ pauseBtn.textContent = paused ? 'Play' : 'Pause';
570
+ if (!paused) {
571
+ if (ortBackend === 'ws') return;
572
+ scheduleNextStep();
573
+ }
574
+ }
575
+
576
+ async function resetWorld() {
577
+ if (ortBackend === 'ws') {
578
+ try {
579
+ if (ws && ws.readyState === WebSocket.OPEN) ws.send(JSON.stringify({ reset: true }));
580
+ } catch (_) {}
581
+ totalFrames = 0;
582
+ frameEl.textContent = '0';
583
+ return;
584
+ }
585
+ const { latentCh, latentSize, history, imgSize } = CONFIG;
586
+ const size = history * latentCh * latentSize * latentSize;
587
+ zHistory = new Float32Array(size);
588
+ totalFrames = 0;
589
+ frameEl.textContent = '0';
590
+
591
+ // try to load initial frames if available
592
+ try {
593
+ statusEl.textContent = 'Loading initial frames...';
594
+ const initUrl = resolveAssetUrl('init_frames.bin');
595
+ const resp = await fetch(initUrl);
596
+ if (resp.ok) {
597
+ const buffer = await resp.arrayBuffer();
598
+ const initData = new Float32Array(buffer);
599
+ if (initData.length === size) {
600
+ zHistory.set(initData);
601
+ statusEl.textContent = 'Ready! Click Play to start.';
602
+
603
+ // decode and show last frame (best-effort, non-blocking)
604
+ try {
605
+ const frameSize = latentCh * latentSize * latentSize;
606
+ const lastZ = zHistory.slice((history - 1) * frameSize);
607
+ decodeLatent(lastZ).then((imgOut) => {
608
+ const imgData = imgOut?.data || imgOut;
609
+ const expected = 3 * imgSize * imgSize;
610
+ if (!imgData || imgData.length !== expected) {
611
+ console.warn(`Init preview decode bad length: ${imgData?.length} (expected ${expected})`);
612
+ } else {
613
+ renderImage(imgOut);
614
+ }
615
+ }).catch((e) => {
616
+ console.warn('Init preview decode failed (continuing):', e);
617
+ });
618
+ } catch (e) {
619
+ console.warn('Init preview setup failed (continuing):', e);
620
+ }
621
+ return;
622
+ } else {
623
+ console.warn(`init_frames.bin length mismatch: got ${initData.length}, expected ${size}`);
624
+ }
625
+ }
626
+ } catch (e) {
627
+ console.log('No init_frames.bin, using random init', e);
628
+ }
629
+
630
+ // fallback: use small random noise (model will need to "warm up")
631
+ for (let i = 0; i < size; i++) {
632
+ zHistory[i] = (Math.random() - 0.5) * 0.1;
633
+ }
634
+
635
+ ctx.fillStyle = '#222';
636
+ ctx.fillRect(0, 0, canvas.width, canvas.height);
637
+ ctx.fillStyle = '#666';
638
+ ctx.font = '14px monospace';
639
+ ctx.textAlign = 'center';
640
+ ctx.fillText('No init data - will generate from noise', canvas.width / 2, canvas.height / 2);
641
+ statusEl.textContent = 'Ready (no init frames). Click Play.';
642
+ }
643
+
644
+ canvas.onclick = (e) => {
645
+ const rect = canvas.getBoundingClientRect();
646
+ const x = (e.clientX - rect.left) / rect.width;
647
+ const y = (e.clientY - rect.top) / rect.height;
648
+ if (ortBackend === 'ws') {
649
+ if (ws && ws.readyState === WebSocket.OPEN) {
650
+ pendingTap = { x, y, t_ms: Date.now() };
651
+ try {
652
+ ws.send(JSON.stringify({ tap_x: x, tap_y: y, cfg_scale: parseFloat(cfgSlider.value) }));
653
+ } catch (_) {}
654
+ }
655
+ } else {
656
+ tapQueue.push({ x, y });
657
+ }
658
+ };
659
+
660
+ // =====================
661
+ // init
662
+ // =====================
663
+ function configureOrtRuntime() {
664
+ // Enable WASM SIMD when available. Threads require cross-origin isolation (SharedArrayBuffer).
665
+ try {
666
+ if (ort?.env?.wasm) {
667
+ ort.env.wasm.simd = true;
668
+ // Force single-thread WASM to avoid worker issues on this server.
669
+ ort.env.wasm.numThreads = 1;
670
+ // Use CDN paths for wasm binaries to avoid path issues in workers.
671
+ ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.2/dist/';
672
+ // Proxy worker disabled (requires COOP/COEP and can fail in this server).
673
+ ort.env.wasm.proxy = false;
674
+ }
675
+ } catch (e) {
676
+ // best-effort
677
+ }
678
+ // Prefer high-performance adapter when using WebGPU (if available).
679
+ try {
680
+ if (ort?.env?.webgpu) {
681
+ ort.env.webgpu.powerPreference = 'high-performance';
682
+ }
683
+ } catch (e) {
684
+ // best-effort
685
+ }
686
+ // Optional debug flags via ?debug=1
687
+ try {
688
+ const params = new URLSearchParams(location.search);
689
+ if (params.get('debug') === '1') {
690
+ ort.env.debug = true;
691
+ ort.env.logLevel = 'verbose';
692
+ }
693
+ } catch (e) {
694
+ // best-effort
695
+ }
696
+ }
697
+
698
+ async function fetchArrayBuffer(url) {
699
+ const resp = await fetch(url);
700
+ if (!resp.ok) {
701
+ throw new Error(`fetch failed ${resp.status} for ${url}`);
702
+ }
703
+ return await resp.arrayBuffer();
704
+ }
705
+
706
+ function resolveModelUrl(filename) {
707
+ if (!MODEL_BASE_URL) return new URL(filename, location.href).toString();
708
+ // map filenames to hf repo layout: fp16/ or fp32/ subfolder under onnx/
709
+ const isFp16 = filename.includes('fp16');
710
+ const subdir = isFp16 ? 'onnx/fp16/' : 'onnx/fp32/';
711
+ return MODEL_BASE_URL + subdir + filename;
712
+ }
713
+
714
+ function resolveAssetUrl(filename) {
715
+ if (!MODEL_BASE_URL) return new URL(filename, location.href).toString();
716
+ return MODEL_BASE_URL + 'viewer/' + filename;
717
+ }
718
+
719
+ async function loadModelWithExternal(modelFile) {
720
+ const modelUrl = resolveModelUrl(modelFile);
721
+ const dataFile = modelFile + '.data';
722
+ const dataUrl = resolveModelUrl(dataFile);
723
+ const modelBuf = await fetchArrayBuffer(modelUrl);
724
+ let dataBuf = null;
725
+ try {
726
+ dataBuf = await fetchArrayBuffer(dataUrl);
727
+ } catch (e) {
728
+ // external data may be inlined (monolithic ONNX)
729
+ dataBuf = null;
730
+ }
731
+ if (dataBuf) {
732
+ return {
733
+ model: new Uint8Array(modelBuf),
734
+ externalData: [{ path: dataFile, data: new Uint8Array(dataBuf) }],
735
+ };
736
+ }
737
+ return { model: new Uint8Array(modelBuf) };
738
+ }
739
+
740
+ async function createSessionsWith(executionProviders, preferFp16) {
741
+ const options = { executionProviders, graphOptimizationLevel: 'all' };
742
+ if (executionProviders && executionProviders.includes('webgpu')) {
743
+ options.enableGraphCapture = false;
744
+ }
745
+
746
+ // try fp16 models if requested (keep_io_types=True means I/O stays fp32)
747
+ if (preferFp16) {
748
+ try {
749
+ const aeF16 = await loadModelWithExternal('ae_decode_fp16.onnx');
750
+ const edmF16 = await loadModelWithExternal('edm_denoise_fp16.onnx');
751
+ const aeOpts16 = { ...options };
752
+ if (aeF16.externalData) aeOpts16.externalData = aeF16.externalData;
753
+ const edmOpts16 = { ...options };
754
+ if (edmF16.externalData) edmOpts16.externalData = edmF16.externalData;
755
+ aeDecode = await ort.InferenceSession.create(aeF16.model, aeOpts16);
756
+ edmDenoise = await ort.InferenceSession.create(edmF16.model, edmOpts16);
757
+ ortPrecision = 'fp16';
758
+ console.log('fp16 models loaded successfully');
759
+ } catch (e) {
760
+ console.warn('fp16 load failed, falling back to fp32:', e);
761
+ aeDecode = null;
762
+ edmDenoise = null;
763
+ }
764
+ }
765
+
766
+ // fp32 fallback
767
+ if (!aeDecode || !edmDenoise) {
768
+ const aeFp32 = await loadModelWithExternal('ae_decode.onnx');
769
+ const edmFp32 = await loadModelWithExternal('edm_denoise.onnx');
770
+ const aeOpts = { ...options };
771
+ if (aeFp32.externalData) aeOpts.externalData = aeFp32.externalData;
772
+ const edmOpts = { ...options };
773
+ if (edmFp32.externalData) edmOpts.externalData = edmFp32.externalData;
774
+ aeDecode = await ort.InferenceSession.create(aeFp32.model, aeOpts);
775
+ edmDenoise = await ort.InferenceSession.create(edmFp32.model, edmOpts);
776
+ ortPrecision = 'fp32';
777
+ }
778
+
779
+ aeInputName = pickName(aeDecode.inputNames || [], ['z', 'latent'], 0);
780
+ aeOutputName = (aeDecode.outputNames && aeDecode.outputNames[0]) || 'img';
781
+ aeInputMeta = aeDecode.inputMetadata || null;
782
+ const edmInputs = edmDenoise.inputNames || [];
783
+ edmInputX = pickName(edmInputs, ['x_noisy', 'x'], 0);
784
+ edmInputSigma = pickName(edmInputs, ['sigma', 't'], 1);
785
+ edmInputCond = pickName(edmInputs, ['cond', 'c'], 2);
786
+ edmOutputName = (edmDenoise.outputNames && edmDenoise.outputNames[0]) || 'x_hat';
787
+ edmInputMeta = edmDenoise.inputMetadata || null;
788
+ }
789
+
790
+ async function switchToWasmFallback(reason) {
791
+ if (fallbackAttempted) return;
792
+ fallbackAttempted = true;
793
+ try {
794
+ statusEl.textContent = `WebGPU error, switching to WASM...`;
795
+ await createSessionsWith(['wasm'], false);
796
+ ortBackend = 'wasm';
797
+ backendEl.textContent = ortBackend;
798
+ precisionEl.textContent = ortPrecision;
799
+ const diagEl = document.getElementById('diag');
800
+ if (diagEl) diagEl.textContent += `\nfallback: wasm (${reason || 'runtime error'})`;
801
+ statusEl.textContent = `Ready! backend=${ortBackend}. Click Play.`;
802
+ } catch (e) {
803
+ statusEl.textContent = `Fatal: WebGPU failed and WASM fallback failed`;
804
+ console.error('fallback failed:', e);
805
+ }
806
+ }
807
+
808
+ async function loadJSON(path) {
809
+ try {
810
+ const resp = await fetch(path);
811
+ if (!resp.ok) return null;
812
+ return await resp.json();
813
+ } catch (e) {
814
+ return null;
815
+ }
816
+ }
817
+
818
+ async function diagFetch(url, label) {
819
+ try {
820
+ const resp = await fetch(url, { method: 'HEAD' });
821
+ if (!resp.ok) return `${label}: ${resp.status}`;
822
+ const len = resp.headers.get('content-length');
823
+ return `${label}: ok${len ? ` (${len} bytes)` : ''}`;
824
+ } catch (e) {
825
+ return `${label}: err ${e && e.message ? e.message : e}`;
826
+ }
827
+ }
828
+
829
+ async function diagFetchOptional(url, label) {
830
+ try {
831
+ const resp = await fetch(url, { method: 'HEAD' });
832
+ if (resp.status === 404) return `${label}: n/a (monolithic)`;
833
+ if (!resp.ok) return `${label}: ${resp.status}`;
834
+ const len = resp.headers.get('content-length');
835
+ return `${label}: ok${len ? ` (${len} bytes)` : ''}`;
836
+ } catch (e) {
837
+ return `${label}: err ${e && e.message ? e.message : e}`;
838
+ }
839
+ }
840
+
841
+ async function runDiagnostics(wantFp16) {
842
+ const diagEl = document.getElementById('diag');
843
+ if (!diagEl) return;
844
+ const base = ort?.env?.wasm?.wasmPaths || 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.2/dist/';
845
+ const wasm = base + 'ort-wasm-simd-threaded.wasm';
846
+ const jsep = base + 'ort-wasm-simd-threaded.jsep.wasm';
847
+ const ortJs = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.2/dist/ort.webgpu.min.js';
848
+ const aeName = wantFp16 ? 'ae_decode_fp16.onnx' : 'ae_decode.onnx';
849
+ const edmName = wantFp16 ? 'edm_denoise_fp16.onnx' : 'edm_denoise.onnx';
850
+ const lines = await Promise.all([
851
+ diagFetch(ortJs, 'ort.webgpu.min.js'),
852
+ diagFetch(wasm, 'ort-wasm-simd-threaded.wasm'),
853
+ diagFetch(jsep, 'ort-wasm-simd-threaded.jsep.wasm'),
854
+ diagFetch(resolveAssetUrl('model_meta.json'), 'model_meta.json'),
855
+ diagFetch(resolveAssetUrl('init_frames.bin'), 'init_frames.bin'),
856
+ diagFetch(resolveAssetUrl('init_frames_meta.json'), 'init_frames_meta.json'),
857
+ diagFetch(resolveModelUrl(aeName), aeName),
858
+ diagFetchOptional(resolveModelUrl(`${aeName}.data`), `${aeName}.data`),
859
+ diagFetch(resolveModelUrl(edmName), edmName),
860
+ diagFetchOptional(resolveModelUrl(`${edmName}.data`), `${edmName}.data`),
861
+ ]);
862
+ diagEl.textContent = lines.join('\n');
863
+ }
864
+
865
+ async function init() {
866
+ try {
867
+ const params0 = new URLSearchParams(location.search);
868
+ const wantFp16_0 = (params0.get('fp16') === '1');
869
+ if (params0.get('hf') === '1') {
870
+ const rev = params0.get('hf_rev');
871
+ if (rev && HF_BASE.includes('/resolve/main/')) {
872
+ MODEL_BASE_URL = HF_BASE.replace('/resolve/main/', `/resolve/${rev}/`);
873
+ } else {
874
+ MODEL_BASE_URL = HF_BASE;
875
+ }
876
+ }
877
+ const mode0 = (params0.get('mode') || 'onnx').toLowerCase();
878
+ if (modeEl) modeEl.value = (mode0 === 'ws') ? 'ws' : 'onnx';
879
+ if (modeEl) {
880
+ modeEl.onchange = () => {
881
+ const next = modeEl.value;
882
+ location.href = setQueryParam('mode', next === 'ws' ? 'ws' : 'onnx');
883
+ };
884
+ }
885
+
886
+ if (mode0 === 'ws') {
887
+ // WebSocket mode: use the same canvas/controls but render server frames.
888
+ await initWebSocketMode();
889
+ return;
890
+ }
891
+ configureOrtRuntime();
892
+ await runDiagnostics(wantFp16_0);
893
+
894
+ // Prefer WebGPU when available; fall back to WASM.
895
+ // WebGPU will typically be 10-50x faster than WASM for this size model.
896
+ const canWebGPU = !!navigator.gpu;
897
+ const params = new URLSearchParams(location.search);
898
+ const forceEp = (params.get('ep') || '').toLowerCase(); // 'webgpu' | 'wasm' | ''
899
+ const wantFp16 = (params.get('fp16') === '1');
900
+ const debug = (params.get('debug') === '1');
901
+ FORCE_FP16 = wantFp16;
902
+ FORCE_SIGMA_SCALAR = (params.get('sigma') === 'scalar');
903
+ FALLBACK_ON_ERROR = (params.get('fallback') || '1') !== '0';
904
+ const noFallback = (params.get('nofallback') === '1');
905
+ GRAPH_CAPTURE = false; // graph capture disabled
906
+ let provider = 'wasm';
907
+ let webgpuOk = false;
908
+ let webgpuMsg = '';
909
+
910
+ if (canWebGPU) {
911
+ try {
912
+ const adapter = await navigator.gpu.requestAdapter();
913
+ if (adapter) {
914
+ webgpuOk = true;
915
+ if (debug) {
916
+ const device = await adapter.requestDevice();
917
+ webgpuMsg = `adapter ok, device ok`;
918
+ if (device?.limits) {
919
+ webgpuMsg += `, maxBuf=${device.limits.maxStorageBufferBindingSize || 'n/a'}`;
920
+ }
921
+ }
922
+ } else {
923
+ webgpuMsg = 'no adapter';
924
+ }
925
+ } catch (e) {
926
+ webgpuOk = false;
927
+ webgpuMsg = `adapter error: ${(e && e.message) ? e.message : e}`;
928
+ }
929
+ }
930
+
931
+ // Load optional model meta to override CONFIG (history/latent size/etc).
932
+ const modelMeta = await loadJSON(resolveAssetUrl('model_meta.json'));
933
+ if (modelMeta) {
934
+ if (typeof modelMeta.history === 'number') CONFIG.history = modelMeta.history;
935
+ if (typeof modelMeta.condSigmaCh === 'number') CONFIG.condSigmaCh = modelMeta.condSigmaCh;
936
+ if (typeof modelMeta.latentCh === 'number') CONFIG.latentCh = modelMeta.latentCh;
937
+ if (typeof modelMeta.latentSize === 'number') CONFIG.latentSize = modelMeta.latentSize;
938
+ if (typeof modelMeta.imgSize === 'number') CONFIG.imgSize = modelMeta.imgSize;
939
+ if (typeof modelMeta.sigmaData === 'number') CONFIG.sigmaData = modelMeta.sigmaData;
940
+ }
941
+
942
+ // Query param overrides (take precedence over model_meta defaults).
943
+ function _getFloat(key) {
944
+ const v = params.get(key);
945
+ if (v == null || v === '') return null;
946
+ const f = parseFloat(v);
947
+ return Number.isFinite(f) ? f : null;
948
+ }
949
+ function _getInt(key) {
950
+ const v = params.get(key);
951
+ if (v == null || v === '') return null;
952
+ const n = parseInt(v, 10);
953
+ return Number.isFinite(n) ? n : null;
954
+ }
955
+
956
+ const qsSteps = _getInt('steps');
957
+ if (qsSteps != null && qsSteps >= 1) CONFIG.eulerSteps = qsSteps;
958
+ const qsRho = _getFloat('rho');
959
+ if (qsRho != null && qsRho > 0) CONFIG.rho = qsRho;
960
+ const qsSigmaMax = _getFloat('sigmaMax');
961
+ if (qsSigmaMax != null && qsSigmaMax > 0) CONFIG.sigmaMax = qsSigmaMax;
962
+ const sigmaMaxNoTapProvided = params.has('sigmaMaxNoTap');
963
+ const qsSigmaMaxNoTap = _getFloat('sigmaMaxNoTap');
964
+ if (qsSigmaMaxNoTap != null && qsSigmaMaxNoTap > 0) CONFIG.sigmaMaxNoTap = qsSigmaMaxNoTap;
965
+ if (!sigmaMaxNoTapProvided) CONFIG.sigmaMaxNoTap = CONFIG.sigmaMax;
966
+ const qsSigmaMin = _getFloat('sigmaMin');
967
+ if (qsSigmaMin != null && qsSigmaMin > 0) CONFIG.sigmaMin = qsSigmaMin;
968
+ const qsNoiseScale = _getFloat('noiseScale');
969
+ if (qsNoiseScale != null && qsNoiseScale >= 0) CONFIG.noiseScale = qsNoiseScale;
970
+ const noiseScaleNoTapProvided = params.has('noiseScaleNoTap');
971
+ const qsNoiseScaleNoTap = _getFloat('noiseScaleNoTap');
972
+ if (qsNoiseScaleNoTap != null && qsNoiseScaleNoTap >= 0) CONFIG.noiseScaleNoTap = qsNoiseScaleNoTap;
973
+ if (!noiseScaleNoTapProvided) CONFIG.noiseScaleNoTap = CONFIG.noiseScale;
974
+ const qsTapHold = _getInt('tapHoldFrames');
975
+ if (qsTapHold != null && qsTapHold >= 0) CONFIG.tapHoldFrames = qsTapHold;
976
+ const qsPlaybackFps = _getFloat('playbackFps');
977
+ if (qsPlaybackFps != null && qsPlaybackFps >= 0) CONFIG.playbackFps = qsPlaybackFps;
978
+ const qsTapStrength = _getFloat('tapStrength');
979
+ if (qsTapStrength != null && qsTapStrength >= 0) CONFIG.tapStrength = qsTapStrength;
980
+ const qsHeatSigma = _getFloat('heatmapSigmaPx');
981
+ if (qsHeatSigma != null && qsHeatSigma > 0) CONFIG.heatmapSigmaPx = qsHeatSigma;
982
+ const qsCfg = _getFloat('cfg');
983
+ if (qsCfg != null && qsCfg >= parseFloat(cfgSlider.min) && qsCfg <= parseFloat(cfgSlider.max)) {
984
+ cfgSlider.value = String(qsCfg);
985
+ cfgValEl.textContent = String(qsCfg);
986
+ }
987
+ // Resize canvas to match model image size.
988
+ canvas.width = CONFIG.imgSize;
989
+ canvas.height = CONFIG.imgSize;
990
+
991
+ statusEl.textContent = 'Loading models...';
992
+ progressEl.style.display = 'block';
993
+ progressEl.value = 0;
994
+
995
+ if (forceEp === 'webgpu' && !webgpuOk) {
996
+ throw new Error('WebGPU not available (no adapter)');
997
+ }
998
+
999
+ if (webgpuOk && forceEp !== 'wasm') {
1000
+ try {
1001
+ provider = 'webgpu';
1002
+ const fp16Label = wantFp16 ? ', fp16' : '';
1003
+ statusEl.textContent = `Loading models (WebGPU${fp16Label})...`;
1004
+ await createSessionsWith(['webgpu'], wantFp16);
1005
+ } catch (e) {
1006
+ console.warn('WebGPU init failed, falling back to WASM:', e);
1007
+ if (!noFallback) {
1008
+ provider = 'wasm';
1009
+ statusEl.textContent = 'Loading models (WASM fallback)...';
1010
+ await createSessionsWith(['wasm'], false);
1011
+ } else {
1012
+ throw e;
1013
+ }
1014
+ }
1015
+ } else {
1016
+ provider = 'wasm';
1017
+ statusEl.textContent = 'Loading models (WASM)...';
1018
+ await createSessionsWith(['wasm'], false);
1019
+ }
1020
+ progressEl.value = 100;
1021
+ ortBackend = provider;
1022
+ backendEl.textContent = ortBackend;
1023
+ precisionEl.textContent = ortPrecision;
1024
+ try {
1025
+ const diagEl = document.getElementById('diag');
1026
+ if (diagEl && aeDecode && edmDenoise) {
1027
+ diagEl.textContent += `\nae inputs: ${JSON.stringify(aeDecode.inputNames || [])} outputs: ${JSON.stringify(aeDecode.outputNames || [])}`;
1028
+ diagEl.textContent += `\nedm inputs: ${JSON.stringify(edmDenoise.inputNames || [])} outputs: ${JSON.stringify(edmDenoise.outputNames || [])}`;
1029
+ if (edmInputMeta && edmInputMeta[edmInputSigma]) {
1030
+ diagEl.textContent += `\nsigma dims: ${JSON.stringify(getMetaDims(edmInputMeta[edmInputSigma]))}`;
1031
+ }
1032
+ diagEl.textContent += `\nsigma override: ${FORCE_SIGMA_SCALAR ? 'scalar' : 'auto'}`;
1033
+ diagEl.textContent += `\nconfig: history=${CONFIG.history} condSigmaCh=${CONFIG.condSigmaCh} steps=${CONFIG.eulerSteps} rho=${CONFIG.rho} ` +
1034
+ `sigmaMax=${CONFIG.sigmaMax} sigmaMaxNoTap=${CONFIG.sigmaMaxNoTap} sigmaMin=${CONFIG.sigmaMin} ` +
1035
+ `noiseScale=${CONFIG.noiseScale} noiseScaleNoTap=${CONFIG.noiseScaleNoTap} tapHold=${CONFIG.tapHoldFrames} tapStrength=${CONFIG.tapStrength} heatSigmaPx=${CONFIG.heatmapSigmaPx}`;
1036
+ }
1037
+ } catch (e) {}
1038
+ console.log('Using backend:', ortBackend);
1039
+
1040
+ // init history buffer
1041
+ await resetWorld();
1042
+
1043
+ progressEl.style.display = 'none';
1044
+ const wasmThreads = (ortBackend === 'wasm') ? (ort?.env?.wasm?.numThreads || 1) : null;
1045
+ const extra = (ortBackend === 'wasm')
1046
+ ? ` (threads=${wasmThreads}, crossOriginIsolated=${!!self.crossOriginIsolated})`
1047
+ : (webgpuMsg ? ` (${webgpuMsg})` : '');
1048
+ statusEl.textContent = `Ready! backend=${ortBackend}${extra}. Click Play.`;
1049
+ pauseBtn.textContent = paused ? 'Play' : 'Pause';
1050
+
1051
+ } catch (e) {
1052
+ console.error('init error:', e);
1053
+ let msg = (e && e.message) ? e.message : String(e);
1054
+ try {
1055
+ msg += ' | ' + JSON.stringify(e, Object.getOwnPropertyNames(e));
1056
+ } catch (err) {}
1057
+ statusEl.textContent = 'Failed to load: ' + msg;
1058
+ progressEl.style.display = 'none';
1059
+ }
1060
+ }
1061
+
1062
+ init();
1063
+
1064
+ // =====================
1065
+ // WebSocket backend (server streaming)
1066
+ // =====================
1067
+ async function initWebSocketMode() {
1068
+ backendEl.textContent = 'ws';
1069
+ precisionEl.textContent = 'server';
1070
+ statusEl.textContent = 'Connecting to server...';
1071
+ progressEl.style.display = 'none';
1072
+ ortBackend = 'ws';
1073
+ ortPrecision = 'server';
1074
+ paused = false;
1075
+ pauseBtn.textContent = 'Pause';
1076
+
1077
+ const proto = (location.protocol === 'https:') ? 'wss:' : 'ws:';
1078
+ const wsUrl = proto + '//' + location.host + '/ws';
1079
+ ws = new WebSocket(wsUrl);
1080
+ ws.binaryType = 'arraybuffer';
1081
+
1082
+ ws.onopen = () => {
1083
+ statusEl.textContent = `Connected (ws). Tap to interact.`;
1084
+ try {
1085
+ ws.send(JSON.stringify({ cfg_scale: parseFloat(cfgSlider.value) }));
1086
+ } catch (_) {}
1087
+ };
1088
+
1089
+ cfgSlider.oninput = () => {
1090
+ cfgValEl.textContent = cfgSlider.value;
1091
+ try {
1092
+ if (ws && ws.readyState === WebSocket.OPEN) {
1093
+ ws.send(JSON.stringify({ tap_x: -1, tap_y: -1, cfg_scale: parseFloat(cfgSlider.value) }));
1094
+ }
1095
+ } catch (_) {}
1096
+ };
1097
+
1098
+ ws.onmessage = async (e) => {
1099
+ if (typeof e.data === 'string') {
1100
+ try {
1101
+ const msg = JSON.parse(e.data);
1102
+ if (msg && msg.mode) wsServerMode = msg.mode;
1103
+ if (msg && msg.format) {
1104
+ window.__WS_FMT__ = msg.format;
1105
+ }
1106
+ if (msg && msg.width && msg.height) {
1107
+ canvas.width = msg.width;
1108
+ canvas.height = msg.height;
1109
+ window.__WS_W__ = msg.width;
1110
+ window.__WS_H__ = msg.height;
1111
+ window.__WS_IMD__ = new ImageData(msg.width, msg.height);
1112
+ }
1113
+ } catch (_) {}
1114
+ return;
1115
+ }
1116
+ if (paused) return;
1117
+ try {
1118
+ const fmt = window.__WS_FMT__ || 'jpeg';
1119
+ if (fmt === 'raw_rgba') {
1120
+ const u8 = new Uint8ClampedArray(e.data);
1121
+ let imd = window.__WS_IMD__;
1122
+ if (!imd || imd.data.length !== u8.length) {
1123
+ const px = Math.floor(Math.sqrt(u8.length / 4));
1124
+ canvas.width = px; canvas.height = px;
1125
+ imd = new ImageData(px, px);
1126
+ window.__WS_IMD__ = imd;
1127
+ }
1128
+ imd.data.set(u8);
1129
+ ctx.putImageData(imd, 0, 0);
1130
+ } else {
1131
+ const mimeType = (fmt === 'webp') ? 'image/webp' : 'image/jpeg';
1132
+ const blob = new Blob([e.data], { type: mimeType });
1133
+ if ('createImageBitmap' in window) {
1134
+ const bmp = await createImageBitmap(blob);
1135
+ ctx.drawImage(bmp, 0, 0, canvas.width, canvas.height);
1136
+ if (bmp && bmp.close) bmp.close();
1137
+ } else {
1138
+ const url = URL.createObjectURL(blob);
1139
+ const img = new Image();
1140
+ await new Promise((resolve) => {
1141
+ img.onload = resolve;
1142
+ img.onerror = resolve;
1143
+ img.src = url;
1144
+ });
1145
+ ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
1146
+ URL.revokeObjectURL(url);
1147
+ }
1148
+ }
1149
+ } catch (err) {
1150
+ console.warn('ws decode/draw failed', err);
1151
+ }
1152
+
1153
+ // local tap echo
1154
+ if (pendingTap && (Date.now() - pendingTap.t_ms) < 750) {
1155
+ ctx.save();
1156
+ ctx.strokeStyle = 'rgba(255,255,255,0.9)';
1157
+ ctx.lineWidth = 2;
1158
+ const x = pendingTap.x * canvas.width;
1159
+ const y = pendingTap.y * canvas.height;
1160
+ ctx.beginPath();
1161
+ ctx.arc(x, y, 10, 0, Math.PI * 2);
1162
+ ctx.stroke();
1163
+ ctx.restore();
1164
+ }
1165
+
1166
+ frameCount++;
1167
+ totalFrames++;
1168
+ frameEl.textContent = totalFrames;
1169
+
1170
+ const now = Date.now();
1171
+ if (now - lastFpsTime > 1000) {
1172
+ fpsEl.textContent = frameCount;
1173
+ frameCount = 0;
1174
+ lastFpsTime = now;
1175
+ }
1176
+
1177
+ // In pull mode, request another frame. In push mode, frames arrive continuously.
1178
+ if (!paused && ws && ws.readyState === WebSocket.OPEN && wsServerMode !== 'push') {
1179
+ try {
1180
+ ws.send(JSON.stringify({ tap_x: -1, tap_y: -1, cfg_scale: parseFloat(cfgSlider.value) }));
1181
+ } catch (_) {}
1182
+ }
1183
+ };
1184
+
1185
+ ws.onclose = () => {
1186
+ statusEl.textContent = 'WS disconnected.';
1187
+ };
1188
+ ws.onerror = () => {
1189
+ statusEl.textContent = 'WS error.';
1190
+ };
1191
+ }
1192
+ </script>
1193
+ </body>
1194
+ </html>
koi/v15_scare/onnx/init_frames.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47724be1af5e460b18151ab7bb7f2b92af4d2393719bce53525885954230d6e9
3
+ size 262144
koi/v15_scare/onnx/init_frames_meta.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk": "v6_scare_temporal_w2_1771775365451_87059415_chunk_0186",
3
+ "start": 53,
4
+ "history": 8,
5
+ "img_size": 256,
6
+ "latent_ch": 8,
7
+ "latent_size": 32
8
+ }
koi/v15_scare/onnx/model_meta.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "history": 8,
3
+ "condSigmaCh": 8,
4
+ "latentCh": 8,
5
+ "latentSize": 32,
6
+ "imgSize": 256,
7
+ "sigmaData": 0.154
8
+ }