| | ''' |
| | This loader is not currently maintained as RWKV can now be loaded |
| | through the transformers library. |
| | ''' |
| |
|
| | import copy |
| | import os |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | from tokenizers import Tokenizer |
| | from transformers import is_torch_xpu_available |
| |
|
| | import modules.shared as shared |
| | from modules.callbacks import Iteratorize |
| |
|
| | np.set_printoptions(precision=4, suppress=True, linewidth=200) |
| |
|
| | os.environ['RWKV_JIT_ON'] = '1' |
| | os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' |
| |
|
| | from rwkv.model import RWKV |
| | from rwkv.utils import PIPELINE, PIPELINE_ARGS |
| |
|
| |
|
| | class RWKVModel: |
| | def __init__(self): |
| | pass |
| |
|
| | @classmethod |
| | def from_pretrained(self, path, dtype="bf16" if is_torch_xpu_available() else "fp16", device="xpu" if is_torch_xpu_available() else "cuda"): |
| | tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json") |
| | if shared.args.rwkv_strategy is None: |
| | model = RWKV(model=str(path), strategy=f'{device} {dtype}') |
| | else: |
| | model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy) |
| |
|
| | pipeline = PIPELINE(model, str(tokenizer_path)) |
| | result = self() |
| | result.pipeline = pipeline |
| | result.model = model |
| | result.cached_context = "" |
| | result.cached_model_state = None |
| | result.cached_output_logits = None |
| | return result |
| |
|
| | def generate(self, prompt, state, callback=None): |
| | args = PIPELINE_ARGS( |
| | temperature=state['temperature'], |
| | top_p=state['top_p'], |
| | top_k=state['top_k'], |
| | alpha_frequency=0.1, |
| | alpha_presence=0.1, |
| | token_ban=[0], |
| | token_stop=[] |
| | ) |
| |
|
| | if self.cached_context != "": |
| | if prompt.startswith(self.cached_context): |
| | prompt = prompt[len(self.cached_context):] |
| | else: |
| | self.cached_context = "" |
| | self.cached_model_state = None |
| | self.cached_output_logits = None |
| |
|
| | |
| | out = self.generate_from_cached_state(prompt, token_count=state['max_new_tokens'], args=args, callback=callback) |
| | return out |
| |
|
| | def generate_with_streaming(self, *args, **kwargs): |
| | with Iteratorize(self.generate, args, kwargs, callback=None) as generator: |
| | reply = '' |
| | for token in generator: |
| | reply += token |
| | yield reply |
| |
|
| | |
| | def generate_from_cached_state(self, ctx="", token_count=20, args=None, callback=None): |
| | all_tokens = [] |
| | out_str = '' |
| | occurrence = {} |
| | state = copy.deepcopy(self.cached_model_state) if self.cached_model_state is not None else None |
| |
|
| | |
| | |
| | |
| | if ctx == "": |
| | out = self.cached_output_logits |
| |
|
| | token = None |
| | for i in range(token_count): |
| | |
| | tokens = self.pipeline.encode(ctx) if i == 0 else [token] |
| | while len(tokens) > 0: |
| | out, state = self.model.forward(tokens[:args.chunk_len], state) |
| | tokens = tokens[args.chunk_len:] |
| | if i == 0: |
| | begin_token = len(all_tokens) |
| | last_token_posi = begin_token |
| | |
| | |
| | |
| | |
| | if i == 0: |
| | self.cached_context += ctx |
| | self.cached_model_state = copy.deepcopy(state) |
| | self.cached_output_logits = copy.deepcopy(out) |
| |
|
| | |
| | for n in args.token_ban: |
| | out[n] = -float('inf') |
| |
|
| | for n in occurrence: |
| | out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) |
| |
|
| | |
| | token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k) |
| | if token in args.token_stop: |
| | break |
| |
|
| | all_tokens += [token] |
| | if token not in occurrence: |
| | occurrence[token] = 1 |
| | else: |
| | occurrence[token] += 1 |
| |
|
| | |
| | tmp = self.pipeline.decode(all_tokens[last_token_posi:]) |
| | if '\ufffd' not in tmp: |
| | if callback: |
| | callback(tmp) |
| |
|
| | out_str += tmp |
| | last_token_posi = begin_token + i + 1 |
| | return out_str |
| |
|
| |
|
| | class RWKVTokenizer: |
| | def __init__(self): |
| | pass |
| |
|
| | @classmethod |
| | def from_pretrained(self, path): |
| | tokenizer_path = path / "20B_tokenizer.json" |
| | tokenizer = Tokenizer.from_file(str(tokenizer_path)) |
| | result = self() |
| | result.tokenizer = tokenizer |
| | return result |
| |
|
| | def encode(self, prompt): |
| | return self.tokenizer.encode(prompt).ids |
| |
|
| | def decode(self, ids): |
| | return self.tokenizer.decode(ids) |
| |
|