| ######################################################################################################## | |
| # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM | |
| ######################################################################################################## | |
| import numpy as np | |
| import math, os, sys, types, time, gc | |
| import torch | |
| from src.utils import TOKENIZER | |
| try: | |
| os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] | |
| except: | |
| pass | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| np.set_printoptions(precision=4, suppress=True, linewidth=200) | |
| args = types.SimpleNamespace() | |
| ######################################################################################################## | |
| # Step 1: set model & config (use v4 to run your trained-from-scratch models. v4 and v4neo are compatible) | |
| ######################################################################################################## | |
| args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast) | |
| args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU) | |
| # if args.RUN_DEVICE == "cuda": | |
| # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output | |
| os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!! | |
| TOKEN_MODE = "pile" | |
| WORD_NAME = [ | |
| "20B_tokenizer.json", | |
| "20B_tokenizer.json", | |
| ] # [vocab, vocab] for Pile model | |
| UNKNOWN_CHAR = None | |
| vocab_size = 50277 | |
| # Download Pile models: https://huggingface.co/BlinkDL | |
| # or, set MODEL_NAME to your fine-tuned model | |
| # MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023" | |
| # n_layer = 12 | |
| # n_embd = 768 | |
| # ctx_len = 1024 | |
| # MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066' | |
| # n_layer = 24 | |
| # n_embd = 1024 | |
| # ctx_len = 1024 | |
| # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040' | |
| # n_layer = 24 | |
| # n_embd = 2048 | |
| # ctx_len = 1024 | |
| # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023' | |
| # n_layer = 32 | |
| # n_embd = 2560 | |
| # ctx_len = 1024 | |
| MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' | |
| n_layer = 32 | |
| n_embd = 4096 | |
| ctx_len = 1024 | |
| args.MODEL_NAME = MODEL_NAME | |
| args.n_layer = n_layer | |
| args.n_embd = n_embd | |
| args.ctx_len = ctx_len | |
| args.vocab_size = vocab_size | |
| args.head_qk = 0 | |
| args.pre_ffn = 0 | |
| args.grad_cp = 0 | |
| args.my_pos_emb = 0 | |
| os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE | |
| ######################################################################################################## | |
| # Step 2: set prompt & sampling stuffs | |
| ######################################################################################################## | |
| # context = 'A' | |
| # context = "\nIn the" | |
| # context = '\nSugar:' | |
| context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." | |
| # context = "\n深圳是" # test Chinese | |
| # context = "\n東京は" # test Japanese | |
| # ###### A good prompt for Q&A ###### | |
| # context = ''' | |
| # Questions & Helpful Answers | |
| # Ask Research Experts | |
| # Question: | |
| # Can penguins fly? | |
| # Full Answer: | |
| # ''' | |
| # ###### A good prompt for chatbot ###### | |
| # context = ''' | |
| # The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins. | |
| # User: who is president of usa? | |
| # Bot: It’s Joe Biden; he was sworn in earlier this year. | |
| # User: french revolution what year | |
| # Bot: It started in 1789, but it lasted 10 years until 1799. | |
| # User: guess i marry who ? | |
| # Bot: Only if you tell me more about yourself - what are your interests? | |
| # User: wat is lhc | |
| # Bot: It’s a large and very expensive piece of science equipment. If I understand correctly, it’s a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. | |
| # User:''' # type your question here | |
| NUM_TRIALS = 999 | |
| LENGTH_PER_TRIAL = 333 | |
| TEMPERATURE = 1.0 | |
| top_p = 0.8 | |
| top_p_newline = 0.9 # only used in TOKEN_MODE = char | |
| DEBUG_DEBUG = False # True False --> show softmax output | |
| ######################################################################################################## | |
| from src.model_run import RWKV_RNN | |
| model = RWKV_RNN(args) | |
| out, _ = model.forward([187], None) | |
| # print(out) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # input(0) | |
| tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) | |
| if TOKEN_MODE == "pile": | |
| assert tokenizer.tokenizer.decode([187]) == '\n' | |
| ######################################################################################################## | |
| if tokenizer.charMode: | |
| context = tokenizer.refine_context(context) | |
| ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] | |
| else: | |
| ctx = tokenizer.tokenizer.encode(context) | |
| src_len = len(ctx) | |
| src_ctx = ctx.copy() | |
| time_slot = {} | |
| time_ref = time.time_ns() | |
| def record_time(name): | |
| if name not in time_slot: | |
| time_slot[name] = 1e20 | |
| tt = (time.time_ns() - time_ref) / 1e9 | |
| if tt < time_slot[name]: | |
| time_slot[name] = tt | |
| init_state = None | |
| init_out = None | |
| state = None | |
| out = None | |
| for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): | |
| time_ref = time.time_ns() | |
| ctx = src_ctx.copy() | |
| if TRIAL == 0: | |
| for i in range(src_len): | |
| x = ctx[: i + 1] | |
| if i == src_len - 1: | |
| init_out, init_state = model.forward(x, init_state) | |
| else: | |
| init_state = model.forward(x, init_state, preprocess_only=True) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| record_time('preprocess') | |
| out_last = src_len | |
| for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): | |
| x = ctx[: i + 1] | |
| x = x[-ctx_len:] | |
| if i == src_len: | |
| out = init_out.clone() | |
| state = init_state.clone() | |
| else: | |
| out, state = model.forward(x, state) | |
| if DEBUG_DEBUG: | |
| if TOKEN_MODE == "pile": | |
| out[0] = -999999999 # disable <|endoftext|> | |
| ttt = tokenizer.sample_logits( | |
| out, | |
| x, | |
| ctx_len, | |
| temperature=TEMPERATURE, | |
| top_p_usual=top_p, | |
| top_p_newline=top_p_newline, | |
| ) | |
| ctx += [ttt] | |
| if tokenizer.charMode: | |
| char = tokenizer.itos[ttt] | |
| else: | |
| char = tokenizer.tokenizer.decode(ctx[out_last:]) | |
| if '\ufffd' not in char: # is valid utf8 string? | |
| out_last = i+1 | |
| record_time('total') | |
| # print(f'\n\n{time_slot}\n\n') | |