| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import os, time, json, re, gc, subprocess |
| import gradio as gr |
| import torch |
| import numpy as np |
| import argparse |
| import time |
| import sampling |
| import copy |
| from datetime import datetime |
| from huggingface_hub import hf_hub_download |
| from pynvml import * |
| from tokenizer_util import add_tokenizer_argument, get_tokenizer |
| import rwkv_world_tokenizer |
| from huggingface_hub import snapshot_download, hf_hub_download |
| hf_hub_download(repo_id="JoPmt/RWKV-5-3B-V2-Quant", filename="rwkv-5-world-3b-v2-20231118-ctx16k.Q4_0.bin", local_dir='~/app/Downloads') |
| model_path='~/app/Downloads/rwkv-5-world-3b-v2-20231118-ctx16k.Q4_0.bin' |
| from copy import deepcopy |
| from enum import Enum |
| from typing import Dict, List |
| from huggingface_hub import InferenceClient |
| from transformers.agents import PythonInterpreterTool |
| from transformers import AutoTokenizer |
| tokenizer=AutoTokenizer.from_pretrained("NousResearch/Hermes-2-Pro-Llama-3-8B",revision="pr/13") |
| tools=[PythonInterpreterTool()] |
| os.system("apt-get update && apt-get install cmake gcc g++") |
| os.system("git clone --recursive https://github.com/JoPmt/rwkv.cpp.git && cd rwkv.cpp && mkdir build && cd build && cmake .. -DRWKV_CUBLAS=ON -DRWKV_BUILD_SHARED_LIBRARY=ON -DGGML_CUDA=ON -DRWKV_BUILD_PYTHON_MODULE=ON -DRWKV_BUILD_TOOLS=ON -DRWKV_BUILD_EXTRAS=ON && cmake --build . --config Release && make RWKV_CUBLAS=1 GGML_CUDA=1") |
| import rwkv_cpp_model |
| import rwkv_cpp_shared_library |
|
|
| def find_lib(): |
| for root, dirs, files in os.walk("/"): |
| for file in files: |
| if file == "librwkv.so": |
| return os.path.join(root, file) |
| return None |
| library_path = find_lib() |
| rwkv_lib = rwkv_cpp_shared_library.RWKVSharedLibrary(library_path) |
| modal = rwkv_cpp_model.RWKVModel(rwkv_lib,model_path,thread_count=2) |
| print('Loading RWKV model') |
| tokenizer_decode, tokenizer_encode = get_tokenizer('auto', modal.n_vocab) |
| out_str = '' |
| prompt = out_str |
| token_count = 1200 |
| temperature = 1.0 |
| top_p = 0.7 |
| presence_penalty = 0.1 |
| count_penalty = 0.4 |
| def generate_prompt(instruction, zput=""): |
| instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') |
| zput = zput.strip().replace('\r\n','\n').replace('\n\n','\n') |
| if zput: |
| return f"""Instruction: {instruction} |
| Input: {zput} |
| Response:""" |
| else: |
| return f"""User: hi |
| Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it. |
| User: {instruction} |
| Assistant:""" |
| class MessageRole(str, Enum): |
| USER = "user" |
| ASSISTANT = "assistant" |
| SYSTEM = "system" |
| TOOL_CALL = "tool-call" |
| TOOL_RESPONSE = "tool-response" |
| @classmethod |
| def roles(cls): |
| return [r.value for r in cls] |
| def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}): |
| """ |
| Subsequent messages with the same role will be concatenated to a single message. |
| |
| Args: |
| message_list (`List[Dict[str, str]]`): List of chat messages. |
| """ |
| final_message_list = [] |
| message_list = deepcopy(message_list) |
| for message in message_list: |
| if not set(message.keys()) == {"role", "content"}: |
| raise ValueError("Message should contain only 'role' and 'content' keys!") |
|
|
| role = message["role"] |
| if role not in MessageRole.roles(): |
| raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.") |
|
|
| if role in role_conversions: |
| message["role"] = role_conversions[role] |
|
|
| if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]: |
| final_message_list[-1]["content"] = "\n=======\n" + message["content"] |
| else: |
| final_message_list.append(message) |
| return final_message_list |
| llama_role_conversions = { |
| MessageRole.TOOL_RESPONSE: MessageRole.USER, |
| MessageRole.TOOL_CALL: MessageRole.USER, |
| } |
| class HfEngine: |
| def __init__(self, model: str = "JoPmt/JoPmt"): |
| self.model = model |
| self.client = modal |
| def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str: |
| messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) |
| print(messages) |
| pret='' |
| prut='' |
| for message in messages: |
| print(message['content']) |
| if message['role'].lower() == 'system': |
| pret+=''+message['content']+'' |
| if message['role'].lower() == 'user': |
| prut+=''+message['content']+'' |
| |
| prompt=tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True,) |
| print(prompt) |
| token_count=1200 |
| temperature=1.0 |
| top_p=0.7 |
| presencePenalty = 0.1 |
| countPenalty = 0.4 |
| token_ban=[] |
| stop_token=[0] |
| ctx=pret |
| prompt=prut |
| all_tokens = [] |
| out_last = 0 |
| out_str = '' |
| occurrence = {} |
| state = None |
| ctx=generate_prompt(ctx,prompt) |
| prompt_tokens = tokenizer_encode(ctx) |
| prompt_token_count = len(prompt_tokens) |
| init_logits, init_state = modal.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True) |
| logits, state = init_logits.copy(), init_state.copy() |
| out_str = '' |
| occurrence = {} |
| bof=[] |
| for i in range(token_count): |
| for n in occurrence: |
| logits[n] -= (presencePenalty + occurrence[n] * countPenalty) |
| token = sampling.sample_logits(logits, temperature, top_p) |
|
|
| if token in stop_token: |
| break |
| all_tokens += [token] |
| |
| for xxx in occurrence: |
| occurrence[xxx] *= 0.996 |
|
|
| if token not in occurrence: |
| occurrence[token] = 1 |
| else: |
| occurrence[token] += 1 |
|
|
| tmp = tokenizer_decode(all_tokens[out_last:]) |
| if '\ufffd' not in tmp: |
| out_str += tmp |
| out_last = i + 1 |
| |
| logits, state = modal.eval(token, state, state, logits, use_numpy=True) |
| del state |
| gc.collect() |
| return out_str.strip() |