Spaces:
Sleeping
Sleeping
Taylor commited on
Commit ·
e5e1d2b
1
Parent(s): 90bf42d
fix: remove missing WASM flashAttention + stream results independently
Browse files1. flashAttentionMultiHead not in standalone WASM binary -- use JS
attention fallback (matVec is the real bottleneck, not attention)
2. Use Gradio generator pattern so PyTorch result shows immediately
when done, then Aether result shows when it finishes. No more
both flashing at the same time.
- aether-server.mjs +2 -13
- app.py +18 -15
aether-server.mjs
CHANGED
|
@@ -124,19 +124,8 @@ async function loadSIMD() {
|
|
| 124 |
wasm.resetHeap(saved);
|
| 125 |
return result;
|
| 126 |
},
|
| 127 |
-
flashAttentionMultiHead
|
| 128 |
-
|
| 129 |
-
const scale = 1.0 / Math.sqrt(headDim);
|
| 130 |
-
const qPtr = wasm.allocate(query.byteLength);
|
| 131 |
-
const kPtr = wasm.allocate(keys.byteLength);
|
| 132 |
-
const vPtr = wasm.allocate(values.byteLength);
|
| 133 |
-
const rPtr = wasm.allocate(numHeads * headDim * 4);
|
| 134 |
-
copyTo(qPtr, query); copyTo(kPtr, keys); copyTo(vPtr, values);
|
| 135 |
-
wasm.flashAttentionMultiHead(qPtr, kPtr, vPtr, rPtr, seqLen, numHeads, numKvHeads, headDim, scale);
|
| 136 |
-
const result = copyFrom(rPtr, numHeads * headDim);
|
| 137 |
-
wasm.resetHeap(saved);
|
| 138 |
-
return result;
|
| 139 |
-
},
|
| 140 |
};
|
| 141 |
} catch (e) {
|
| 142 |
console.warn(`[Aether] WASM SIMD failed: ${e.message}, using JS fallbacks`);
|
|
|
|
| 124 |
wasm.resetHeap(saved);
|
| 125 |
return result;
|
| 126 |
},
|
| 127 |
+
// flashAttentionMultiHead: not in standalone WASM -- use JS attention
|
| 128 |
+
flashAttentionMultiHead: null,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
};
|
| 130 |
} catch (e) {
|
| 131 |
console.warn(`[Aether] WASM SIMD failed: ${e.message}, using JS fallbacks`);
|
app.py
CHANGED
|
@@ -108,17 +108,20 @@ def gen_aether(prompt):
|
|
| 108 |
|
| 109 |
|
| 110 |
def compare(prompt):
|
|
|
|
| 111 |
if not prompt or not prompt.strip():
|
| 112 |
-
|
|
|
|
| 113 |
|
| 114 |
-
# Run
|
| 115 |
base_text, base_time, base_toks, base_ms = gen_pytorch(prompt)
|
| 116 |
-
aether_text, aether_time, aether_toks, aether_ms = gen_aether(prompt)
|
| 117 |
-
|
| 118 |
base_stats = f"{base_toks} tokens in {base_time:.1f}s ({base_ms:.0f}ms/tok)"
|
| 119 |
-
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
CSS = """
|
|
@@ -193,13 +196,13 @@ with gr.Blocks(css=CSS, theme=gr.themes.Base(primary_hue="blue", neutral_hue="zi
|
|
| 193 |
aether_stats = gr.HTML('<p class="stats-text">--</p>')
|
| 194 |
|
| 195 |
def run_compare(prompt_text):
|
| 196 |
-
base_text, aether_text, b_stats, a_stats
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
|
| 204 |
btn.click(run_compare, [prompt], [base_out, aether_out, base_stats, aether_stats])
|
| 205 |
prompt.submit(run_compare, [prompt], [base_out, aether_out, base_stats, aether_stats])
|
|
@@ -208,8 +211,8 @@ with gr.Blocks(css=CSS, theme=gr.themes.Base(primary_hue="blue", neutral_hue="zi
|
|
| 208 |
with gr.Row():
|
| 209 |
for p in ["hello", "How are you feeling?", "I've been anxious lately.", "Write a haiku about failure.", "What is the meaning of life?"]:
|
| 210 |
gr.Button(p, size="sm", elem_classes=["prompt-chip"]).click(
|
| 211 |
-
fn=lambda x=p:
|
| 212 |
-
).then(fn=
|
| 213 |
|
| 214 |
gr.HTML("""
|
| 215 |
<div id="footer">
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
def compare(prompt):
|
| 111 |
+
"""Generator: yields results as each engine finishes."""
|
| 112 |
if not prompt or not prompt.strip():
|
| 113 |
+
yield "", "", "", ""
|
| 114 |
+
return
|
| 115 |
|
| 116 |
+
# Run PyTorch first, show immediately
|
| 117 |
base_text, base_time, base_toks, base_ms = gen_pytorch(prompt)
|
|
|
|
|
|
|
| 118 |
base_stats = f"{base_toks} tokens in {base_time:.1f}s ({base_ms:.0f}ms/tok)"
|
| 119 |
+
yield base_text, "generating...", base_stats, "running..."
|
| 120 |
|
| 121 |
+
# Then run Aether, show when done
|
| 122 |
+
aether_text, aether_time, aether_toks, aether_ms = gen_aether(prompt)
|
| 123 |
+
aether_stats = f"{aether_toks} tokens in {aether_time:.1f}s ({aether_ms:.0f}ms/tok)"
|
| 124 |
+
yield base_text, aether_text, base_stats, aether_stats
|
| 125 |
|
| 126 |
|
| 127 |
CSS = """
|
|
|
|
| 196 |
aether_stats = gr.HTML('<p class="stats-text">--</p>')
|
| 197 |
|
| 198 |
def run_compare(prompt_text):
|
| 199 |
+
for base_text, aether_text, b_stats, a_stats in compare(prompt_text):
|
| 200 |
+
yield (
|
| 201 |
+
base_text,
|
| 202 |
+
aether_text,
|
| 203 |
+
f'<p class="stats-text">{b_stats}</p>',
|
| 204 |
+
f'<p class="stats-text">{a_stats}</p>',
|
| 205 |
+
)
|
| 206 |
|
| 207 |
btn.click(run_compare, [prompt], [base_out, aether_out, base_stats, aether_stats])
|
| 208 |
prompt.submit(run_compare, [prompt], [base_out, aether_out, base_stats, aether_stats])
|
|
|
|
| 211 |
with gr.Row():
|
| 212 |
for p in ["hello", "How are you feeling?", "I've been anxious lately.", "Write a haiku about failure.", "What is the meaning of life?"]:
|
| 213 |
gr.Button(p, size="sm", elem_classes=["prompt-chip"]).click(
|
| 214 |
+
fn=lambda x=p: x, outputs=[prompt]
|
| 215 |
+
).then(fn=run_compare, inputs=[prompt], outputs=[base_out, aether_out, base_stats, aether_stats])
|
| 216 |
|
| 217 |
gr.HTML("""
|
| 218 |
<div id="footer">
|