File size: 5,901 Bytes
bc9db78
 
72e6273
 
bc9db78
72e6273
 
 
bc9db78
 
 
 
 
 
 
 
72e6273
bc9db78
72e6273
bc9db78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72e6273
 
bc9db78
 
 
 
 
 
72e6273
 
bc9db78
1786137
bc9db78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfe75a1
bc9db78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72e6273
 
 
 
bc9db78
 
 
 
 
 
 
72e6273
 
 
 
bc9db78
72e6273
 
 
bc9db78
 
 
 
 
 
 
 
 
 
 
 
 
72e6273
bc9db78
 
 
 
 
 
 
 
 
 
72e6273
bc9db78
72e6273
bc9db78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72e6273
 
bc9db78
 
72e6273
bc9db78
 
72e6273
bc9db78
 
72e6273
bc9db78
 
72e6273
bc9db78
 
72e6273
 
bc9db78
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from argparse import ArgumentParser
import gradio as gr
import requests
import json
import time



def get_streaming_response(response: requests.Response):
    for chunk in response.iter_lines():
        if chunk:
            data = chunk.decode("utf-8")
            if data.startswith('data: '):
                json_str = data[6:]
                
                if json_str == '[DONE]':
                    break
                    
                try:
                    chunk = json.loads(json_str)
                    delta = chunk.get('choices', [{}])[0].get('delta', {})
                    new_text = delta.get('content', '')
                    
                    if new_text:
                        yield new_text
                except (json.JSONDecodeError, IndexError):
                    print(f"Skipping malformed SSE line: {json_str}")
                    continue

def _chat_stream(model, tokenizer, query, history, temperature, top_p, max_output_tokens):
    conversation = []
    for query_h, response_h in history:
        conversation.append({"role": "user", "content": query_h})
        conversation.append({"role": "assistant", "content": response_h})
    conversation.append({"role": "user", "content": query})

    headers = {
        "Content-Type": "application/json"
    }
    
    payload = {
        "model": "megrez-moe-waic",
        "messages": conversation,
        "max_tokens": max_output_tokens,
        "temperature": max(temperature, 0),
        "top_p": top_p,
        "stream": True
    }

    try:
        API_URL = "http://8.152.0.142:10021/v1/chat/completions"
        response = requests.post(API_URL, headers=headers, data=json.dumps(payload), timeout=60, stream=True)
        response.raise_for_status()
        for chunk in get_streaming_response(response):
            yield chunk
            time.sleep(0.01)
        
    except requests.exceptions.RequestException as e:
        print(f"API request failed: {e}")
        yield f"Error: Could not connect to the API. Details: {e}"
    except (KeyError, IndexError) as e:
        print(f"Failed to parse API response: {response.text}")
        yield f"Error: Invalid response format from the API. Details: {e}"

def predict(_query, _chatbot, _task_history, _temperature, _top_p, _max_output_tokens):
    print(f"User: {_query}")
    _chatbot.append((_query, ""))
    
    full_response = ""
    stream = _chat_stream(None, None, _query, history=_task_history, temperature=_temperature, top_p=_top_p, max_output_tokens=_max_output_tokens)
    
    for new_text in stream:
        full_response += new_text
        _chatbot[-1] = (_query, full_response)
        yield _chatbot

    print(f"History: {_task_history}")
    _task_history.append((_query, full_response))
    print(f"Megrez (from API): {full_response}")

def regenerate(_chatbot, _task_history, _temperature, _top_p, _max_output_tokens):
    if not _task_history:
        yield _chatbot
        return
    item = _task_history.pop(-1)
    _chatbot.pop(-1)
    yield from predict(item[0], _chatbot, _task_history, _temperature, _top_p, _max_output_tokens)

def reset_user_input():
    return gr.update(value="")

def reset_state(_chatbot, _task_history):
    _task_history.clear()
    _chatbot.clear()
    return _chatbot



if __name__ == "__main__":
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown(
            f"""
# 🎱 Chat with Megrez2 <a href="https://github.com/infinigence/Infini-Megrez">
"""
        )

        chatbot = gr.Chatbot(label="Megrez2", elem_classes="control-height", height='48vh', show_copy_button=True,
            latex_delimiters=[
            {"left": "$$", "right": "$$", "display": True},
            {"left": "$", "right": "$", "display": False},
            {"left": "\\(", "right": "\\)", "display": False},
            {"left": "\\[", "right": "\\]", "display": True},
        ])
        with gr.Row():
            with gr.Column(scale=20):
                query = gr.Textbox(show_label=False, container=False, placeholder="Enter your prompt here and press ENTER")
            with gr.Column(scale=1, min_width=100):
                submit_btn = gr.Button("πŸš€ Send", variant="primary")
        task_history = gr.State([])

        with gr.Row():
            empty_btn = gr.Button("πŸ—‘οΈ Clear History")
            regen_btn = gr.Button("πŸ”„ Regenerate")

        with gr.Accordion("Parameters", open=False) as parameter_row:
            temperature = gr.Slider(
                minimum=0.0,
                maximum=1.2,
                value=0.7,
                step=0.1,
                interactive=True,
                label="Temperature",
            )
            top_p = gr.Slider(
                minimum=0.0,
                maximum=1.0,
                value=0.9,
                step=0.1,
                interactive=True,
                label="Top P",
            )
            max_output_tokens = gr.Slider(
                minimum=16,
                maximum=32768,
                value=4096,
                step=1024,
                interactive=True,
                label="Max output tokens",
            )

        submit_btn.click(
            predict, [query, chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True
        )
        query.submit(
            predict, [query, chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True
        )
        submit_btn.click(reset_user_input, [], [query])
        query.submit(reset_user_input, [], [query])

        empty_btn.click(
            reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True
        )
        regen_btn.click(
            regenerate, [chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True
        )

    demo.launch(ssr_mode=False, share=True)