| from model.fastchat.conversation import (Conversation, SeparatorStyle, |
| compute_skip_echo_len) |
| from model.fastchat.serve.inference import ChatIO, generate_stream, load_model |
|
|
|
|
| class SimpleChatIO(ChatIO): |
|
|
| def prompt_for_input(self, role) -> str: |
| return input(f"{role}: ") |
|
|
| def prompt_for_output(self, role: str): |
| print(f"{role}: ", end="", flush=True) |
|
|
| def stream_output(self, output_stream, skip_echo_len: int): |
| pre = 0 |
| for outputs in output_stream: |
| outputs = outputs[skip_echo_len:].strip() |
| outputs = outputs.split(" ") |
| now = len(outputs) - 1 |
| if now > pre: |
| print(" ".join(outputs[pre:now]), end=" ", flush=True) |
| pre = now |
| print(" ".join(outputs[pre:]), flush=True) |
| return " ".join(outputs) |
|
|
|
|
| class VicunaChatBot: |
|
|
| def __init__( |
| self, |
| model_path: str, |
| device: str, |
| num_gpus: str, |
| max_gpu_memory: str, |
| load_8bit: bool, |
| ChatIO: ChatIO, |
| debug: bool, |
| ): |
| self.model_path = model_path |
| self.device = device |
| self.chatio = ChatIO |
| self.debug = debug |
|
|
| self.model, self.tokenizer = load_model(self.model_path, device, |
| num_gpus, max_gpu_memory, |
| load_8bit, debug) |
|
|
| def chat(self, inp: str, temperature: float, max_new_tokens: int, |
| conv: Conversation): |
| """ Vicuna as a chatbot. """ |
| conv.append_message(conv.roles[0], inp) |
| conv.append_message(conv.roles[1], None) |
|
|
| generate_stream_func = generate_stream |
| prompt = conv.get_prompt() |
|
|
| skip_echo_len = compute_skip_echo_len(self.model_path, conv, prompt) |
| stop_str = ( |
| conv.sep if conv.sep_style |
| in [SeparatorStyle.SINGLE, SeparatorStyle.BAIZE] else None) |
| params = { |
| "model": self.model_path, |
| "prompt": prompt, |
| "temperature": temperature, |
| "max_new_tokens": max_new_tokens, |
| "stop": stop_str, |
| } |
| print(prompt) |
| self.chatio.prompt_for_output(conv.roles[1]) |
| output_stream = generate_stream_func(self.model, self.tokenizer, |
| params, self.device) |
| outputs = self.chatio.stream_output(output_stream, skip_echo_len) |
| |
| conv.messages[-1][-1] = outputs.strip() |
| return outputs, conv |
|
|
|
|
| class VicunaHandler: |
| """ VicunaHandler is a class that handles the communication between the |
| frontend and the backend. """ |
|
|
| def __init__(self, config): |
| self.config = config |
| self.chat_io = SimpleChatIO() |
| self.chatbot = VicunaChatBot( |
| self.config['model_path'], |
| self.config['device'], |
| self.config['num_gpus'], |
| self.config['max_gpu_memory'], |
| self.config['load_8bit'], |
| self.chat_io, |
| self.config['debug'], |
| ) |
|
|
| def chat(self): |
| """ Chat with the Vicuna. """ |
| pass |
|
|
| def gr_chatbot_init(self, caption: str): |
| """ Initialise the chatbot for gradio. """ |
|
|
| template = self._construct_conversation(caption) |
| print("Chatbot initialised.") |
| return template.copy(), template.copy() |
|
|
| def gr_chat(self, inp, conv: Conversation): |
| """ Chat using gradio as the frontend. """ |
| return self.chatbot.chat(inp, self.config['temperature'], |
| self.config['max_new_tokens'], conv) |
|
|
| def _construct_conversation(self, prompt): |
| """ Construct a conversation template. |
| Args: |
| prompt: the prompt for the conversation. |
| """ |
|
|
| user_message = "The following text described what you have " +\ |
| "seen, found, heard and notice from a consecutive video." +\ |
| " Some of the texts may not be accurate. " +\ |
| "Try to conclude what happens in the video, " +\ |
| "then answer my question based on your conclusion.\n" +\ |
| "<video begin>\n" + prompt + "<video end>\n" +\ |
| "Example: Is this a Video?" |
|
|
| user_message = user_message.strip() |
|
|
| print(user_message) |
|
|
| return Conversation( |
| system= |
| "A chat between a curious user and an artificial intelligence assistant answering quetions on videos." |
| "The assistant answers the questions based on the given video captions and speech in time order.", |
| roles=("USER", "ASSISTANT"), |
| messages=(("USER", user_message), ("ASSISTANT", "yes")), |
| offset=0, |
| sep_style=SeparatorStyle.TWO, |
| sep=" ", |
| sep2="</s>", |
| ) |
|
|