Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| import argparse | |
| import logging | |
| import typing as t | |
| from gradio_ui import build_gradio_ui_for | |
| from koboldai_client import run_raw_inference_on_kai, wait_for_kai_server | |
| from parsing import parse_messages_from_str | |
| from prompting import build_prompt_for | |
| from utils import clear_stdout | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| # For UI debugging purposes. | |
| DONT_USE_MODEL = False | |
| def main(server_port: int, | |
| share_gradio_link: bool = False, | |
| model_name: t.Optional[str] = None, | |
| koboldai_url: t.Optional[str] = None) -> None: | |
| '''Script entrypoint.''' | |
| if model_name and not DONT_USE_MODEL: | |
| from model import build_model_and_tokenizer_for, run_raw_inference | |
| model, tokenizer = build_model_and_tokenizer_for(model_name) | |
| else: | |
| model, tokenizer = None, None | |
| def inference_fn(history: t.List[str], user_input: str, | |
| generation_settings: t.Dict[str, t.Any], | |
| *char_settings: t.Any) -> str: | |
| if DONT_USE_MODEL: | |
| return "Mock response for UI tests." | |
| # Brittle. Comes from the order defined in gradio_ui.py. | |
| [ | |
| char_name, | |
| _user_name, | |
| char_persona, | |
| char_greeting, | |
| world_scenario, | |
| example_dialogue, | |
| ] = char_settings | |
| # If we're just starting the conversation and the character has a greeting | |
| # configured, return that instead. This is a workaround for the fact that | |
| # Gradio assumed that a chatbot cannot possibly start a conversation, so we | |
| # can't just have the greeting there automatically, it needs to be in | |
| # response to a user message. | |
| if len(history) == 0 and char_greeting is not None: | |
| return f"{char_name}: {char_greeting}" | |
| prompt = build_prompt_for(history=history, | |
| user_message=user_input, | |
| char_name=char_name, | |
| char_persona=char_persona, | |
| example_dialogue=example_dialogue, | |
| world_scenario=world_scenario) | |
| if model and tokenizer: | |
| model_output = run_raw_inference(model, tokenizer, prompt, | |
| user_input, **generation_settings) | |
| elif koboldai_url: | |
| model_output = f"{char_name}:" | |
| model_output += run_raw_inference_on_kai(koboldai_url, prompt, | |
| **generation_settings) | |
| else: | |
| raise Exception( | |
| "Not using local inference, but no Kobold instance URL was" | |
| " given. Nowhere to perform inference on.") | |
| generated_messages = parse_messages_from_str(model_output, | |
| ["You", char_name]) | |
| logger.debug("Parsed model response is: `%s`", generated_messages) | |
| bot_message = generated_messages[0] | |
| return bot_message | |
| ui = build_gradio_ui_for(inference_fn, for_kobold=koboldai_url is not None) | |
| ui.launch(server_port=server_port, share=share_gradio_link) | |
| def _parse_args_from_argv() -> argparse.Namespace: | |
| '''Parses arguments coming in from the command line.''' | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-m", | |
| "--model-name", | |
| help= | |
| "HuggingFace Transformers model name, if not using a KoboldAI instance as an inference server.", | |
| ) | |
| parser.add_argument( | |
| "-p", | |
| "--port", | |
| type=int, | |
| default=3000, | |
| help="Port to listen on.", | |
| ) | |
| parser.add_argument( | |
| "-k", | |
| "--koboldai-url", | |
| help="URL to a KoboldAI instance to use as an inference server.", | |
| ) | |
| parser.add_argument( | |
| "-s", | |
| "--share", | |
| action="store_true", | |
| help="Enable to generate a public link for the Gradio UI.", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = _parse_args_from_argv() | |
| if args.koboldai_url: | |
| # I have no idea how long a safe wait time is, but we'd rather wait for | |
| # too long rather than just cut the user off _right_ when the setup is | |
| # about to finish, so let's pick something absurd here. | |
| wait_for_kai_server(args.koboldai_url, max_wait_time_seconds=60 * 30) | |
| # Clear out any Kobold logs so the user can clearly see the Gradio link | |
| # that's about to show up afterwards. | |
| clear_stdout() | |
| main(model_name=args.model_name, | |
| server_port=args.port, | |
| koboldai_url=args.koboldai_url, | |
| share_gradio_link=args.share) | |