File size: 3,661 Bytes
48886b7
 
 
 
 
 
e2e7b98
 
48886b7
 
 
 
 
e2e7b98
 
48886b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2e7b98
 
 
 
 
48886b7
e2e7b98
 
 
 
 
 
48886b7
e2e7b98
 
 
 
 
 
 
 
48886b7
 
 
e2e7b98
48886b7
 
e2e7b98
48886b7
 
e2e7b98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48886b7
 
 
 
 
 
 
 
 
 
 
 
 
 
e2e7b98
48886b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2e7b98
 
48886b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2e7b98
48886b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
# SPDX-License-Identifier: Apache-2.0
#

import os
from config import MODEL, INFO, HOST
from openai import AsyncOpenAI
import gradio as gr

async def playground(
    message,
    history,
    num_ctx,
    max_tokens,
    temperature,
    repeat_penalty,
    top_k,
    top_p
):
    if not isinstance(message, str) or not message.strip():
        yield []
        return

    messages = []
    for item in history:
        if isinstance(item, dict) and "role" in item and "content" in item:
            messages.append({
                "role": item["role"],
                "content": item["content"]
            })
    messages.append({"role": "user", "content": message})

    response = ""
    stream = await AsyncOpenAI(
        base_url=os.getenv("OLLAMA_API_BASE_URL"),
        api_key=os.getenv("OLLAMA_API_KEY")
    ).chat.completions.create(
        model=MODEL,
        messages=messages,
        max_tokens=int(max_tokens),
        temperature=float(temperature),
        top_p=float(top_p),
        stream=True,
        extra_body={
            "num_ctx": int(num_ctx),
            "repeat_penalty": float(repeat_penalty),
            "top_k": int(top_k)
        }
    )

    async for chunk in stream:
        if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
            response += chunk.choices[0].delta.content
            yield response

with gr.Blocks(
    fill_height=True,
    fill_width=False
) as app:
    with gr.Sidebar():
        gr.HTML(INFO)
        gr.Markdown("---")
        gr.Markdown("## Model Parameters")
        num_ctx = gr.Slider(
            minimum=512,
            maximum=8192,
            value=512,
            step=128,
            label="Context Length",
            info="Maximum context window size (memory)"
        )
        gr.Markdown("")
        max_tokens = gr.Slider(
            minimum=512,
            maximum=8192,
            value=512,
            step=128,
            label="Max Tokens",
            info="Maximum number of tokens to generate"
        )
        gr.Markdown("")
        temperature = gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.1,
            step=0.1,
            label="Temperature",
            info="Controls randomness in generation"
        )
        gr.Markdown("")
        repeat_penalty = gr.Slider(
            minimum=0.1,
            maximum=2.0,
            value=1.05,
            step=0.1,
            label="Repetition Penalty",
            info="Penalty for repeating tokens"
        )
        gr.Markdown("")
        top_k = gr.Slider(
            minimum=0,
            maximum=100,
            value=50,
            step=1,
            label="Top K",
            info="Number of top tokens to consider"
        )
        gr.Markdown("")
        top_p = gr.Slider(
            minimum=0.0,
            maximum=1.0,
            value=0.1,
            step=0.05,
            label="Top P",
            info="Cumulative probability threshold"
        )

    gr.ChatInterface(
        fn=playground,
        additional_inputs=[
            num_ctx,
            max_tokens,
            temperature,
            repeat_penalty,
            top_k,
            top_p
        ],
        type="messages",
        examples=[
            ["Please introduce yourself."],
            ["What caused World War II?"],
            ["Give me a short introduction to large language model."],
            ["Explain about quantum computers."]
        ],
        cache_examples=False,
        show_api=False
    )

app.launch(
    server_name=HOST,
    pwa=True
)