Taylor commited on
Commit
e5e1d2b
·
1 Parent(s): 90bf42d

fix: remove missing WASM flashAttention + stream results independently

Browse files

1. flashAttentionMultiHead not in standalone WASM binary -- use JS
attention fallback (matVec is the real bottleneck, not attention)

2. Use Gradio generator pattern so PyTorch result shows immediately
when done, then Aether result shows when it finishes. No more
both flashing at the same time.

Files changed (2) hide show
  1. aether-server.mjs +2 -13
  2. app.py +18 -15
aether-server.mjs CHANGED
@@ -124,19 +124,8 @@ async function loadSIMD() {
124
  wasm.resetHeap(saved);
125
  return result;
126
  },
127
- flashAttentionMultiHead(query, keys, values, seqLen, numHeads, numKvHeads, headDim) {
128
- const saved = wasm.getHeapPtr();
129
- const scale = 1.0 / Math.sqrt(headDim);
130
- const qPtr = wasm.allocate(query.byteLength);
131
- const kPtr = wasm.allocate(keys.byteLength);
132
- const vPtr = wasm.allocate(values.byteLength);
133
- const rPtr = wasm.allocate(numHeads * headDim * 4);
134
- copyTo(qPtr, query); copyTo(kPtr, keys); copyTo(vPtr, values);
135
- wasm.flashAttentionMultiHead(qPtr, kPtr, vPtr, rPtr, seqLen, numHeads, numKvHeads, headDim, scale);
136
- const result = copyFrom(rPtr, numHeads * headDim);
137
- wasm.resetHeap(saved);
138
- return result;
139
- },
140
  };
141
  } catch (e) {
142
  console.warn(`[Aether] WASM SIMD failed: ${e.message}, using JS fallbacks`);
 
124
  wasm.resetHeap(saved);
125
  return result;
126
  },
127
+ // flashAttentionMultiHead: not in standalone WASM -- use JS attention
128
+ flashAttentionMultiHead: null,
 
 
 
 
 
 
 
 
 
 
 
129
  };
130
  } catch (e) {
131
  console.warn(`[Aether] WASM SIMD failed: ${e.message}, using JS fallbacks`);
app.py CHANGED
@@ -108,17 +108,20 @@ def gen_aether(prompt):
108
 
109
 
110
  def compare(prompt):
 
111
  if not prompt or not prompt.strip():
112
- return "", "", "", ""
 
113
 
114
- # Run both
115
  base_text, base_time, base_toks, base_ms = gen_pytorch(prompt)
116
- aether_text, aether_time, aether_toks, aether_ms = gen_aether(prompt)
117
-
118
  base_stats = f"{base_toks} tokens in {base_time:.1f}s ({base_ms:.0f}ms/tok)"
119
- aether_stats = f"{aether_toks} tokens in {aether_time:.1f}s ({aether_ms:.0f}ms/tok)"
120
 
121
- return base_text, aether_text, base_stats, aether_stats
 
 
 
122
 
123
 
124
  CSS = """
@@ -193,13 +196,13 @@ with gr.Blocks(css=CSS, theme=gr.themes.Base(primary_hue="blue", neutral_hue="zi
193
  aether_stats = gr.HTML('<p class="stats-text">--</p>')
194
 
195
  def run_compare(prompt_text):
196
- base_text, aether_text, b_stats, a_stats = compare(prompt_text)
197
- return (
198
- base_text,
199
- aether_text,
200
- f'<p class="stats-text">{b_stats}</p>',
201
- f'<p class="stats-text">{a_stats}</p>',
202
- )
203
 
204
  btn.click(run_compare, [prompt], [base_out, aether_out, base_stats, aether_stats])
205
  prompt.submit(run_compare, [prompt], [base_out, aether_out, base_stats, aether_stats])
@@ -208,8 +211,8 @@ with gr.Blocks(css=CSS, theme=gr.themes.Base(primary_hue="blue", neutral_hue="zi
208
  with gr.Row():
209
  for p in ["hello", "How are you feeling?", "I've been anxious lately.", "Write a haiku about failure.", "What is the meaning of life?"]:
210
  gr.Button(p, size="sm", elem_classes=["prompt-chip"]).click(
211
- fn=lambda x=p: run_compare(x), outputs=[base_out, aether_out, base_stats, aether_stats]
212
- ).then(fn=lambda x=p: x, outputs=[prompt])
213
 
214
  gr.HTML("""
215
  <div id="footer">
 
108
 
109
 
110
  def compare(prompt):
111
+ """Generator: yields results as each engine finishes."""
112
  if not prompt or not prompt.strip():
113
+ yield "", "", "", ""
114
+ return
115
 
116
+ # Run PyTorch first, show immediately
117
  base_text, base_time, base_toks, base_ms = gen_pytorch(prompt)
 
 
118
  base_stats = f"{base_toks} tokens in {base_time:.1f}s ({base_ms:.0f}ms/tok)"
119
+ yield base_text, "generating...", base_stats, "running..."
120
 
121
+ # Then run Aether, show when done
122
+ aether_text, aether_time, aether_toks, aether_ms = gen_aether(prompt)
123
+ aether_stats = f"{aether_toks} tokens in {aether_time:.1f}s ({aether_ms:.0f}ms/tok)"
124
+ yield base_text, aether_text, base_stats, aether_stats
125
 
126
 
127
  CSS = """
 
196
  aether_stats = gr.HTML('<p class="stats-text">--</p>')
197
 
198
  def run_compare(prompt_text):
199
+ for base_text, aether_text, b_stats, a_stats in compare(prompt_text):
200
+ yield (
201
+ base_text,
202
+ aether_text,
203
+ f'<p class="stats-text">{b_stats}</p>',
204
+ f'<p class="stats-text">{a_stats}</p>',
205
+ )
206
 
207
  btn.click(run_compare, [prompt], [base_out, aether_out, base_stats, aether_stats])
208
  prompt.submit(run_compare, [prompt], [base_out, aether_out, base_stats, aether_stats])
 
211
  with gr.Row():
212
  for p in ["hello", "How are you feeling?", "I've been anxious lately.", "Write a haiku about failure.", "What is the meaning of life?"]:
213
  gr.Button(p, size="sm", elem_classes=["prompt-chip"]).click(
214
+ fn=lambda x=p: x, outputs=[prompt]
215
+ ).then(fn=run_compare, inputs=[prompt], outputs=[base_out, aether_out, base_stats, aether_stats])
216
 
217
  gr.HTML("""
218
  <div id="footer">