File size: 5,573 Bytes
2909463 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | import streamlit as st
from PIL import Image
import base64
import requests
import json
from voice_toolkit import voice_toolkit
icon_path = "images/院徽.ico"
ICON = Image.open(icon_path)
with open(icon_path, "rb") as img_file:
ICON_base64 = base64.b64encode(img_file.read()).decode()
st.set_page_config(
page_title="智课灵犀-对话",
layout="centered",
page_icon=ICON,
)
with st.sidebar:
icon_text = f"""
<div class="icon-text-container" style="text-align: center;">
<img src='data:image/png;base64,{ICON_base64}' alt='icon' style='width: 70px; height: 70px; margin: 0 auto; display: block;'>
<span style='font-size: 24px;'>湘潭大学课程助手--智课灵犀</span>
</div>
"""
st.markdown(
icon_text,
unsafe_allow_html=True,
)
st.sidebar.title('模型')
option1 = st.sidebar.selectbox('课程', ['数据结构', '软件工程与项目管理'])
st.sidebar.title('输入')
option2 = st.sidebar.selectbox('方式', ['键盘', '语音'])
# 添加滑动条
st.sidebar.title('参数')
with st.sidebar.expander("文本生成"):
if "max_new_tokens" not in st.session_state:
st.session_state["max_new_tokens"] = 500
st.session_state["top_p"] = 0.9
st.session_state["temperature"] = 0.1
st.session_state["repetition_penalty"] = 1.0
parameter_1 = st.slider('max_new_tokens', min_value=50, max_value=1000,
value=st.session_state.max_new_tokens,
step=50)
parameter_2 = st.slider('top_p', min_value=0.5, max_value=0.95, value=st.session_state.top_p, step=0.01)
parameter_3 = st.slider('temperature', min_value=0.1, max_value=5.0, value=st.session_state.temperature,
step=0.1)
parameter_4 = st.slider('repetition_penalty', min_value=0.5, max_value=5.0,
value=st.session_state.repetition_penalty, step=0.1)
st.session_state["max_new_tokens"] = parameter_1
st.session_state["top_p"] = parameter_2
st.session_state["temperature"] = parameter_3
st.session_state["repetition_penalty"] = parameter_4
st.title("🪶 智课灵犀")
st.caption("🌈 一款由湘潭大学计算机学院开发的课程助手")
# 状态
if "chat_type" not in st.session_state or st.session_state["chat_type"] != "chat":
st.session_state["chat_type"] = "chat"
if "is_recording" not in st.session_state:
st.session_state.is_recording = False
if "user_input_area" not in st.session_state:
st.session_state.user_input_area = ""
if "user_voice_value" not in st.session_state:
st.session_state.user_voice_value = ""
if "voice_flag" not in st.session_state:
st.session_state["voice_flag"] = ""
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "message": "你好,我是湘潭大学课程知识答疑小助手“智课灵犀”"}]
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["message"])
def send_message():
payload = json.dumps({
"chat_type": st.session_state.chat_type,
"messages": st.session_state.messages,
"max_new_tokens": st.session_state.max_new_tokens,
"top_p": st.session_state.top_p,
"temperature": st.session_state.temperature,
"repetition_penalty": st.session_state.repetition_penalty,
})
# print(type(payload), payload)
headers = {'Content-Type': 'application/json'}
url_map = {
"数据结构": "http://localhost:5000/api-dev/qa/get_answer",
"软件工程与项目管理": "http://localhost:5000/api-dev/qa/get_answer2",
}
url = url_map.get(option1)
response = requests.post(url, data=payload, headers=headers)
# print(response, type(response))
return response.text
if option2 == "键盘":
if prompt := st.chat_input(placeholder="输入..."):
st.session_state.messages.append({"role": "user", "message": prompt})
st.chat_message("user").write(prompt)
answer = send_message()
st.session_state.messages.append({"role": "assistant", "message": answer})
st.chat_message("assistant").write(answer)
print(st.session_state)
elif option2 == "语音":
# 文本输入表单
with st.form("input_form", clear_on_submit=True):
prompt = st.text_area(
"**输入:**",
key="user_input_area",
value=st.session_state.user_voice_value,
help="在此输入文本或通过语音输入。"
)
submitted = st.form_submit_button("确认提交")
# 处理提交
if submitted:
st.session_state.messages.append({"role": "user", "message": prompt})
st.chat_message("user").write(prompt)
answer = send_message()
st.session_state.messages.append({"role": "assistant", "message": answer})
st.chat_message("assistant").write(answer)
st.session_state.user_voice_value = ""
st.rerun()
# 语音输入
vocie_result = voice_toolkit()
# vocie_result会保存最后一次结果
if (
vocie_result and vocie_result["voice_result"]["flag"] == "interim"
) or st.session_state["voice_flag"] == "interim":
st.session_state["voice_flag"] = "interim"
st.session_state["user_voice_value"] = vocie_result["voice_result"]["value"]
if vocie_result["voice_result"]["flag"] == "final":
st.session_state["voice_flag"] = "final"
st.rerun()
|