Spaces:
Running
on
T4
Running
on
T4
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,24 +25,36 @@ pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
|
| 25 |
|
| 26 |
args = model.args
|
| 27 |
eng_name = 'rwkv-x060-eng_single_round_qa-7B-20240516-ctx2048'
|
| 28 |
-
chn_name = 'rwkv-x060-chn_single_round_qa-7B-20240516-ctx2048'
|
| 29 |
eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
|
| 30 |
-
chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
|
| 31 |
state_eng_raw = torch.load(eng_file)
|
| 32 |
-
state_chn_raw = torch.load(chn_file)
|
| 33 |
state_eng = [None] * args.n_layer * 3
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
state_chn = [None] * args.n_layer * 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
for i in range(args.n_layer):
|
| 36 |
dd = model.strategy[i]
|
| 37 |
dev = dd.device
|
| 38 |
atype = dd.atype
|
| 39 |
state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
| 40 |
-
state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
| 41 |
state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
| 42 |
-
state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
| 43 |
state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
|
|
|
|
|
|
|
|
|
| 44 |
state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def generate_prompt(instruction, input=""):
|
| 47 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
| 48 |
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
|
|
@@ -208,6 +220,56 @@ def evaluate_chn(
|
|
| 208 |
torch.cuda.empty_cache()
|
| 209 |
yield out_str.strip()
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
examples = [
|
| 212 |
["Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
|
| 213 |
["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
|
|
@@ -242,6 +304,14 @@ examples_chn = [
|
|
| 242 |
["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
| 243 |
]
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
##########################################################################
|
| 246 |
|
| 247 |
with gr.Blocks(title=title) as demo:
|
|
@@ -307,6 +377,26 @@ with gr.Blocks(title=title) as demo:
|
|
| 307 |
clear.click(lambda: None, [], [output])
|
| 308 |
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
demo.queue(concurrency_count=1, max_size=10)
|
| 312 |
demo.launch(share=False)
|
|
|
|
| 25 |
|
| 26 |
args = model.args
|
| 27 |
eng_name = 'rwkv-x060-eng_single_round_qa-7B-20240516-ctx2048'
|
|
|
|
| 28 |
eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
|
|
|
|
| 29 |
state_eng_raw = torch.load(eng_file)
|
|
|
|
| 30 |
state_eng = [None] * args.n_layer * 3
|
| 31 |
+
|
| 32 |
+
chn_name = 'rwkv-x060-chn_single_round_qa-7B-20240516-ctx2048'
|
| 33 |
+
chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
|
| 34 |
+
state_chn_raw = torch.load(chn_file)
|
| 35 |
state_chn = [None] * args.n_layer * 3
|
| 36 |
+
|
| 37 |
+
wyw_name = 'rwkv-x060-chn_文言文和古典名著_single_round_qa-7B-20240601-ctx2048'
|
| 38 |
+
wyw_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{wyw_name}.pth")
|
| 39 |
+
state_wyw_raw = torch.load(wyw_file)
|
| 40 |
+
state_wyw = [None] * args.n_layer * 3
|
| 41 |
+
|
| 42 |
for i in range(args.n_layer):
|
| 43 |
dd = model.strategy[i]
|
| 44 |
dev = dd.device
|
| 45 |
atype = dd.atype
|
| 46 |
state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
|
|
|
| 47 |
state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
|
|
|
| 48 |
state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
| 49 |
+
|
| 50 |
+
state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
| 51 |
+
state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
| 52 |
state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
| 53 |
|
| 54 |
+
state_wyw[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
| 55 |
+
state_wyw[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
| 56 |
+
state_wyw[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
| 57 |
+
|
| 58 |
def generate_prompt(instruction, input=""):
|
| 59 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
| 60 |
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
|
|
|
|
| 220 |
torch.cuda.empty_cache()
|
| 221 |
yield out_str.strip()
|
| 222 |
|
| 223 |
+
def evaluate_wyw(
|
| 224 |
+
ctx,
|
| 225 |
+
token_count=gen_limit,
|
| 226 |
+
temperature=1.0,
|
| 227 |
+
top_p=0.3,
|
| 228 |
+
presencePenalty=0.3,
|
| 229 |
+
countPenalty=0.3,
|
| 230 |
+
):
|
| 231 |
+
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
|
| 232 |
+
alpha_frequency = countPenalty,
|
| 233 |
+
alpha_presence = presencePenalty,
|
| 234 |
+
token_ban = [], # ban the generation of some tokens
|
| 235 |
+
token_stop = [0]) # stop generation whenever you see any token here
|
| 236 |
+
ctx = qa_prompt(ctx)
|
| 237 |
+
all_tokens = []
|
| 238 |
+
out_last = 0
|
| 239 |
+
out_str = ''
|
| 240 |
+
occurrence = {}
|
| 241 |
+
state = copy.deepcopy(state_wyw)
|
| 242 |
+
for i in range(int(token_count)):
|
| 243 |
+
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
|
| 244 |
+
for n in occurrence:
|
| 245 |
+
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
| 246 |
+
|
| 247 |
+
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
|
| 248 |
+
if token in args.token_stop:
|
| 249 |
+
break
|
| 250 |
+
all_tokens += [token]
|
| 251 |
+
for xxx in occurrence:
|
| 252 |
+
occurrence[xxx] *= penalty_decay
|
| 253 |
+
if token not in occurrence:
|
| 254 |
+
occurrence[token] = 1
|
| 255 |
+
else:
|
| 256 |
+
occurrence[token] += 1
|
| 257 |
+
|
| 258 |
+
tmp = pipeline.decode(all_tokens[out_last:])
|
| 259 |
+
if '\ufffd' not in tmp:
|
| 260 |
+
out_str += tmp
|
| 261 |
+
yield out_str.strip()
|
| 262 |
+
out_last = i + 1
|
| 263 |
+
|
| 264 |
+
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
| 265 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 266 |
+
print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
|
| 267 |
+
del out
|
| 268 |
+
del state
|
| 269 |
+
gc.collect()
|
| 270 |
+
torch.cuda.empty_cache()
|
| 271 |
+
yield out_str.strip()
|
| 272 |
+
|
| 273 |
examples = [
|
| 274 |
["Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
|
| 275 |
["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
|
|
|
|
| 304 |
["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
| 305 |
]
|
| 306 |
|
| 307 |
+
examples_wyw = [
|
| 308 |
+
["我和前男友分手了", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
| 309 |
+
["量子计算机的原理", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
| 310 |
+
["李白和杜甫的结拜故事", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
| 311 |
+
["林黛玉和伏地魔的关系是什么?", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
| 312 |
+
["我被同事陷害了,帮我写一篇文言文骂他", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
| 313 |
+
]
|
| 314 |
+
|
| 315 |
##########################################################################
|
| 316 |
|
| 317 |
with gr.Blocks(title=title) as demo:
|
|
|
|
| 377 |
clear.click(lambda: None, [], [output])
|
| 378 |
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
|
| 379 |
|
| 380 |
+
with gr.Tab("=== WenYanWen Q/A ==="):
|
| 381 |
+
gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [WenYanWen 文言文 Q/A](https://huggingface.co/BlinkDL/temp-latest-training-models/blob/main/{wyw_name}.pth). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
|
| 382 |
+
with gr.Row():
|
| 383 |
+
with gr.Column():
|
| 384 |
+
prompt = gr.Textbox(lines=2, label="Prompt", value="我和前男友分手了")
|
| 385 |
+
token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
|
| 386 |
+
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
|
| 387 |
+
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
|
| 388 |
+
presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
|
| 389 |
+
count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
|
| 390 |
+
with gr.Column():
|
| 391 |
+
with gr.Row():
|
| 392 |
+
submit = gr.Button("Submit", variant="primary")
|
| 393 |
+
clear = gr.Button("Clear", variant="secondary")
|
| 394 |
+
output = gr.Textbox(label="Output", lines=30)
|
| 395 |
+
data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_chn, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
|
| 396 |
+
submit.click(evaluate_wyw, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
|
| 397 |
+
clear.click(lambda: None, [], [output])
|
| 398 |
+
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
|
| 399 |
+
|
| 400 |
|
| 401 |
demo.queue(concurrency_count=1, max_size=10)
|
| 402 |
demo.launch(share=False)
|