Spaces:
Runtime error
Runtime error
Upload 7 files
Browse files- app.py +131 -0
- gradio_ui.py +442 -0
- koboldai_client.py +117 -0
- model.py +110 -0
- parsing.py +40 -0
- prompting.py +54 -0
- utils.py +11 -0
app.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import logging
|
| 4 |
+
import typing as t
|
| 5 |
+
|
| 6 |
+
from gradio_ui import build_gradio_ui_for
|
| 7 |
+
from koboldai_client import run_raw_inference_on_kai, wait_for_kai_server
|
| 8 |
+
from parsing import parse_messages_from_str
|
| 9 |
+
from prompting import build_prompt_for
|
| 10 |
+
from utils import clear_stdout
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
|
| 15 |
+
# For UI debugging purposes.
|
| 16 |
+
DONT_USE_MODEL = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main(server_port: int,
|
| 20 |
+
share_gradio_link: bool = False,
|
| 21 |
+
model_name: t.Optional[str] = None,
|
| 22 |
+
koboldai_url: t.Optional[str] = None) -> None:
|
| 23 |
+
'''Script entrypoint.'''
|
| 24 |
+
if model_name and not DONT_USE_MODEL:
|
| 25 |
+
from model import build_model_and_tokenizer_for, run_raw_inference
|
| 26 |
+
model, tokenizer = build_model_and_tokenizer_for(model_name)
|
| 27 |
+
else:
|
| 28 |
+
model, tokenizer = None, None
|
| 29 |
+
|
| 30 |
+
def inference_fn(history: t.List[str], user_input: str,
|
| 31 |
+
generation_settings: t.Dict[str, t.Any],
|
| 32 |
+
*char_settings: t.Any) -> str:
|
| 33 |
+
if DONT_USE_MODEL:
|
| 34 |
+
return "Mock response for UI tests."
|
| 35 |
+
|
| 36 |
+
# Brittle. Comes from the order defined in gradio_ui.py.
|
| 37 |
+
[
|
| 38 |
+
char_name,
|
| 39 |
+
_user_name,
|
| 40 |
+
char_persona,
|
| 41 |
+
char_greeting,
|
| 42 |
+
world_scenario,
|
| 43 |
+
example_dialogue,
|
| 44 |
+
] = char_settings
|
| 45 |
+
|
| 46 |
+
# If we're just starting the conversation and the character has a greeting
|
| 47 |
+
# configured, return that instead. This is a workaround for the fact that
|
| 48 |
+
# Gradio assumed that a chatbot cannot possibly start a conversation, so we
|
| 49 |
+
# can't just have the greeting there automatically, it needs to be in
|
| 50 |
+
# response to a user message.
|
| 51 |
+
if len(history) == 0 and char_greeting is not None:
|
| 52 |
+
return f"{char_name}: {char_greeting}"
|
| 53 |
+
|
| 54 |
+
prompt = build_prompt_for(history=history,
|
| 55 |
+
user_message=user_input,
|
| 56 |
+
char_name=char_name,
|
| 57 |
+
char_persona=char_persona,
|
| 58 |
+
example_dialogue=example_dialogue,
|
| 59 |
+
world_scenario=world_scenario)
|
| 60 |
+
|
| 61 |
+
if model and tokenizer:
|
| 62 |
+
model_output = run_raw_inference(model, tokenizer, prompt,
|
| 63 |
+
user_input, **generation_settings)
|
| 64 |
+
elif koboldai_url:
|
| 65 |
+
model_output = f"{char_name}:"
|
| 66 |
+
model_output += run_raw_inference_on_kai(koboldai_url, prompt,
|
| 67 |
+
**generation_settings)
|
| 68 |
+
else:
|
| 69 |
+
raise Exception(
|
| 70 |
+
"Not using local inference, but no Kobold instance URL was"
|
| 71 |
+
" given. Nowhere to perform inference on.")
|
| 72 |
+
|
| 73 |
+
generated_messages = parse_messages_from_str(model_output,
|
| 74 |
+
["You", char_name])
|
| 75 |
+
logger.debug("Parsed model response is: `%s`", generated_messages)
|
| 76 |
+
bot_message = generated_messages[0]
|
| 77 |
+
|
| 78 |
+
return bot_message
|
| 79 |
+
|
| 80 |
+
ui = build_gradio_ui_for(inference_fn, for_kobold=koboldai_url is not None)
|
| 81 |
+
ui.launch(server_port=server_port, share=share_gradio_link)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _parse_args_from_argv() -> argparse.Namespace:
|
| 85 |
+
'''Parses arguments coming in from the command line.'''
|
| 86 |
+
parser = argparse.ArgumentParser()
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"-m",
|
| 89 |
+
"--model-name",
|
| 90 |
+
help=
|
| 91 |
+
"HuggingFace Transformers model name, if not using a KoboldAI instance as an inference server.",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"-p",
|
| 95 |
+
"--port",
|
| 96 |
+
type=int,
|
| 97 |
+
default=3000,
|
| 98 |
+
help="Port to listen on.",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"-k",
|
| 102 |
+
"--koboldai-url",
|
| 103 |
+
help="URL to a KoboldAI instance to use as an inference server.",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"-s",
|
| 107 |
+
"--share",
|
| 108 |
+
action="store_true",
|
| 109 |
+
help="Enable to generate a public link for the Gradio UI.",
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
return parser.parse_args()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
args = _parse_args_from_argv()
|
| 117 |
+
|
| 118 |
+
if args.koboldai_url:
|
| 119 |
+
# I have no idea how long a safe wait time is, but we'd rather wait for
|
| 120 |
+
# too long rather than just cut the user off _right_ when the setup is
|
| 121 |
+
# about to finish, so let's pick something absurd here.
|
| 122 |
+
wait_for_kai_server(args.koboldai_url, max_wait_time_seconds=60 * 30)
|
| 123 |
+
|
| 124 |
+
# Clear out any Kobold logs so the user can clearly see the Gradio link
|
| 125 |
+
# that's about to show up afterwards.
|
| 126 |
+
clear_stdout()
|
| 127 |
+
|
| 128 |
+
main(model_name=args.model_name,
|
| 129 |
+
server_port=args.port,
|
| 130 |
+
koboldai_url=args.koboldai_url,
|
| 131 |
+
share_gradio_link=args.share)
|
gradio_ui.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_generation_defaults(for_kobold):
|
| 8 |
+
defaults = {
|
| 9 |
+
"do_sample": True,
|
| 10 |
+
"max_new_tokens": 196,
|
| 11 |
+
"temperature": 0.5,
|
| 12 |
+
"top_p": 0.9,
|
| 13 |
+
"top_k": 0,
|
| 14 |
+
"typical_p": 1.0,
|
| 15 |
+
"repetition_penalty": 1.05,
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
if for_kobold:
|
| 19 |
+
defaults.update({"max_context_length": 768})
|
| 20 |
+
else:
|
| 21 |
+
defaults.update({"penalty_alpha": 0.6})
|
| 22 |
+
|
| 23 |
+
return defaults
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def build_gradio_ui_for(inference_fn, for_kobold):
|
| 30 |
+
'''
|
| 31 |
+
Builds a Gradio UI to interact with the model. Big thanks to TearGosling for
|
| 32 |
+
the initial version that inspired this.
|
| 33 |
+
'''
|
| 34 |
+
with gr.Blocks(title="Pygmalion", analytics_enabled=False) as interface:
|
| 35 |
+
history_for_gradio = gr.State([])
|
| 36 |
+
history_for_model = gr.State([])
|
| 37 |
+
generation_settings = gr.State(
|
| 38 |
+
get_generation_defaults(for_kobold=for_kobold))
|
| 39 |
+
|
| 40 |
+
def _update_generation_settings(
|
| 41 |
+
original_settings,
|
| 42 |
+
param_name,
|
| 43 |
+
new_value,
|
| 44 |
+
):
|
| 45 |
+
'''
|
| 46 |
+
Merges `{param_name: new_value}` into `original_settings` and
|
| 47 |
+
returns a new dictionary.
|
| 48 |
+
'''
|
| 49 |
+
updated_settings = {**original_settings, param_name: new_value}
|
| 50 |
+
logging.debug("Generation settings updated to: `%s`",
|
| 51 |
+
updated_settings)
|
| 52 |
+
return updated_settings
|
| 53 |
+
|
| 54 |
+
def _run_inference(
|
| 55 |
+
model_history,
|
| 56 |
+
gradio_history,
|
| 57 |
+
user_input,
|
| 58 |
+
generation_settings,
|
| 59 |
+
*char_setting_states,
|
| 60 |
+
):
|
| 61 |
+
'''
|
| 62 |
+
Runs inference on the model, and formats the returned response for
|
| 63 |
+
the Gradio state and chatbot component.
|
| 64 |
+
'''
|
| 65 |
+
char_name = char_setting_states[0]
|
| 66 |
+
user_name = char_setting_states[1]
|
| 67 |
+
|
| 68 |
+
# If user input is blank, format it as if user was silent
|
| 69 |
+
if user_input is None or user_input.strip() == "":
|
| 70 |
+
user_input = "..."
|
| 71 |
+
|
| 72 |
+
inference_result = inference_fn(model_history, user_input,
|
| 73 |
+
generation_settings,
|
| 74 |
+
*char_setting_states)
|
| 75 |
+
|
| 76 |
+
inference_result_for_gradio = inference_result \
|
| 77 |
+
.replace(f"{char_name}:", f"**{char_name}:**") \
|
| 78 |
+
.replace("<USER>", user_name) \
|
| 79 |
+
.replace("\n", "<br>") # Gradio chatbot component can display br tag as linebreak
|
| 80 |
+
|
| 81 |
+
model_history.append(f"You: {user_input}")
|
| 82 |
+
model_history.append(inference_result)
|
| 83 |
+
gradio_history.append((user_input, inference_result_for_gradio))
|
| 84 |
+
|
| 85 |
+
return None, model_history, gradio_history, gradio_history
|
| 86 |
+
|
| 87 |
+
def _regenerate(
|
| 88 |
+
model_history,
|
| 89 |
+
gradio_history,
|
| 90 |
+
generation_settings,
|
| 91 |
+
*char_setting_states,
|
| 92 |
+
):
|
| 93 |
+
'''Regenerates the last response.'''
|
| 94 |
+
return _run_inference(
|
| 95 |
+
model_history[:-2],
|
| 96 |
+
gradio_history[:-1],
|
| 97 |
+
model_history[-2].replace("You: ", ""),
|
| 98 |
+
generation_settings,
|
| 99 |
+
*char_setting_states,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def _undo_last_exchange(model_history, gradio_history):
|
| 103 |
+
'''Undoes the last exchange (message pair).'''
|
| 104 |
+
return model_history[:-2], gradio_history[:-1], gradio_history[:-1]
|
| 105 |
+
|
| 106 |
+
def _save_chat_history(model_history, *char_setting_states):
|
| 107 |
+
'''Saves the current chat history to a .json file.'''
|
| 108 |
+
char_name = char_setting_states[0]
|
| 109 |
+
with open(f"{char_name}_conversation.json", "w") as f:
|
| 110 |
+
f.write(json.dumps({"chat": model_history}))
|
| 111 |
+
return f"{char_name}_conversation.json"
|
| 112 |
+
|
| 113 |
+
def _load_chat_history(file_obj, *char_setting_states):
|
| 114 |
+
'''Loads up a chat history from a .json file.'''
|
| 115 |
+
# #############################################################################################
|
| 116 |
+
# TODO(TG): Automatically detect and convert any CAI dump files loaded in to Pygmalion format #
|
| 117 |
+
# #############################################################################################
|
| 118 |
+
|
| 119 |
+
# https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
|
| 120 |
+
def pairwise(iterable):
|
| 121 |
+
# "s -> (s0, s1), (s2, s3), (s4, s5), ..."
|
| 122 |
+
a = iter(iterable)
|
| 123 |
+
return zip(a, a)
|
| 124 |
+
|
| 125 |
+
char_name = char_setting_states[0]
|
| 126 |
+
user_name = char_setting_states[1]
|
| 127 |
+
|
| 128 |
+
file_data = json.loads(file_obj.decode('utf-8'))
|
| 129 |
+
model_history = file_data["chat"]
|
| 130 |
+
# Construct a new gradio history
|
| 131 |
+
new_gradio_history = []
|
| 132 |
+
for human_turn, bot_turn in pairwise(model_history):
|
| 133 |
+
# Handle the situation where convo history may be loaded before character defs
|
| 134 |
+
if char_name == "":
|
| 135 |
+
# Grab char name from the model history
|
| 136 |
+
char_name = bot_turn.split(":")[0]
|
| 137 |
+
# Format the user and bot utterances
|
| 138 |
+
user_turn = human_turn.replace("You: ", "")
|
| 139 |
+
bot_turn = bot_turn.replace(f"{char_name}:", f"**{char_name}:**")
|
| 140 |
+
|
| 141 |
+
# Somebody released a script on /g/ which tries to convert CAI dump logs
|
| 142 |
+
# to Pygmalion character settings and chats. The anonymization of the dumps, however, means that
|
| 143 |
+
# [NAME_IN_MESSAGE_REDACTED] is left in the conversational history. We obviously wouldn't want this
|
| 144 |
+
# This therefore accomodates users of that script, so that [NAME_IN_MESSAGE_REDACTED] doesn't have
|
| 145 |
+
# to be manually edited in the conversation JSON.
|
| 146 |
+
# The model shouldn't generate [NAME_IN_MESSAGE_REDACTED] by itself.
|
| 147 |
+
user_turn = user_turn.replace("[NAME_IN_MESSAGE_REDACTED]", user_name)
|
| 148 |
+
bot_turn = bot_turn.replace("[NAME_IN_MESSAGE_REDACTED]", user_name)
|
| 149 |
+
|
| 150 |
+
new_gradio_history.append((user_turn, bot_turn))
|
| 151 |
+
|
| 152 |
+
return model_history, new_gradio_history, new_gradio_history
|
| 153 |
+
|
| 154 |
+
with gr.Tab("Character Settings") as settings_tab:
|
| 155 |
+
charfile, char_setting_states = _build_character_settings_ui()
|
| 156 |
+
|
| 157 |
+
with gr.Tab("Chat Window"):
|
| 158 |
+
chatbot = gr.Chatbot(
|
| 159 |
+
label="Your conversation will show up here").style(
|
| 160 |
+
color_map=("#326efd", "#212528"))
|
| 161 |
+
|
| 162 |
+
char_name, _user_name, char_persona, char_greeting, world_scenario, example_dialogue = char_setting_states
|
| 163 |
+
charfile.upload(
|
| 164 |
+
fn=_char_file_upload,
|
| 165 |
+
inputs=[charfile, history_for_model, history_for_gradio],
|
| 166 |
+
outputs=[history_for_model, history_for_gradio, chatbot, char_name, char_persona, char_greeting, world_scenario, example_dialogue]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
message = gr.Textbox(
|
| 170 |
+
label="Your message (hit Enter to send)",
|
| 171 |
+
placeholder="Write a message...",
|
| 172 |
+
)
|
| 173 |
+
message.submit(
|
| 174 |
+
fn=_run_inference,
|
| 175 |
+
inputs=[
|
| 176 |
+
history_for_model, history_for_gradio, message,
|
| 177 |
+
generation_settings, *char_setting_states
|
| 178 |
+
],
|
| 179 |
+
outputs=[
|
| 180 |
+
message, history_for_model, history_for_gradio, chatbot
|
| 181 |
+
],
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
with gr.Row():
|
| 185 |
+
send_btn = gr.Button("Send", variant="primary")
|
| 186 |
+
send_btn.click(
|
| 187 |
+
fn=_run_inference,
|
| 188 |
+
inputs=[
|
| 189 |
+
history_for_model, history_for_gradio, message,
|
| 190 |
+
generation_settings, *char_setting_states
|
| 191 |
+
],
|
| 192 |
+
outputs=[
|
| 193 |
+
message, history_for_model, history_for_gradio, chatbot
|
| 194 |
+
],
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
regenerate_btn = gr.Button("Regenerate")
|
| 198 |
+
regenerate_btn.click(
|
| 199 |
+
fn=_regenerate,
|
| 200 |
+
inputs=[
|
| 201 |
+
history_for_model, history_for_gradio,
|
| 202 |
+
generation_settings, *char_setting_states
|
| 203 |
+
],
|
| 204 |
+
outputs=[
|
| 205 |
+
message, history_for_model, history_for_gradio, chatbot
|
| 206 |
+
],
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
undo_btn = gr.Button("Undo last exchange")
|
| 210 |
+
undo_btn.click(
|
| 211 |
+
fn=_undo_last_exchange,
|
| 212 |
+
inputs=[history_for_model, history_for_gradio],
|
| 213 |
+
outputs=[history_for_model, history_for_gradio, chatbot],
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
with gr.Row():
|
| 217 |
+
with gr.Column():
|
| 218 |
+
chatfile = gr.File(type="binary", file_types=[".json"], interactive=True)
|
| 219 |
+
chatfile.upload(
|
| 220 |
+
fn=_load_chat_history,
|
| 221 |
+
inputs=[chatfile, *char_setting_states],
|
| 222 |
+
outputs=[history_for_model, history_for_gradio, chatbot]
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
save_char_btn = gr.Button(value="Save Conversation History")
|
| 226 |
+
save_char_btn.click(_save_chat_history, inputs=[history_for_model, *char_setting_states], outputs=[chatfile])
|
| 227 |
+
with gr.Column():
|
| 228 |
+
gr.Markdown("""
|
| 229 |
+
### To save a chat
|
| 230 |
+
Click "Save Conversation History". The file will appear above the button and you can click to download it.
|
| 231 |
+
|
| 232 |
+
### To load a chat
|
| 233 |
+
Drag a valid .json file onto the upload box, or click the box to browse.
|
| 234 |
+
|
| 235 |
+
**Remember to fill out/load up your character definitions before resuming a chat!**
|
| 236 |
+
""")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
with gr.Tab("Generation Settings"):
|
| 241 |
+
_build_generation_settings_ui(
|
| 242 |
+
state=generation_settings,
|
| 243 |
+
fn=_update_generation_settings,
|
| 244 |
+
for_kobold=for_kobold,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return interface
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _char_file_upload(file_obj, history_model, history_gradio):
|
| 251 |
+
file_data = json.loads(file_obj.decode('utf-8'))
|
| 252 |
+
char_name = file_data["char_name"]
|
| 253 |
+
greeting = file_data["char_greeting"]
|
| 254 |
+
empty_history = not history_model or (len(history_model) <= 2 and history_model[0] == '')
|
| 255 |
+
if empty_history and char_name and greeting:
|
| 256 |
+
# if chat history is empty so far, and there is a character greeting, add character greeting to the chat
|
| 257 |
+
s = f'{char_name}: {greeting}'
|
| 258 |
+
t = f'**{char_name}**: {greeting}'
|
| 259 |
+
history_model = ['', s]
|
| 260 |
+
history_gradio = [('', t)]
|
| 261 |
+
return history_model, history_gradio, history_gradio, char_name, file_data["char_persona"], greeting, file_data["world_scenario"], file_data["example_dialogue"]
|
| 262 |
+
|
| 263 |
+
def _build_character_settings_ui():
|
| 264 |
+
|
| 265 |
+
def char_file_create(char_name, char_persona, char_greeting, world_scenario, example_dialogue):
|
| 266 |
+
with open(char_name + ".json", "w") as f:
|
| 267 |
+
f.write(json.dumps({"char_name": char_name, "char_persona": char_persona, "char_greeting": char_greeting, "world_scenario": world_scenario, "example_dialogue": example_dialogue}))
|
| 268 |
+
return char_name + ".json"
|
| 269 |
+
|
| 270 |
+
with gr.Column():
|
| 271 |
+
with gr.Row():
|
| 272 |
+
char_name = gr.Textbox(
|
| 273 |
+
label="Character Name",
|
| 274 |
+
placeholder="The character's name",
|
| 275 |
+
)
|
| 276 |
+
user_name = gr.Textbox(
|
| 277 |
+
label="Your Name",
|
| 278 |
+
placeholder="How the character should call you",
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
char_persona = gr.Textbox(
|
| 282 |
+
label="Character Persona",
|
| 283 |
+
placeholder=
|
| 284 |
+
"Describe the character's persona here. Think of this as CharacterAI's description + definitions in one box.",
|
| 285 |
+
lines=4,
|
| 286 |
+
)
|
| 287 |
+
char_greeting = gr.Textbox(
|
| 288 |
+
label="Character Greeting",
|
| 289 |
+
placeholder=
|
| 290 |
+
"Write the character's greeting here. They will say this verbatim as their first response.",
|
| 291 |
+
lines=3,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
world_scenario = gr.Textbox(
|
| 295 |
+
label="Scenario",
|
| 296 |
+
placeholder=
|
| 297 |
+
"Optionally, describe the starting scenario in a few short sentences.",
|
| 298 |
+
)
|
| 299 |
+
example_dialogue = gr.Textbox(
|
| 300 |
+
label="Example Chat",
|
| 301 |
+
placeholder=
|
| 302 |
+
"Optionally, write in an example chat here. This is useful for showing how the character should behave, for example.",
|
| 303 |
+
lines=4,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
with gr.Row():
|
| 307 |
+
with gr.Column():
|
| 308 |
+
charfile = gr.File(type="binary", file_types=[".json"])
|
| 309 |
+
|
| 310 |
+
save_char_btn = gr.Button(value="Generate Character File")
|
| 311 |
+
save_char_btn.click(char_file_create, inputs=[char_name, char_persona, char_greeting, world_scenario, example_dialogue], outputs=[charfile])
|
| 312 |
+
with gr.Column():
|
| 313 |
+
gr.Markdown("""
|
| 314 |
+
### To save a character
|
| 315 |
+
Click "Generate Character File". The file will appear above the button and you can click to download it.
|
| 316 |
+
|
| 317 |
+
### To upload a character
|
| 318 |
+
Drag a valid .json file onto the upload box, or click the box to browse.
|
| 319 |
+
""")
|
| 320 |
+
|
| 321 |
+
return charfile, (char_name, user_name, char_persona, char_greeting, world_scenario, example_dialogue)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _build_generation_settings_ui(state, fn, for_kobold):
|
| 325 |
+
generation_defaults = get_generation_defaults(for_kobold=for_kobold)
|
| 326 |
+
|
| 327 |
+
with gr.Row():
|
| 328 |
+
with gr.Column():
|
| 329 |
+
max_new_tokens = gr.Slider(
|
| 330 |
+
16,
|
| 331 |
+
512,
|
| 332 |
+
value=generation_defaults["max_new_tokens"],
|
| 333 |
+
step=4,
|
| 334 |
+
label="max_new_tokens",
|
| 335 |
+
)
|
| 336 |
+
max_new_tokens.change(
|
| 337 |
+
lambda state, value: fn(state, "max_new_tokens", value),
|
| 338 |
+
inputs=[state, max_new_tokens],
|
| 339 |
+
outputs=state,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
temperature = gr.Slider(
|
| 343 |
+
0.1,
|
| 344 |
+
2,
|
| 345 |
+
value=generation_defaults["temperature"],
|
| 346 |
+
step=0.01,
|
| 347 |
+
label="temperature",
|
| 348 |
+
)
|
| 349 |
+
temperature.change(
|
| 350 |
+
lambda state, value: fn(state, "temperature", value),
|
| 351 |
+
inputs=[state, temperature],
|
| 352 |
+
outputs=state,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
top_p = gr.Slider(
|
| 356 |
+
0.0,
|
| 357 |
+
1.0,
|
| 358 |
+
value=generation_defaults["top_p"],
|
| 359 |
+
step=0.01,
|
| 360 |
+
label="top_p",
|
| 361 |
+
)
|
| 362 |
+
top_p.change(
|
| 363 |
+
lambda state, value: fn(state, "top_p", value),
|
| 364 |
+
inputs=[state, top_p],
|
| 365 |
+
outputs=state,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
with gr.Column():
|
| 369 |
+
typical_p = gr.Slider(
|
| 370 |
+
0.0,
|
| 371 |
+
1.0,
|
| 372 |
+
value=generation_defaults["typical_p"],
|
| 373 |
+
step=0.01,
|
| 374 |
+
label="typical_p",
|
| 375 |
+
)
|
| 376 |
+
typical_p.change(
|
| 377 |
+
lambda state, value: fn(state, "typical_p", value),
|
| 378 |
+
inputs=[state, typical_p],
|
| 379 |
+
outputs=state,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
repetition_penalty = gr.Slider(
|
| 383 |
+
1.0,
|
| 384 |
+
3.0,
|
| 385 |
+
value=generation_defaults["repetition_penalty"],
|
| 386 |
+
step=0.01,
|
| 387 |
+
label="repetition_penalty",
|
| 388 |
+
)
|
| 389 |
+
repetition_penalty.change(
|
| 390 |
+
lambda state, value: fn(state, "repetition_penalty", value),
|
| 391 |
+
inputs=[state, repetition_penalty],
|
| 392 |
+
outputs=state,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
top_k = gr.Slider(
|
| 396 |
+
0,
|
| 397 |
+
100,
|
| 398 |
+
value=generation_defaults["top_k"],
|
| 399 |
+
step=1,
|
| 400 |
+
label="top_k",
|
| 401 |
+
)
|
| 402 |
+
top_k.change(
|
| 403 |
+
lambda state, value: fn(state, "top_k", value),
|
| 404 |
+
inputs=[state, top_k],
|
| 405 |
+
outputs=state,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
if not for_kobold:
|
| 409 |
+
penalty_alpha = gr.Slider(
|
| 410 |
+
0,
|
| 411 |
+
1,
|
| 412 |
+
value=generation_defaults["penalty_alpha"],
|
| 413 |
+
step=0.05,
|
| 414 |
+
label="penalty_alpha",
|
| 415 |
+
)
|
| 416 |
+
penalty_alpha.change(
|
| 417 |
+
lambda state, value: fn(state, "penalty_alpha", value),
|
| 418 |
+
inputs=[state, penalty_alpha],
|
| 419 |
+
outputs=state,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
#
|
| 423 |
+
# Some of these explanations are taken from Kobold:
|
| 424 |
+
# https://github.com/KoboldAI/KoboldAI-Client/blob/main/gensettings.py
|
| 425 |
+
#
|
| 426 |
+
# They're passed directly into the `generate` call, so they should exist here:
|
| 427 |
+
# https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
|
| 428 |
+
#
|
| 429 |
+
with gr.Accordion(label="Helpful information", open=False):
|
| 430 |
+
gr.Markdown("""
|
| 431 |
+
Here's a basic rundown of each setting:
|
| 432 |
+
|
| 433 |
+
- `max_new_tokens`: Number of tokens the AI should generate. Higher numbers will take longer to generate.
|
| 434 |
+
- `temperature`: Randomness of sampling. High values can increase creativity but may make text less sensible. Lower values will make text more predictable but can become repetitious.
|
| 435 |
+
- `top_p`: Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious. (Put this value on 1 to disable its effect)
|
| 436 |
+
- `top_k`: Alternative sampling method, can be combined with top_p. The number of highest probability vocabulary tokens to keep for top-k-filtering. (Put this value on 0 to disable its effect)
|
| 437 |
+
- `typical_p`: Alternative sampling method described in the paper "Typical_p Decoding for Natural Language Generation" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect.
|
| 438 |
+
- `repetition_penalty`: Used to penalize words that were already generated or belong to the context (Going over 1.2 breaks 6B models. Set to 1.0 to disable).
|
| 439 |
+
- `penalty_alpha`: The alpha coefficient when using contrastive search.
|
| 440 |
+
|
| 441 |
+
Some settings might not show up depending on which inference backend is being used.
|
| 442 |
+
""")
|
koboldai_client.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class KoboldApiServerException(Exception):
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def wait_for_kai_server(koboldai_url: str, max_wait_time_seconds: int) -> None:
|
| 15 |
+
'''Blocks until the KAI server is up.'''
|
| 16 |
+
start_time = datetime.datetime.now()
|
| 17 |
+
|
| 18 |
+
while True:
|
| 19 |
+
try:
|
| 20 |
+
requests.head(koboldai_url, timeout=(5, 5))
|
| 21 |
+
break
|
| 22 |
+
except requests.exceptions.ConnectionError as ex:
|
| 23 |
+
if "Connection refused" not in str(ex):
|
| 24 |
+
raise ex
|
| 25 |
+
|
| 26 |
+
abort_at = start_time + datetime.timedelta(
|
| 27 |
+
seconds=max_wait_time_seconds)
|
| 28 |
+
|
| 29 |
+
if datetime.datetime.now() > abort_at:
|
| 30 |
+
raise TimeoutError(
|
| 31 |
+
f"Waited for {max_wait_time_seconds} seconds but KoboldAI"
|
| 32 |
+
" server is still not up, aborting.")
|
| 33 |
+
|
| 34 |
+
time.sleep(1)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def run_raw_inference_on_kai(
|
| 38 |
+
koboldai_url: str,
|
| 39 |
+
prompt: str,
|
| 40 |
+
max_new_tokens: int,
|
| 41 |
+
do_sample: bool,
|
| 42 |
+
typical_p: float,
|
| 43 |
+
repetition_penalty: float,
|
| 44 |
+
**kwargs,
|
| 45 |
+
) -> str:
|
| 46 |
+
endpoint = f"{koboldai_url}/api/v1/generate"
|
| 47 |
+
payload = {
|
| 48 |
+
"prompt": prompt,
|
| 49 |
+
|
| 50 |
+
# Incredibly low max len for reasons explained in the "while True" loop
|
| 51 |
+
# below.
|
| 52 |
+
"max_length": 32,
|
| 53 |
+
|
| 54 |
+
# Take care of parameters which are named differently between Kobold and
|
| 55 |
+
# HuggingFace.
|
| 56 |
+
"sampler_full_determinism": not do_sample,
|
| 57 |
+
"typical": typical_p,
|
| 58 |
+
"rep_pen": repetition_penalty,
|
| 59 |
+
|
| 60 |
+
# Disable any pre or post-processing on the KoboldAI side, we'd rather
|
| 61 |
+
# take care of things on our own.
|
| 62 |
+
"frmttriminc": False,
|
| 63 |
+
"frmtrmspch": False,
|
| 64 |
+
"frmtrmblln": False,
|
| 65 |
+
"frmtadsnsp": False,
|
| 66 |
+
|
| 67 |
+
# Append any other generation parameters that we didn't handle manually.
|
| 68 |
+
**kwargs,
|
| 69 |
+
}
|
| 70 |
+
generated_text = ""
|
| 71 |
+
|
| 72 |
+
# Currently, Kobold doesn't support custom stopping criteria, and their chat
|
| 73 |
+
# mode can't handle multi-line responses. To work around both of those, we
|
| 74 |
+
# use the regular adventure mode generation but keep asking for more tokens
|
| 75 |
+
# until the model starts trying to talk as the user, then we stop.
|
| 76 |
+
attempts = 0
|
| 77 |
+
max_extra_attempts = 4
|
| 78 |
+
while attempts < (payload["max_length"] /
|
| 79 |
+
max_new_tokens) + max_extra_attempts:
|
| 80 |
+
attempts += 1
|
| 81 |
+
response = requests.post(endpoint, json=payload)
|
| 82 |
+
if not response.ok:
|
| 83 |
+
error_message = response.text
|
| 84 |
+
raise KoboldApiServerException(
|
| 85 |
+
"The KoboldAI API server returned an error"
|
| 86 |
+
f" (HTTP status code {response.status_code}): {error_message}")
|
| 87 |
+
|
| 88 |
+
inference_result = response.json()["results"][0]["text"]
|
| 89 |
+
generated_text += inference_result
|
| 90 |
+
|
| 91 |
+
# Model started to talk as us. Stop generating and return results, the
|
| 92 |
+
# rest of the code will take care of trimming it properly.
|
| 93 |
+
if "\nYou:" in generated_text:
|
| 94 |
+
logger.debug("Hit `\nYou:`: `%s`", generated_text)
|
| 95 |
+
return generated_text
|
| 96 |
+
|
| 97 |
+
# For SFT: hit an EOS token. Trim and return.
|
| 98 |
+
if generated_text.endswith("<|endoftext|>"):
|
| 99 |
+
logger.debug("Got EOS token: `%s`", generated_text)
|
| 100 |
+
|
| 101 |
+
# We add a fake generated "\nYou:" here so the trimming code doesn't
|
| 102 |
+
# need to handle SFT and UFT models differently.
|
| 103 |
+
return generated_text.replace("<|endoftext|>", "\nYou:")
|
| 104 |
+
|
| 105 |
+
# Hit the configured generation limit.
|
| 106 |
+
if len(generated_text.split()) >= max_new_tokens:
|
| 107 |
+
logger.debug("Hit max length: `%s`", generated_text)
|
| 108 |
+
return generated_text
|
| 109 |
+
|
| 110 |
+
# Model still hasn't finished what it had to say. Append its output to
|
| 111 |
+
# the prompt and feed it back in.
|
| 112 |
+
logger.debug("Got another %s tokens, but still not done: `%s`",
|
| 113 |
+
payload["max_length"], generated_text)
|
| 114 |
+
payload["prompt"] += inference_result
|
| 115 |
+
|
| 116 |
+
logger.debug("Exhausted generation attempts: `%s`", generated_text)
|
| 117 |
+
return generated_text
|
model.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import typing as t
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import transformers
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_model_and_tokenizer_for(
|
| 11 |
+
model_name: str
|
| 12 |
+
) -> t.Tuple[transformers.AutoModelForCausalLM, transformers.AutoTokenizer]:
|
| 13 |
+
'''Sets up the model and accompanying objects.'''
|
| 14 |
+
logger.info(f"Loading tokenizer for {model_name}")
|
| 15 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
| 16 |
+
|
| 17 |
+
# NOTE(11b): non-OPT models support passing this in at inference time, might
|
| 18 |
+
# be worth refactoring for a debug version so we're able to experiment on
|
| 19 |
+
# the fly
|
| 20 |
+
bad_words_ids = [
|
| 21 |
+
tokenizer(bad_word, add_special_tokens=False).input_ids
|
| 22 |
+
for bad_word in _build_bad_words_list_for(model_name)
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
logger.info(f"Loading the {model_name} model")
|
| 26 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 27 |
+
model_name, bad_words_ids=bad_words_ids)
|
| 28 |
+
model.eval().half().to("cuda")
|
| 29 |
+
|
| 30 |
+
logger.info("Model and tokenizer are ready")
|
| 31 |
+
return model, tokenizer
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def run_raw_inference(model: transformers.AutoModelForCausalLM,
|
| 35 |
+
tokenizer: transformers.AutoTokenizer, prompt: str,
|
| 36 |
+
user_message: str, **kwargs: t.Any) -> str:
|
| 37 |
+
'''
|
| 38 |
+
Runs inference on the model, and attempts to returns only the newly
|
| 39 |
+
generated text.
|
| 40 |
+
|
| 41 |
+
:param model: Model to perform inference with.
|
| 42 |
+
:param tokenizer: Tokenizer to tokenize input with.
|
| 43 |
+
:param prompt: Input to feed to the model.
|
| 44 |
+
:param user_message: The user's raw message, exactly as appended to the end
|
| 45 |
+
of `prompt`. Used for trimming the original input from the model output.
|
| 46 |
+
:return: Decoded model generation.
|
| 47 |
+
'''
|
| 48 |
+
tokenized_items = tokenizer(prompt, return_tensors="pt").to("cuda")
|
| 49 |
+
|
| 50 |
+
# Atrocious code to stop generation when the model outputs "\nYou: " in
|
| 51 |
+
# freshly generated text. Feel free to send in a PR if you know of a
|
| 52 |
+
# cleaner way to do this.
|
| 53 |
+
stopping_criteria_list = transformers.StoppingCriteriaList([
|
| 54 |
+
_SentinelTokenStoppingCriteria(
|
| 55 |
+
sentinel_token_ids=tokenizer(
|
| 56 |
+
"\nYou:",
|
| 57 |
+
add_special_tokens=False,
|
| 58 |
+
return_tensors="pt",
|
| 59 |
+
).input_ids.to("cuda"),
|
| 60 |
+
starting_idx=tokenized_items.input_ids.shape[-1])
|
| 61 |
+
])
|
| 62 |
+
|
| 63 |
+
logits = model.generate(stopping_criteria=stopping_criteria_list,
|
| 64 |
+
**tokenized_items,
|
| 65 |
+
**kwargs)
|
| 66 |
+
output = tokenizer.decode(logits[0], skip_special_tokens=True)
|
| 67 |
+
|
| 68 |
+
logger.debug("Before trimming, model output was: `%s`", output)
|
| 69 |
+
|
| 70 |
+
# Trim out the input prompt from the generated output.
|
| 71 |
+
if (idx := prompt.rfind(user_message)) != -1:
|
| 72 |
+
trimmed_output = output[idx + len(user_message) - 1:].strip()
|
| 73 |
+
logger.debug("After trimming, it became: `%s`", trimmed_output)
|
| 74 |
+
|
| 75 |
+
return trimmed_output
|
| 76 |
+
else:
|
| 77 |
+
raise Exception(
|
| 78 |
+
"Couldn't find user message in the model's output. What?")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _build_bad_words_list_for(_model_name: str) -> t.List[str]:
|
| 82 |
+
'''Builds a list of bad words for the given model.'''
|
| 83 |
+
|
| 84 |
+
# NOTE(11b): This was implemented as a function because each model size
|
| 85 |
+
# seems to have it quirks at the moment, but this is a rushed implementation
|
| 86 |
+
# so I'm not handling that, hence the dumb return here.
|
| 87 |
+
return ["Persona:", "Scenario:", "<START>"]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
| 91 |
+
|
| 92 |
+
def __init__(self, sentinel_token_ids: torch.LongTensor,
|
| 93 |
+
starting_idx: int):
|
| 94 |
+
transformers.StoppingCriteria.__init__(self)
|
| 95 |
+
self.sentinel_token_ids = sentinel_token_ids
|
| 96 |
+
self.starting_idx = starting_idx
|
| 97 |
+
|
| 98 |
+
def __call__(self, input_ids: torch.LongTensor,
|
| 99 |
+
_scores: torch.FloatTensor) -> bool:
|
| 100 |
+
for sample in input_ids:
|
| 101 |
+
trimmed_sample = sample[self.starting_idx:]
|
| 102 |
+
# Can't unfold, output is still too tiny. Skip.
|
| 103 |
+
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
for window in trimmed_sample.unfold(
|
| 107 |
+
0, self.sentinel_token_ids.shape[-1], 1):
|
| 108 |
+
if torch.all(torch.eq(self.sentinel_token_ids, window)):
|
| 109 |
+
return True
|
| 110 |
+
return False
|
parsing.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import typing as t
|
| 3 |
+
|
| 4 |
+
def parse_messages_from_str(string: str, names: t.List[str]) -> t.List[str]:
|
| 5 |
+
'''
|
| 6 |
+
Given a big string containing raw chat history, this function attempts to
|
| 7 |
+
parse it out into a list where each item is an individual message.
|
| 8 |
+
'''
|
| 9 |
+
sanitized_names = [
|
| 10 |
+
re.escape(name) for name in names
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
speaker_regex = re.compile(rf"^({'|'.join(sanitized_names)}): ?",
|
| 14 |
+
re.MULTILINE)
|
| 15 |
+
|
| 16 |
+
message_start_indexes = []
|
| 17 |
+
for match in speaker_regex.finditer(string):
|
| 18 |
+
message_start_indexes.append(match.start())
|
| 19 |
+
|
| 20 |
+
if len(message_start_indexes) < 2:
|
| 21 |
+
# Single message in the string.
|
| 22 |
+
return [string.strip()]
|
| 23 |
+
|
| 24 |
+
prev_start_idx = message_start_indexes[0]
|
| 25 |
+
messages = []
|
| 26 |
+
|
| 27 |
+
for start_idx in message_start_indexes[1:]:
|
| 28 |
+
message = string[prev_start_idx:start_idx].strip()
|
| 29 |
+
messages.append(message)
|
| 30 |
+
prev_start_idx = start_idx
|
| 31 |
+
|
| 32 |
+
# add the last message
|
| 33 |
+
messages.append(string[prev_start_idx:].strip())
|
| 34 |
+
|
| 35 |
+
return messages
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def serialize_chat_history(history: t.List[str]) -> str:
|
| 39 |
+
'''Given a structured chat history object, collapses it down to a string.'''
|
| 40 |
+
return "\n".join(history)
|
prompting.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import typing as t
|
| 3 |
+
|
| 4 |
+
from parsing import parse_messages_from_str
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def build_prompt_for(
|
| 10 |
+
history: t.List[str],
|
| 11 |
+
user_message: str,
|
| 12 |
+
char_name: str,
|
| 13 |
+
char_persona: t.Optional[str] = None,
|
| 14 |
+
example_dialogue: t.Optional[str] = None,
|
| 15 |
+
world_scenario: t.Optional[str] = None,
|
| 16 |
+
) -> str:
|
| 17 |
+
'''Converts all the given stuff into a proper input prompt for the model.'''
|
| 18 |
+
|
| 19 |
+
# If example dialogue is given, parse the history out from it and append
|
| 20 |
+
# that at the beginning of the dialogue history.
|
| 21 |
+
example_history = parse_messages_from_str(
|
| 22 |
+
example_dialogue, ["You", char_name]) if example_dialogue else []
|
| 23 |
+
concatenated_history = [*example_history, *history]
|
| 24 |
+
|
| 25 |
+
# Construct the base turns with the info we already have.
|
| 26 |
+
prompt_turns = [
|
| 27 |
+
# TODO(11b): Shouldn't be here on the original 350M.
|
| 28 |
+
"<START>",
|
| 29 |
+
|
| 30 |
+
# TODO(11b): Arbitrary limit. See if it's possible to vary this
|
| 31 |
+
# based on available context size and VRAM instead.
|
| 32 |
+
*concatenated_history[-8:],
|
| 33 |
+
f"You: {user_message}",
|
| 34 |
+
f"{char_name}:",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
# If we have a scenario or the character has a persona definition, add those
|
| 38 |
+
# to the beginning of the prompt.
|
| 39 |
+
if world_scenario:
|
| 40 |
+
prompt_turns.insert(
|
| 41 |
+
0,
|
| 42 |
+
f"Scenario: {world_scenario}",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if char_persona:
|
| 46 |
+
prompt_turns.insert(
|
| 47 |
+
0,
|
| 48 |
+
f"{char_name}'s Persona: {char_persona}",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Done!
|
| 52 |
+
logger.debug("Constructed prompt is: `%s`", prompt_turns)
|
| 53 |
+
prompt_str = "\n".join(prompt_turns)
|
| 54 |
+
return prompt_str
|
utils.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def clear_stdout():
|
| 2 |
+
'''
|
| 3 |
+
Attempts to clear stdout, whether running in a notebook (IPython) or locally
|
| 4 |
+
in a Unix envirnoment.
|
| 5 |
+
'''
|
| 6 |
+
try:
|
| 7 |
+
from IPython.display import clear_output
|
| 8 |
+
clear_output(wait=True)
|
| 9 |
+
except ImportError:
|
| 10 |
+
import os
|
| 11 |
+
os.system("clear")
|