Jellyfish042 commited on
Commit
7dff699
·
1 Parent(s): 05b8d8c

Sync playback speed and add quick-code panel

Browse files
Files changed (1) hide show
  1. app.py +122 -43
app.py CHANGED
@@ -10,7 +10,7 @@ MAX_DEMO_LEN = 20
10
  GRADIO_MAJOR = int((gr.__version__ or "0").split(".", maxsplit=1)[0])
11
 
12
  ROSA_CODE = """
13
- def rosa_qkv(qqq, kkk, vvv):
14
  n=len(qqq); out=[-1]*n
15
  for i in range(n):
16
  for w in range(i+1,0,-1):
@@ -22,7 +22,34 @@ def rosa_qkv(qqq, kkk, vvv):
22
  if out[i]!=-1:
23
  break
24
  return out
25
- """.strip("\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  KEYWORDS = {
@@ -30,11 +57,16 @@ KEYWORDS = {
30
  "for",
31
  "in",
32
  "if",
 
 
33
  "break",
34
  "return",
35
  "assert",
 
 
 
36
  }
37
- BUILTINS = {"len", "range"}
38
  KEYWORD_RE = re.compile(r"\b(" + "|".join(sorted(KEYWORDS)) + r")\b")
39
  BUILTIN_RE = re.compile(r"\b(" + "|".join(sorted(BUILTINS)) + r")\b")
40
  NUMBER_RE = re.compile(r"(?<![\w.])(-?\d+)(?![\w.])")
@@ -130,15 +162,9 @@ def build_code_html(code: str) -> str:
130
  if index == LINE_ASSIGN:
131
  line_with_markers = line_with_markers.replace("vvv[j+w]", marker_v, 1)
132
  highlighted = highlight_python_line(line_with_markers)
133
- highlighted = highlighted.replace(
134
- marker_t, '<span class="code-token" data-token="t">t</span>'
135
- )
136
- highlighted = highlighted.replace(
137
- marker_k, '<span class="code-token" data-token="k">kkk[j:j+w]</span>'
138
- )
139
- highlighted = highlighted.replace(
140
- marker_v, '<span class="code-token" data-token="v">vvv[j+w]</span>'
141
- )
142
  rendered.append(
143
  '<div class="code-line" data-line="{line}">'
144
  '<span class="line-no">{line}</span>'
@@ -149,6 +175,21 @@ def build_code_html(code: str) -> str:
149
  return "\n".join(rendered)
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def find_line_number(lines: list[str], needle: str, fallback: int = 1) -> int:
153
  for index, line in enumerate(lines, start=1):
154
  if needle in line:
@@ -157,6 +198,8 @@ def find_line_number(lines: list[str], needle: str, fallback: int = 1) -> int:
157
 
158
 
159
  ROSA_LINES = ROSA_CODE.splitlines()
 
 
160
  def find_line_after(lines: list[str], start_line: int, needle: str, fallback: int) -> int:
161
  start_index = min(max(start_line, 0), len(lines))
162
  for index in range(start_index, len(lines)):
@@ -176,6 +219,7 @@ LINE_IF_OUT = find_line_number(ROSA_LINES, "if out[i]!=-1:")
176
  LINE_BREAK_OUTER = find_line_after(ROSA_LINES, LINE_IF_OUT, "break", LINE_IF_OUT)
177
  LINE_RETURN = find_line_number(ROSA_LINES, "return out")
178
  CODE_HTML = build_code_html(ROSA_CODE)
 
179
 
180
 
181
  def parse_bits(text: str, name: str) -> list[int]:
@@ -350,6 +394,28 @@ CSS = """
350
  flex: 2 1 320px;
351
  min-width: 300px;
352
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  #rosa-vis {
354
  --rosa-blue: #3b82f6;
355
  --rosa-sky: #38bdf8;
@@ -599,7 +665,7 @@ CSS = """
599
  color: #64748b;
600
  font-variant-numeric: tabular-nums;
601
  }
602
- #rosa-code {
603
  background: #ffffff;
604
  border: 1px solid #e2e8f0;
605
  border-radius: 12px;
@@ -610,51 +676,53 @@ CSS = """
610
  max-height: 520px;
611
  overflow: auto;
612
  }
613
- #rosa-code .code-line {
614
  display: flex;
615
  gap: 10px;
616
  padding: 2px 6px;
617
  border-radius: 6px;
618
  }
619
- #rosa-code .line-no {
620
- width: 28px;
621
- text-align: right;
 
622
  color: #94a3b8;
623
  user-select: none;
 
624
  }
625
- #rosa-code .line-text {
626
  white-space: pre;
627
  flex: 1;
628
  }
629
- #rosa-code .code-line.active {
630
  background: #dbeafe;
631
  color: #1d4ed8;
632
  }
633
- #rosa-code .code-line.active .line-no {
634
  color: #1d4ed8;
635
  }
636
- #rosa-code .tok-keyword {
637
  color: #7c3aed;
638
  font-weight: 600;
639
  }
640
- #rosa-code .tok-builtin {
641
  color: #0ea5e9;
642
  }
643
- #rosa-code .tok-number {
644
  color: #f97316;
645
  }
646
- #rosa-code .tok-string {
647
  color: #10b981;
648
  }
649
- #rosa-code .tok-comment {
650
  color: #64748b;
651
  font-style: italic;
652
  }
653
- #rosa-code .code-token {
654
  border-bottom: 1px dashed #94a3b8;
655
  padding: 0 1px;
656
  }
657
- #rosa-code .code-token.active {
658
  background: #fef3c7;
659
  border-radius: 4px;
660
  border-bottom-color: #f59e0b;
@@ -703,7 +771,8 @@ JS_FUNC = """
703
 
704
  function getSpeed() {
705
  const slider = findInputById(speedId, 'input[type="range"], input');
706
- const value = slider ? parseFloat(slider.value) : 1;
 
707
  if (!Number.isFinite(value) || value <= 0) return 1;
708
  return value;
709
  }
@@ -718,7 +787,7 @@ JS_FUNC = """
718
  return root;
719
  }
720
 
721
- function animateTransfer(fromEl, toEl, value, onFinish) {
722
  if (!fromEl || !toEl) {
723
  if (debugAnim) {
724
  console.warn("[ROSA] animateTransfer missing element", {
@@ -787,7 +856,8 @@ JS_FUNC = """
787
  const dx = toRect.left - fromRect.left;
788
  const dy = toRect.top - fromRect.top;
789
  const toTransform = `translate(${dx}px, ${dy}px) scale(0.95)`;
790
- const duration = transferMs;
 
791
  const startAnimation = () => {
792
  if (bubble.animate) {
793
  const anim = bubble.animate(
@@ -1164,11 +1234,11 @@ JS_FUNC = """
1164
  const legend = document.createElement("div");
1165
  legend.className = "rosa-legend";
1166
  legend.innerHTML = `
1167
- <span class="legend-item"><span class="legend-dot legend-suffix"></span>Current suffix</span>
1168
- <span class="legend-item"><span class="legend-dot legend-window"></span>k window</span>
1169
- <span class="legend-item"><span class="legend-dot legend-match"></span>Match</span>
1170
- <span class="legend-item"><span class="legend-dot legend-v"></span>Read v</span>
1171
- <span class="legend-item"><span class="legend-dot legend-out"></span>Output</span>
1172
  `;
1173
  card.appendChild(legend);
1174
 
@@ -1328,17 +1398,20 @@ JS_FUNC = """
1328
  vCells[step.v_index] &&
1329
  outCells[step.i]
1330
  ) {
 
1331
  const holdToken = state.runToken;
1332
  if (state.outPending) {
1333
  state.outPending.add(step.i);
1334
  }
1335
- const totalWait = transferPauseMs + transferMs;
 
 
1336
  if (state.outPending) {
1337
  state.outPending.add(step.i);
1338
  }
1339
  const startTransfer = () => {
1340
  if (state.runToken !== holdToken) return;
1341
- animateTransfer(vCells[step.v_index], outCells[step.i], step.value, () => {
1342
  if (state.runToken !== holdToken) return;
1343
  if (state.outPending) {
1344
  state.outPending.delete(step.i);
@@ -1353,8 +1426,8 @@ JS_FUNC = """
1353
  }
1354
  });
1355
  };
1356
- if (transferPauseMs > 0) {
1357
- setTimeout(startTransfer, transferPauseMs);
1358
  } else {
1359
  startTransfer();
1360
  }
@@ -1458,8 +1531,8 @@ demo_context = gr.Blocks(css=CSS, js=JS_BOOT)
1458
  with demo_context as demo:
1459
  gr.HTML(
1460
  '<div class="page-header">'
1461
- '<div class="page-title">ROSA QKV Demo</div>'
1462
- '<div class="page-subtitle">Enter or randomize q/k/v (0/1 only). Click Start Demo to see suffix matching and retrieval.</div>'
1463
  "</div>"
1464
  )
1465
 
@@ -1474,7 +1547,7 @@ with demo_context as demo:
1474
  demo_btn = gr.Button("Start Demo", variant="primary")
1475
  speed = gr.Slider(
1476
  0.1,
1477
- 5.0,
1478
  value=2.0,
1479
  step=0.05,
1480
  label="Playback speed",
@@ -1488,7 +1561,13 @@ with demo_context as demo:
1488
  gr.HTML(
1489
  f'<div id="rosa-shell" class="rosa-shell">'
1490
  f'<div class="rosa-pane"><div id="rosa-vis"></div></div>'
1491
- f'<div class="rosa-code-pane">{CODE_HTML}</div>'
 
 
 
 
 
 
1492
  f"</div>"
1493
  )
1494
 
 
10
  GRADIO_MAJOR = int((gr.__version__ or "0").split(".", maxsplit=1)[0])
11
 
12
  ROSA_CODE = """
13
+ def rosa_qkv_naive(qqq, kkk, vvv):
14
  n=len(qqq); out=[-1]*n
15
  for i in range(n):
16
  for w in range(i+1,0,-1):
 
22
  if out[i]!=-1:
23
  break
24
  return out
25
+ """.strip(
26
+ "\n"
27
+ )
28
+
29
+ ROSA_QUICK_CODE = """
30
+ def rosa_qkv_ref_minus1(qqq, kkk, vvv): # note: input will never contain "-1"
31
+ n=len(qqq); y=[-1]*n; s=2*n+1; t=[None]*s; f=[-1]*s; m=[0]*s; r=[-1]*s; t[0]={}; g=0; u=1; w=h=0; assert n==len(kkk)==len(vvv)
32
+ for i,(q,k) in enumerate(zip(qqq,kkk)):
33
+ p,x=w,h
34
+ while p!=-1 and q not in t[p]: x=m[p] if x>m[p] else x; p=f[p]
35
+ p,x=(t[p][q],x+1) if p!=-1 else (0,0); v=p
36
+ while f[v]!=-1 and m[f[v]]>=x: v=f[v]
37
+ while v!=-1 and (m[v]<=0 or r[v]<0): v=f[v]
38
+ y[i]=vvv[r[v]+1] if v!=-1 else -1; w,h=p,x; j=u; u+=1; t[j]={}; m[j]=m[g]+1; p=g
39
+ while p!=-1 and k not in t[p]: t[p][k]=j; p=f[p]
40
+ if p==-1: f[j]=0
41
+ else:
42
+ d=t[p][k]
43
+ if m[p]+1==m[d]: f[j]=d
44
+ else:
45
+ b=u; u+=1; t[b]=t[d].copy(); m[b]=m[p]+1; f[b]=f[d]; r[b]=r[d]; f[d]=f[j]=b
46
+ while p!=-1 and t[p][k]==d: t[p][k]=b; p=f[p]
47
+ v=g=j
48
+ while v!=-1 and r[v]<i: r[v]=i; v=f[v]
49
+ return y
50
+ """.strip(
51
+ "\n"
52
+ )
53
 
54
 
55
  KEYWORDS = {
 
57
  "for",
58
  "in",
59
  "if",
60
+ "else",
61
+ "while",
62
  "break",
63
  "return",
64
  "assert",
65
+ "None",
66
+ "True",
67
+ "False",
68
  }
69
+ BUILTINS = {"len", "range", "zip", "enumerate"}
70
  KEYWORD_RE = re.compile(r"\b(" + "|".join(sorted(KEYWORDS)) + r")\b")
71
  BUILTIN_RE = re.compile(r"\b(" + "|".join(sorted(BUILTINS)) + r")\b")
72
  NUMBER_RE = re.compile(r"(?<![\w.])(-?\d+)(?![\w.])")
 
162
  if index == LINE_ASSIGN:
163
  line_with_markers = line_with_markers.replace("vvv[j+w]", marker_v, 1)
164
  highlighted = highlight_python_line(line_with_markers)
165
+ highlighted = highlighted.replace(marker_t, '<span class="code-token" data-token="t">t</span>')
166
+ highlighted = highlighted.replace(marker_k, '<span class="code-token" data-token="k">kkk[j:j+w]</span>')
167
+ highlighted = highlighted.replace(marker_v, '<span class="code-token" data-token="v">vvv[j+w]</span>')
 
 
 
 
 
 
168
  rendered.append(
169
  '<div class="code-line" data-line="{line}">'
170
  '<span class="line-no">{line}</span>'
 
175
  return "\n".join(rendered)
176
 
177
 
178
+ def build_plain_code_html(code: str, block_id: str) -> str:
179
+ lines = code.splitlines()
180
+ rendered = [f'<div id="{block_id}" class="rosa-code">']
181
+ for index, line in enumerate(lines, start=1):
182
+ highlighted = highlight_python_line(line)
183
+ rendered.append(
184
+ '<div class="code-line">'
185
+ '<span class="line-no">{line}</span>'
186
+ '<span class="line-text">{text}</span>'
187
+ "</div>".format(line=index, text=highlighted)
188
+ )
189
+ rendered.append("</div>")
190
+ return "\n".join(rendered)
191
+
192
+
193
  def find_line_number(lines: list[str], needle: str, fallback: int = 1) -> int:
194
  for index, line in enumerate(lines, start=1):
195
  if needle in line:
 
198
 
199
 
200
  ROSA_LINES = ROSA_CODE.splitlines()
201
+
202
+
203
  def find_line_after(lines: list[str], start_line: int, needle: str, fallback: int) -> int:
204
  start_index = min(max(start_line, 0), len(lines))
205
  for index in range(start_index, len(lines)):
 
219
  LINE_BREAK_OUTER = find_line_after(ROSA_LINES, LINE_IF_OUT, "break", LINE_IF_OUT)
220
  LINE_RETURN = find_line_number(ROSA_LINES, "return out")
221
  CODE_HTML = build_code_html(ROSA_CODE)
222
+ QUICK_CODE_HTML = build_plain_code_html(ROSA_QUICK_CODE, "rosa-code-quick")
223
 
224
 
225
  def parse_bits(text: str, name: str) -> list[int]:
 
394
  flex: 2 1 320px;
395
  min-width: 300px;
396
  }
397
+ .quick-code-details {
398
+ margin-top: 8px;
399
+ }
400
+ .quick-code-details > summary {
401
+ cursor: pointer;
402
+ user-select: none;
403
+ padding: 8px 10px;
404
+ border: 1px dashed #cbd5e1;
405
+ border-radius: 12px;
406
+ background: #f8fafc;
407
+ color: #0f172a;
408
+ font-weight: 600;
409
+ }
410
+ .quick-code-details[open] > summary {
411
+ margin-bottom: 10px;
412
+ }
413
+ .quick-code-details > summary::-webkit-details-marker {
414
+ display: none;
415
+ }
416
+ .quick-code-details .rosa-code {
417
+ max-height: 420px;
418
+ }
419
  #rosa-vis {
420
  --rosa-blue: #3b82f6;
421
  --rosa-sky: #38bdf8;
 
665
  color: #64748b;
666
  font-variant-numeric: tabular-nums;
667
  }
668
+ .rosa-code {
669
  background: #ffffff;
670
  border: 1px solid #e2e8f0;
671
  border-radius: 12px;
 
676
  max-height: 520px;
677
  overflow: auto;
678
  }
679
+ .rosa-code .code-line {
680
  display: flex;
681
  gap: 10px;
682
  padding: 2px 6px;
683
  border-radius: 6px;
684
  }
685
+ .rosa-code .line-no {
686
+ flex: 0 0 36px;
687
+ width: 36px;
688
+ text-align: left;
689
  color: #94a3b8;
690
  user-select: none;
691
+ font-variant-numeric: tabular-nums;
692
  }
693
+ .rosa-code .line-text {
694
  white-space: pre;
695
  flex: 1;
696
  }
697
+ .rosa-code .code-line.active {
698
  background: #dbeafe;
699
  color: #1d4ed8;
700
  }
701
+ .rosa-code .code-line.active .line-no {
702
  color: #1d4ed8;
703
  }
704
+ .rosa-code .tok-keyword {
705
  color: #7c3aed;
706
  font-weight: 600;
707
  }
708
+ .rosa-code .tok-builtin {
709
  color: #0ea5e9;
710
  }
711
+ .rosa-code .tok-number {
712
  color: #f97316;
713
  }
714
+ .rosa-code .tok-string {
715
  color: #10b981;
716
  }
717
+ .rosa-code .tok-comment {
718
  color: #64748b;
719
  font-style: italic;
720
  }
721
+ .rosa-code .code-token {
722
  border-bottom: 1px dashed #94a3b8;
723
  padding: 0 1px;
724
  }
725
+ .rosa-code .code-token.active {
726
  background: #fef3c7;
727
  border-radius: 4px;
728
  border-bottom-color: #f59e0b;
 
771
 
772
  function getSpeed() {
773
  const slider = findInputById(speedId, 'input[type="range"], input');
774
+ // If the speed slider isn't mounted yet (e.g. on first load), default to x2.
775
+ const value = slider ? parseFloat(slider.value) : 2;
776
  if (!Number.isFinite(value) || value <= 0) return 1;
777
  return value;
778
  }
 
787
  return root;
788
  }
789
 
790
+ function animateTransfer(fromEl, toEl, value, durationMs, onFinish) {
791
  if (!fromEl || !toEl) {
792
  if (debugAnim) {
793
  console.warn("[ROSA] animateTransfer missing element", {
 
856
  const dx = toRect.left - fromRect.left;
857
  const dy = toRect.top - fromRect.top;
858
  const toTransform = `translate(${dx}px, ${dy}px) scale(0.95)`;
859
+ const duration =
860
+ Number.isFinite(durationMs) && durationMs > 0 ? durationMs : transferMs;
861
  const startAnimation = () => {
862
  if (bubble.animate) {
863
  const anim = bubble.animate(
 
1234
  const legend = document.createElement("div");
1235
  legend.className = "rosa-legend";
1236
  legend.innerHTML = `
1237
+ <span class="legend-item"><span class="legend-dot legend-suffix"></span>Current suffix (t)</span>
1238
+ <span class="legend-item"><span class="legend-dot legend-window"></span>k window (kkk[j:j+w])</span>
1239
+ <span class="legend-item"><span class="legend-dot legend-match"></span>Match (kkk[j:j+w]==t)</span>
1240
+ <span class="legend-item"><span class="legend-dot legend-v"></span>Read v (vvv[j+w])</span>
1241
+ <span class="legend-item"><span class="legend-dot legend-out"></span>Output (out)</span>
1242
  `;
1243
  card.appendChild(legend);
1244
 
 
1398
  vCells[step.v_index] &&
1399
  outCells[step.i]
1400
  ) {
1401
+ const speed = getSpeed();
1402
  const holdToken = state.runToken;
1403
  if (state.outPending) {
1404
  state.outPending.add(step.i);
1405
  }
1406
+ const pauseMs = transferPauseMs / speed;
1407
+ const durationMs = transferMs / speed;
1408
+ const totalWait = pauseMs + durationMs;
1409
  if (state.outPending) {
1410
  state.outPending.add(step.i);
1411
  }
1412
  const startTransfer = () => {
1413
  if (state.runToken !== holdToken) return;
1414
+ animateTransfer(vCells[step.v_index], outCells[step.i], step.value, durationMs, () => {
1415
  if (state.runToken !== holdToken) return;
1416
  if (state.outPending) {
1417
  state.outPending.delete(step.i);
 
1426
  }
1427
  });
1428
  };
1429
+ if (pauseMs > 0) {
1430
+ setTimeout(startTransfer, pauseMs);
1431
  } else {
1432
  startTransfer();
1433
  }
 
1531
  with demo_context as demo:
1532
  gr.HTML(
1533
  '<div class="page-header">'
1534
+ '<div class="page-title">RWKV-8 ROSA-QKV-1bit Demo</div>'
1535
+ '<div class="page-subtitle">This is using naive algorithm (not suffix automaton). Enter or randomize q/k/v (0/1 only), then click [Start Demo].</div>'
1536
  "</div>"
1537
  )
1538
 
 
1547
  demo_btn = gr.Button("Start Demo", variant="primary")
1548
  speed = gr.Slider(
1549
  0.1,
1550
+ 10.0,
1551
  value=2.0,
1552
  step=0.05,
1553
  label="Playback speed",
 
1561
  gr.HTML(
1562
  f'<div id="rosa-shell" class="rosa-shell">'
1563
  f'<div class="rosa-pane"><div id="rosa-vis"></div></div>'
1564
+ f'<div class="rosa-code-pane">'
1565
+ f"{CODE_HTML}"
1566
+ f'<details class="quick-code-details">'
1567
+ f"<summary>Fast version (click to expand)</summary>"
1568
+ f"{QUICK_CODE_HTML}"
1569
+ f"</details>"
1570
+ f"</div>"
1571
  f"</div>"
1572
  )
1573