from argparse import ArgumentParser import gradio as gr import requests import json import time def get_streaming_response(response: requests.Response): for chunk in response.iter_lines(): if chunk: data = chunk.decode("utf-8") if data.startswith('data: '): json_str = data[6:] if json_str == '[DONE]': break try: chunk = json.loads(json_str) delta = chunk.get('choices', [{}])[0].get('delta', {}) new_text = delta.get('content', '') if new_text: yield new_text except (json.JSONDecodeError, IndexError): print(f"Skipping malformed SSE line: {json_str}") continue def _chat_stream(model, tokenizer, query, history, temperature, top_p, max_output_tokens): conversation = [] for query_h, response_h in history: conversation.append({"role": "user", "content": query_h}) conversation.append({"role": "assistant", "content": response_h}) conversation.append({"role": "user", "content": query}) headers = { "Content-Type": "application/json" } payload = { "model": "megrez-moe-waic", "messages": conversation, "max_tokens": max_output_tokens, "temperature": max(temperature, 0), "top_p": top_p, "stream": True } try: API_URL = "http://8.152.0.142:10021/v1/chat/completions" response = requests.post(API_URL, headers=headers, data=json.dumps(payload), timeout=60, stream=True) response.raise_for_status() for chunk in get_streaming_response(response): yield chunk time.sleep(0.01) except requests.exceptions.RequestException as e: print(f"API request failed: {e}") yield f"Error: Could not connect to the API. Details: {e}" except (KeyError, IndexError) as e: print(f"Failed to parse API response: {response.text}") yield f"Error: Invalid response format from the API. Details: {e}" def predict(_query, _chatbot, _task_history, _temperature, _top_p, _max_output_tokens): print(f"User: {_query}") _chatbot.append((_query, "")) full_response = "" stream = _chat_stream(None, None, _query, history=_task_history, temperature=_temperature, top_p=_top_p, max_output_tokens=_max_output_tokens) for new_text in stream: full_response += new_text _chatbot[-1] = (_query, full_response) yield _chatbot print(f"History: {_task_history}") _task_history.append((_query, full_response)) print(f"Megrez (from API): {full_response}") def regenerate(_chatbot, _task_history, _temperature, _top_p, _max_output_tokens): if not _task_history: yield _chatbot return item = _task_history.pop(-1) _chatbot.pop(-1) yield from predict(item[0], _chatbot, _task_history, _temperature, _top_p, _max_output_tokens) def reset_user_input(): return gr.update(value="") def reset_state(_chatbot, _task_history): _task_history.clear() _chatbot.clear() return _chatbot if __name__ == "__main__": with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( f""" # 🎱 Chat with Megrez2 """ ) chatbot = gr.Chatbot(label="Megrez2", elem_classes="control-height", height='48vh', show_copy_button=True, latex_delimiters=[ {"left": "$$", "right": "$$", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": "\\(", "right": "\\)", "display": False}, {"left": "\\[", "right": "\\]", "display": True}, ]) with gr.Row(): with gr.Column(scale=20): query = gr.Textbox(show_label=False, container=False, placeholder="Enter your prompt here and press ENTER") with gr.Column(scale=1, min_width=100): submit_btn = gr.Button("🚀 Send", variant="primary") task_history = gr.State([]) with gr.Row(): empty_btn = gr.Button("🗑️ Clear History") regen_btn = gr.Button("🔄 Regenerate") with gr.Accordion("Parameters", open=False) as parameter_row: temperature = gr.Slider( minimum=0.0, maximum=1.2, value=0.7, step=0.1, interactive=True, label="Temperature", ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.9, step=0.1, interactive=True, label="Top P", ) max_output_tokens = gr.Slider( minimum=16, maximum=32768, value=4096, step=1024, interactive=True, label="Max output tokens", ) submit_btn.click( predict, [query, chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True ) query.submit( predict, [query, chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True ) submit_btn.click(reset_user_input, [], [query]) query.submit(reset_user_input, [], [query]) empty_btn.click( reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True ) regen_btn.click( regenerate, [chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True ) demo.launch(ssr_mode=False, share=True)