| import re |
|
|
| import gradio as gr |
| from routellm.controller import Controller |
|
|
| TEMPERATURE = 0.8 |
| THRESHOLD = 0.11593 |
| ROUTER = "mf" |
|
|
| client = Controller( |
| routers=["mf"], |
| strong_model="gpt-4-1106-preview", |
| weak_model="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1", |
| ) |
|
|
|
|
| def predict(message, history, threshold, temperature): |
| |
| history_openai_format = [ |
| {"role": "system", "content": "You are a helpful AI assistant."} |
| ] |
| for human, assistant in history: |
| history_openai_format.append({"role": "user", "content": human}) |
| history_openai_format.append( |
| { |
| "role": "assistant", |
| |
| "content": re.sub(r"^\*\*\[.*?\]\*\*\s*", "", assistant), |
| } |
| ) |
| history_openai_format.append({"role": "user", "content": message}) |
|
|
| |
| stream = client.chat.completions.create( |
| model=f"router-{ROUTER}-{threshold}", |
| messages=history_openai_format, |
| temperature=temperature, |
| stream=True, |
| max_tokens=512 |
| ) |
| print(stream) |
|
|
| |
| partial_message = "" |
| for i, chunk in enumerate(stream): |
| print(chunk) |
| if i == 0: |
| if chunk.model == "mistralai/Mixtral-8x7B-Instruct-v0.1": |
| model_name = "Mixtral-8x7B-Instruct-v0.1" |
| else: |
| model_name = chunk.model |
| model_prefix = f"**[{model_name}]**\n" |
| yield model_prefix |
| partial_message += model_prefix |
| partial_message += chunk.choices[0].delta.content or "" |
| yield partial_message |
|
|
|
|
| |
| demo = gr.ChatInterface( |
| predict, |
| additional_inputs=[ |
| gr.Slider(label="Threshold", minimum=0, maximum=1, value=THRESHOLD, step=0.01), |
| gr.Slider( |
| label="Temperature", minimum=0, maximum=1, value=TEMPERATURE, step=0.1 |
| ), |
| ], |
| title="RouteLLM", |
| fill_height=True, |
| description="This is a demo of our matrix factorization router, calibrated so that approximately 50% of calls (those that are harder) are routed to GPT-4, with remaining calls routed to Mixtral 8x7B.\n\nCheck out https://github.com/lm-sys/RouteLLM for details!", |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|