File size: 2,825 Bytes
f3d2b28
e421ea4
 
 
f3d2b28
 
e421ea4
f3d2b28
 
 
 
56cf2cc
5c4242f
 
 
 
 
 
e421ea4
 
 
 
4363a5c
e421ea4
 
 
 
 
f3d2b28
 
 
5c4242f
e421ea4
4363a5c
 
f3d2b28
e421ea4
f3d2b28
56cf2cc
f3d2b28
56cf2cc
f3d2b28
56cf2cc
f3d2b28
 
4363a5c
f3d2b28
 
 
 
 
 
 
 
 
56cf2cc
f3d2b28
e421ea4
 
 
 
 
 
 
 
 
 
 
4363a5c
 
 
 
 
 
 
 
 
 
e421ea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import gradio as gr
from huggingface_hub import InferenceClient

VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
VLLM_API_KEY = os.getenv("VLLM_API_KEY")

if not VLLM_BASE_URL:
    raise ValueError("Missing env var: VLLM_BASE_URL")
if not VLLM_API_KEY:
    raise ValueError("Missing env var: VLLM_API_KEY")

model2port = {
    "llama-3.2-1b-instruct-unsloth-bnb-16bit-FineTome-r32": 8000,
    "llama-3.2-3b-instruct-unsloth-bnb-16bit-FineTome-r32": 8001,
    "qwen-2.5-3b-instruct-unsloth-bnb-16bit-FineTome-r32": 8002,
}

def respond(
    message,
    history: list[dict[str, str]],
    system_message,
    model_name,
    max_tokens,
    temperature,
    top_p,
    hf_token: gr.OAuthToken,
):
    """
    For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
    """
    client = InferenceClient(token=VLLM_API_KEY, model=f"{VLLM_BASE_URL}:{model2port[model_name]}")

    model_name = f"Zephyroam/{model_name}"

    messages = [{"role": "system", "content": system_message}]

    messages.extend(history)

    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(
        messages,
        model=model_name,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        choices = message.choices
        token = ""
        if len(choices) and choices[0].delta.content:
            token = choices[0].delta.content

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Dropdown(
            label="Model name",
            choices=[
                "llama-3.2-1b-instruct-unsloth-bnb-16bit-FineTome-r32",
                "llama-3.2-3b-instruct-unsloth-bnb-16bit-FineTome-r32",
                "qwen-2.5-3b-instruct-unsloth-bnb-16bit-FineTome-r32",
            ],
            value="llama-3.2-3b-instruct-unsloth-bnb-16bit-FineTome-r32",
            allow_custom_value=False,
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

with gr.Blocks() as demo:
    with gr.Sidebar():
        gr.LoginButton()
    chatbot.render()


if __name__ == "__main__":
    demo.launch()