File size: 2,087 Bytes
53fc829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
常量定义 - ROSA-QKV-1bit 演示应用
"""

import re
import gradio as gr

# 演示配置
MAX_DEMO_LEN = 20
GRADIO_MAJOR = int((gr.__version__ or "0").split(".", maxsplit=1)[0])

# 算法代码示例
ROSA_CODE = """
def rosa_qkv_naive(qqq, kkk, vvv):
    n=len(qqq); out=[-1]*n
    for i in range(n):
        for w in range(i+1,0,-1):
            t=qqq[i+1-w:i+1]
            for j in range(i-w,-1,-1):
                if kkk[j:j+w]==t:
                    out[i]=vvv[j+w]
                    break
            if out[i]!=-1:
                break
    return out
""".strip(
    "\n"
)

ROSA_QUICK_CODE = """
def rosa_qkv_ref_minus1(qqq, kkk, vvv): # note: input will never contain "-1"
    n=len(qqq); y=[-1]*n; s=2*n+1; t=[None]*s; f=[-1]*s; m=[0]*s; r=[-1]*s; t[0]={}; g=0; u=1; w=h=0; assert n==len(kkk)==len(vvv)
    for i,(q,k) in enumerate(zip(qqq,kkk)):
        p,x=w,h
        while p!=-1 and q not in t[p]: x=m[p] if x>m[p] else x; p=f[p]
        p,x=(t[p][q],x+1) if p!=-1 else (0,0); v=p
        while f[v]!=-1 and m[f[v]]>=x: v=f[v]
        while v!=-1 and (m[v]<=0 or r[v]<0): v=f[v]
        y[i]=vvv[r[v]+1] if v!=-1 else -1; w,h=p,x; j=u; u+=1; t[j]={}; m[j]=m[g]+1; p=g
        while p!=-1 and k not in t[p]: t[p][k]=j; p=f[p]
        if p==-1: f[j]=0
        else:
            d=t[p][k]
            if m[p]+1==m[d]: f[j]=d
            else:
                b=u; u+=1; t[b]=t[d].copy(); m[b]=m[p]+1; f[b]=f[d]; r[b]=r[d]; f[d]=f[j]=b
                while p!=-1 and t[p][k]==d: t[p][k]=b; p=f[p]
        v=g=j
        while v!=-1 and r[v]<i: r[v]=i; v=f[v]
    return y
""".strip(
    "\n"
)

# Python 关键字和内置函数
KEYWORDS = {
    "def",
    "for",
    "in",
    "if",
    "else",
    "while",
    "break",
    "return",
    "assert",
    "None",
    "True",
    "False",
}

BUILTINS = {"len", "range", "zip", "enumerate"}

# 代码高亮的正则表达式
KEYWORD_RE = re.compile(r"\b(" + "|".join(sorted(KEYWORDS)) + r")\b")
BUILTIN_RE = re.compile(r"\b(" + "|".join(sorted(BUILTINS)) + r")\b")
NUMBER_RE = re.compile(r"(?<![\w.])(-?\d+)(?![\w.])")