| import gradio as gr |
| import time |
| import argparse |
| from vllm import LLM, SamplingParams |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", type=str) |
| parser.add_argument("--n_gpu", type=int, default=1) |
| return parser.parse_args() |
|
|
| def echo(message, history, system_prompt, temperature, max_tokens): |
| response = f"System prompt: {system_prompt}\n Message: {message}. \n Temperature: {temperature}. \n Max Tokens: {max_tokens}." |
| for i in range(min(len(response), int(max_tokens))): |
| time.sleep(0.05) |
| yield response[: i+1] |
|
|
| def predict(message, history, system_prompt, temperature, max_tokens): |
| instruction = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " |
| for human, assistant in history: |
| instruction += 'USER: '+ human + ' ASSISTANT: '+ assistant + '</s>' |
| instruction += 'USER: '+ message + ' ASSISTANT:' |
| problem = [instruction] |
| stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"] |
| sampling_params = SamplingParams(temperature=temperature, top_p=1, max_tokens=max_tokens, stop=stop_tokens) |
| completions = llm.generate(problem, sampling_params) |
| for output in completions: |
| prompt = output.prompt |
| generated_text = output.outputs[0].text |
| for idx in range(len(generated_text)): |
| yield generated_text[:idx+1] |
|
|
|
|
| """ |
| - Setup environment: |
| ```bash |
| conda create -n wizardweb python=3.8 -y |
| conda activate wizardweb |
| pip install vllm |
| pip install transformers==4.31.0 |
| pip install --upgrade gradio |
| pip install jsonlines |
| pip install ray==2.5.1 |
| ``` |
| ```python |
| python gradio_wizardlm.py --model xxxx --n_gpu 1 |
| python gradio_wizardlm.py --model /workspaceblobstore/caxu/trained_models/13Bv2_v14continue_2048_e3_2e_5/checkpoint-850 --n_gpu 1 |
| ``` |
| |
| """ |
| if __name__ == "__main__": |
| args = parse_args() |
| llm = LLM(model=args.model, tensor_parallel_size=args.n_gpu) |
|
|
| gr.ChatInterface( |
| predict, |
| title="LLM playground - WizardLM", |
| description="This is a LLM playground for WizardLM.", |
| theme="soft", |
| |
| |
| chatbot=gr.Chatbot(height=300, label="Chat History",), |
| textbox=gr.Textbox(placeholder="input", container=False, scale=7), |
| retry_btn=None, |
| undo_btn="Delete Previous", |
| clear_btn="Clear", |
| additional_inputs=[ |
| gr.Textbox("You are helpful AI.", label="System Prompt"), |
| gr.Slider(0, 1, 0.9, label="Temperature"), |
| gr.Slider(10, 1000, 800, label="Max Tokens"), |
| ], |
| additional_inputs_accordion_name="Parameters", |
| ).queue().launch(share=False, server_name="phlrr2019.guest.corp.microsoft.com", server_port=7860) |
| |
|
|