Spaces:
Running
Running
| """ | |
| 常量定义 - ROSA-QKV-1bit 演示应用 | |
| """ | |
| import re | |
| import gradio as gr | |
| # 演示配置 | |
| MAX_DEMO_LEN = 20 | |
| GRADIO_MAJOR = int((gr.__version__ or "0").split(".", maxsplit=1)[0]) | |
| # 算法代码示例 | |
| ROSA_CODE = """ | |
| def rosa_qkv_naive(qqq, kkk, vvv): | |
| n=len(qqq); out=[-1]*n | |
| for i in range(n): | |
| for w in range(i+1,0,-1): | |
| t=qqq[i+1-w:i+1] | |
| for j in range(i-w,-1,-1): | |
| if kkk[j:j+w]==t: | |
| out[i]=vvv[j+w] | |
| break | |
| if out[i]!=-1: | |
| break | |
| return out | |
| """.strip( | |
| "\n" | |
| ) | |
| ROSA_QUICK_CODE = """ | |
| def rosa_qkv_ref_minus1(qqq, kkk, vvv): # note: input will never contain "-1" | |
| n=len(qqq); y=[-1]*n; s=2*n+1; t=[None]*s; f=[-1]*s; m=[0]*s; r=[-1]*s; t[0]={}; g=0; u=1; w=h=0; assert n==len(kkk)==len(vvv) | |
| for i,(q,k) in enumerate(zip(qqq,kkk)): | |
| p,x=w,h | |
| while p!=-1 and q not in t[p]: x=m[p] if x>m[p] else x; p=f[p] | |
| p,x=(t[p][q],x+1) if p!=-1 else (0,0); v=p | |
| while f[v]!=-1 and m[f[v]]>=x: v=f[v] | |
| while v!=-1 and (m[v]<=0 or r[v]<0): v=f[v] | |
| y[i]=vvv[r[v]+1] if v!=-1 else -1; w,h=p,x; j=u; u+=1; t[j]={}; m[j]=m[g]+1; p=g | |
| while p!=-1 and k not in t[p]: t[p][k]=j; p=f[p] | |
| if p==-1: f[j]=0 | |
| else: | |
| d=t[p][k] | |
| if m[p]+1==m[d]: f[j]=d | |
| else: | |
| b=u; u+=1; t[b]=t[d].copy(); m[b]=m[p]+1; f[b]=f[d]; r[b]=r[d]; f[d]=f[j]=b | |
| while p!=-1 and t[p][k]==d: t[p][k]=b; p=f[p] | |
| v=g=j | |
| while v!=-1 and r[v]<i: r[v]=i; v=f[v] | |
| return y | |
| """.strip( | |
| "\n" | |
| ) | |
| # Python 关键字和内置函数 | |
| KEYWORDS = { | |
| "def", | |
| "for", | |
| "in", | |
| "if", | |
| "else", | |
| "while", | |
| "break", | |
| "return", | |
| "assert", | |
| "None", | |
| "True", | |
| "False", | |
| } | |
| BUILTINS = {"len", "range", "zip", "enumerate"} | |
| # 代码高亮的正则表达式 | |
| KEYWORD_RE = re.compile(r"\b(" + "|".join(sorted(KEYWORDS)) + r")\b") | |
| BUILTIN_RE = re.compile(r"\b(" + "|".join(sorted(BUILTINS)) + r")\b") | |
| NUMBER_RE = re.compile(r"(?<![\w.])(-?\d+)(?![\w.])") | |