nextllama2 / app.py
shigure451's picture
update
cb43d18
import json
import gradio as gr
import os
import requests
hf_token = os.getenv('HF_TOKEN')
api_url = "http://region-31.seetacloud.com:46766/v1/chat/completions"
headers = {
'Content-Type': 'application/json',
}
TIME_OUT_SECONDS = 30
title = "Next llama Chatbot"
description = """...""" # 保留你原始的描述
css = """.toast-wrap { display: none !important } """
examples = [
['Hello there! How are you doing?'],
# ... 其他示例
]
message_history = []
def process_api_response(response):
assistant_response = ""
for line in response.iter_lines():
decoded_line = line.decode('utf-8').strip()
if decoded_line.startswith("data: "):
decoded_line = decoded_line[6:]
if decoded_line:
try:
json_line = json.loads(decoded_line)
if "choices" in json_line and "delta" in json_line["choices"][0]:
delta = json_line["choices"][0]["delta"]
if "content" in delta:
assistant_response += delta["content"] # 累加每一步的回复
yield assistant_response # 实时返回累加的回复
except json.JSONDecodeError:
print(f"Failed to decode line: {decoded_line}")
# 在历史记录中只保存最后一个完整的回复
if assistant_response:
message_history.append({
"role": "assistant",
"content": assistant_response
})
def predict(system_message, message, system_prompt="You are a helpful, respectful and honest assistant.", temperature=0.9, max_new_tokens=2048, top_p=0.6, repetition_penalty=1.0):
# 添加用户和助手的消息到历史记录
message_history.append({
"role": "assistant",
"content": system_prompt
})
message_history.append({
"role": "user",
"content": system_message
})
data = {
"model": "LLaMa-2-13B-chat",
"messages": message_history, # 使用完整的消息历史记录
"stream": True,
"temperature": temperature,
"max_tokens": max_new_tokens,
"presence_penalty": repetition_penalty
}
# 打印发送到后端的API数据
print("Sending the following data to the backend API:")
print(json.dumps(data, indent=4))
try:
response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=('hf', hf_token), stream=True, timeout=TIME_OUT_SECONDS)
if response.status_code == 200:
for assistant_reply in process_api_response(response):
yield assistant_reply
elif response.status_code == 401:
yield "Error: Unauthorized"
else:
yield f"Error with status code: {response.status_code}"
except requests.Timeout:
yield "Error: Request timed out"
except requests.RequestException as e:
yield f"Error: {e}"
# def vote(data: gr.LikeData):
# if data.liked:
# print("You upvoted this response: " + data.value)
# else:
# print("You downvoted this response: " + data.value)
additional_inputs = [
gr.Textbox("", label="Optional system prompt"),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=4096,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.6,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
# Remove the unrecognized arguments from gr.Chatbot
chatbot_stream = gr.Chatbot()
# Since gr.ChatInterface doesn't support additional_inputs, we'll need to adjust our design.
# For now, I'm removing the additional_inputs argument. You might need to consider a different interface type if you want to use these inputs.
chat_interface_stream = gr.ChatInterface(predict,
title=title,
description=description,
chatbot=chatbot_stream,
css=css,
examples=examples,
cache_examples=True)
chat_interface_stream.queue().launch(debug=True)