InnoMegrez2 / app.py
SII-InnoMegrez's picture
Update app.py
1786137 verified
from argparse import ArgumentParser
import gradio as gr
import requests
import json
import time
def get_streaming_response(response: requests.Response):
for chunk in response.iter_lines():
if chunk:
data = chunk.decode("utf-8")
if data.startswith('data: '):
json_str = data[6:]
if json_str == '[DONE]':
break
try:
chunk = json.loads(json_str)
delta = chunk.get('choices', [{}])[0].get('delta', {})
new_text = delta.get('content', '')
if new_text:
yield new_text
except (json.JSONDecodeError, IndexError):
print(f"Skipping malformed SSE line: {json_str}")
continue
def _chat_stream(model, tokenizer, query, history, temperature, top_p, max_output_tokens):
conversation = []
for query_h, response_h in history:
conversation.append({"role": "user", "content": query_h})
conversation.append({"role": "assistant", "content": response_h})
conversation.append({"role": "user", "content": query})
headers = {
"Content-Type": "application/json"
}
payload = {
"model": "megrez-moe-waic",
"messages": conversation,
"max_tokens": max_output_tokens,
"temperature": max(temperature, 0),
"top_p": top_p,
"stream": True
}
try:
API_URL = "http://8.152.0.142:10021/v1/chat/completions"
response = requests.post(API_URL, headers=headers, data=json.dumps(payload), timeout=60, stream=True)
response.raise_for_status()
for chunk in get_streaming_response(response):
yield chunk
time.sleep(0.01)
except requests.exceptions.RequestException as e:
print(f"API request failed: {e}")
yield f"Error: Could not connect to the API. Details: {e}"
except (KeyError, IndexError) as e:
print(f"Failed to parse API response: {response.text}")
yield f"Error: Invalid response format from the API. Details: {e}"
def predict(_query, _chatbot, _task_history, _temperature, _top_p, _max_output_tokens):
print(f"User: {_query}")
_chatbot.append((_query, ""))
full_response = ""
stream = _chat_stream(None, None, _query, history=_task_history, temperature=_temperature, top_p=_top_p, max_output_tokens=_max_output_tokens)
for new_text in stream:
full_response += new_text
_chatbot[-1] = (_query, full_response)
yield _chatbot
print(f"History: {_task_history}")
_task_history.append((_query, full_response))
print(f"Megrez (from API): {full_response}")
def regenerate(_chatbot, _task_history, _temperature, _top_p, _max_output_tokens):
if not _task_history:
yield _chatbot
return
item = _task_history.pop(-1)
_chatbot.pop(-1)
yield from predict(item[0], _chatbot, _task_history, _temperature, _top_p, _max_output_tokens)
def reset_user_input():
return gr.update(value="")
def reset_state(_chatbot, _task_history):
_task_history.clear()
_chatbot.clear()
return _chatbot
if __name__ == "__main__":
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# 🎱 Chat with Megrez2 <a href="https://github.com/infinigence/Infini-Megrez">
"""
)
chatbot = gr.Chatbot(label="Megrez2", elem_classes="control-height", height='48vh', show_copy_button=True,
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\[", "right": "\\]", "display": True},
])
with gr.Row():
with gr.Column(scale=20):
query = gr.Textbox(show_label=False, container=False, placeholder="Enter your prompt here and press ENTER")
with gr.Column(scale=1, min_width=100):
submit_btn = gr.Button("πŸš€ Send", variant="primary")
task_history = gr.State([])
with gr.Row():
empty_btn = gr.Button("πŸ—‘οΈ Clear History")
regen_btn = gr.Button("πŸ”„ Regenerate")
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.2,
value=0.7,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.1,
interactive=True,
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=16,
maximum=32768,
value=4096,
step=1024,
interactive=True,
label="Max output tokens",
)
submit_btn.click(
predict, [query, chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True
)
query.submit(
predict, [query, chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True
)
submit_btn.click(reset_user_input, [], [query])
query.submit(reset_user_input, [], [query])
empty_btn.click(
reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True
)
regen_btn.click(
regenerate, [chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True
)
demo.launch(ssr_mode=False, share=True)