Spaces:
Sleeping
Sleeping
| from argparse import ArgumentParser | |
| import gradio as gr | |
| import requests | |
| import json | |
| import time | |
| def get_streaming_response(response: requests.Response): | |
| for chunk in response.iter_lines(): | |
| if chunk: | |
| data = chunk.decode("utf-8") | |
| if data.startswith('data: '): | |
| json_str = data[6:] | |
| if json_str == '[DONE]': | |
| break | |
| try: | |
| chunk = json.loads(json_str) | |
| delta = chunk.get('choices', [{}])[0].get('delta', {}) | |
| new_text = delta.get('content', '') | |
| if new_text: | |
| yield new_text | |
| except (json.JSONDecodeError, IndexError): | |
| print(f"Skipping malformed SSE line: {json_str}") | |
| continue | |
| def _chat_stream(model, tokenizer, query, history, temperature, top_p, max_output_tokens): | |
| conversation = [] | |
| for query_h, response_h in history: | |
| conversation.append({"role": "user", "content": query_h}) | |
| conversation.append({"role": "assistant", "content": response_h}) | |
| conversation.append({"role": "user", "content": query}) | |
| headers = { | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": "megrez-moe-waic", | |
| "messages": conversation, | |
| "max_tokens": max_output_tokens, | |
| "temperature": max(temperature, 0), | |
| "top_p": top_p, | |
| "stream": True | |
| } | |
| try: | |
| API_URL = "http://8.152.0.142:10021/v1/chat/completions" | |
| response = requests.post(API_URL, headers=headers, data=json.dumps(payload), timeout=60, stream=True) | |
| response.raise_for_status() | |
| for chunk in get_streaming_response(response): | |
| yield chunk | |
| time.sleep(0.01) | |
| except requests.exceptions.RequestException as e: | |
| print(f"API request failed: {e}") | |
| yield f"Error: Could not connect to the API. Details: {e}" | |
| except (KeyError, IndexError) as e: | |
| print(f"Failed to parse API response: {response.text}") | |
| yield f"Error: Invalid response format from the API. Details: {e}" | |
| def predict(_query, _chatbot, _task_history, _temperature, _top_p, _max_output_tokens): | |
| print(f"User: {_query}") | |
| _chatbot.append((_query, "")) | |
| full_response = "" | |
| stream = _chat_stream(None, None, _query, history=_task_history, temperature=_temperature, top_p=_top_p, max_output_tokens=_max_output_tokens) | |
| for new_text in stream: | |
| full_response += new_text | |
| _chatbot[-1] = (_query, full_response) | |
| yield _chatbot | |
| print(f"History: {_task_history}") | |
| _task_history.append((_query, full_response)) | |
| print(f"Megrez (from API): {full_response}") | |
| def regenerate(_chatbot, _task_history, _temperature, _top_p, _max_output_tokens): | |
| if not _task_history: | |
| yield _chatbot | |
| return | |
| item = _task_history.pop(-1) | |
| _chatbot.pop(-1) | |
| yield from predict(item[0], _chatbot, _task_history, _temperature, _top_p, _max_output_tokens) | |
| def reset_user_input(): | |
| return gr.update(value="") | |
| def reset_state(_chatbot, _task_history): | |
| _task_history.clear() | |
| _chatbot.clear() | |
| return _chatbot | |
| if __name__ == "__main__": | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| f""" | |
| # π± Chat with Megrez2 <a href="https://github.com/infinigence/Infini-Megrez"> | |
| """ | |
| ) | |
| chatbot = gr.Chatbot(label="Megrez2", elem_classes="control-height", height='48vh', show_copy_button=True, | |
| latex_delimiters=[ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "$", "right": "$", "display": False}, | |
| {"left": "\\(", "right": "\\)", "display": False}, | |
| {"left": "\\[", "right": "\\]", "display": True}, | |
| ]) | |
| with gr.Row(): | |
| with gr.Column(scale=20): | |
| query = gr.Textbox(show_label=False, container=False, placeholder="Enter your prompt here and press ENTER") | |
| with gr.Column(scale=1, min_width=100): | |
| submit_btn = gr.Button("π Send", variant="primary") | |
| task_history = gr.State([]) | |
| with gr.Row(): | |
| empty_btn = gr.Button("ποΈ Clear History") | |
| regen_btn = gr.Button("π Regenerate") | |
| with gr.Accordion("Parameters", open=False) as parameter_row: | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.2, | |
| value=0.7, | |
| step=0.1, | |
| interactive=True, | |
| label="Temperature", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| interactive=True, | |
| label="Top P", | |
| ) | |
| max_output_tokens = gr.Slider( | |
| minimum=16, | |
| maximum=32768, | |
| value=4096, | |
| step=1024, | |
| interactive=True, | |
| label="Max output tokens", | |
| ) | |
| submit_btn.click( | |
| predict, [query, chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True | |
| ) | |
| query.submit( | |
| predict, [query, chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True | |
| ) | |
| submit_btn.click(reset_user_input, [], [query]) | |
| query.submit(reset_user_input, [], [query]) | |
| empty_btn.click( | |
| reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True | |
| ) | |
| regen_btn.click( | |
| regenerate, [chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True | |
| ) | |
| demo.launch(ssr_mode=False, share=True) |