File size: 5,573 Bytes
2909463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import streamlit as st
from PIL import Image
import base64
import requests
import json
from voice_toolkit import voice_toolkit


icon_path = "images/院徽.ico"

ICON = Image.open(icon_path)
with open(icon_path, "rb") as img_file:
    ICON_base64 = base64.b64encode(img_file.read()).decode()

st.set_page_config(
    page_title="智课灵犀-对话",
    layout="centered",
    page_icon=ICON,
)

with st.sidebar:
    icon_text = f"""
        <div class="icon-text-container" style="text-align: center;">
            <img src='data:image/png;base64,{ICON_base64}' alt='icon' style='width: 70px; height: 70px; margin: 0 auto; display: block;'>
            <span style='font-size: 24px;'>湘潭大学课程助手--智课灵犀</span>
        </div>
        """
    st.markdown(
        icon_text,
        unsafe_allow_html=True,
    )

st.sidebar.title('模型')
option1 = st.sidebar.selectbox('课程', ['数据结构', '软件工程与项目管理'])

st.sidebar.title('输入')
option2 = st.sidebar.selectbox('方式', ['键盘', '语音'])


# 添加滑动条
st.sidebar.title('参数')
with st.sidebar.expander("文本生成"):
    if "max_new_tokens" not in st.session_state:
        st.session_state["max_new_tokens"] = 500
        st.session_state["top_p"] = 0.9
        st.session_state["temperature"] = 0.1
        st.session_state["repetition_penalty"] = 1.0
    parameter_1 = st.slider('max_new_tokens', min_value=50, max_value=1000,
                                    value=st.session_state.max_new_tokens,
                                    step=50)
    parameter_2 = st.slider('top_p', min_value=0.5, max_value=0.95, value=st.session_state.top_p, step=0.01)
    parameter_3 = st.slider('temperature', min_value=0.1, max_value=5.0, value=st.session_state.temperature,
                                    step=0.1)
    parameter_4 = st.slider('repetition_penalty', min_value=0.5, max_value=5.0,
                                    value=st.session_state.repetition_penalty, step=0.1)

    st.session_state["max_new_tokens"] = parameter_1
    st.session_state["top_p"] = parameter_2
    st.session_state["temperature"] = parameter_3
    st.session_state["repetition_penalty"] = parameter_4

st.title("🪶 智课灵犀")
st.caption("🌈 一款由湘潭大学计算机学院开发的课程助手")

# 状态
if "chat_type" not in st.session_state or st.session_state["chat_type"] != "chat":
    st.session_state["chat_type"] = "chat"

if "is_recording" not in st.session_state:
    st.session_state.is_recording = False

if "user_input_area" not in st.session_state:
    st.session_state.user_input_area = ""

if "user_voice_value" not in st.session_state:
    st.session_state.user_voice_value = ""

if "voice_flag" not in st.session_state:
    st.session_state["voice_flag"] = ""

if "messages" not in st.session_state:
    st.session_state["messages"] = [{"role": "assistant", "message": "你好,我是湘潭大学课程知识答疑小助手“智课灵犀”"}]

for msg in st.session_state.messages:
    st.chat_message(msg["role"]).write(msg["message"])


def send_message():
    payload = json.dumps({
        "chat_type": st.session_state.chat_type,
        "messages": st.session_state.messages,
        "max_new_tokens": st.session_state.max_new_tokens,
        "top_p": st.session_state.top_p,
        "temperature": st.session_state.temperature,
        "repetition_penalty": st.session_state.repetition_penalty,
    })
    # print(type(payload), payload)
    headers = {'Content-Type': 'application/json'}
    url_map = {
        "数据结构": "http://localhost:5000/api-dev/qa/get_answer",
        "软件工程与项目管理": "http://localhost:5000/api-dev/qa/get_answer2",
    }
    url = url_map.get(option1)
    response = requests.post(url, data=payload, headers=headers)
    # print(response, type(response))
    return response.text



if option2 == "键盘":
    if prompt := st.chat_input(placeholder="输入..."):
        st.session_state.messages.append({"role": "user", "message": prompt})
        st.chat_message("user").write(prompt)
        answer = send_message()
        st.session_state.messages.append({"role": "assistant", "message": answer})
        st.chat_message("assistant").write(answer)
        print(st.session_state)

elif option2 == "语音":
    # 文本输入表单
    with st.form("input_form", clear_on_submit=True):
        prompt = st.text_area(
            "**输入:**",
            key="user_input_area",
            value=st.session_state.user_voice_value,
            help="在此输入文本或通过语音输入。"
        )
        submitted = st.form_submit_button("确认提交")

    # 处理提交
    if submitted:
        st.session_state.messages.append({"role": "user", "message": prompt})
        st.chat_message("user").write(prompt)
        answer = send_message()
        st.session_state.messages.append({"role": "assistant", "message": answer})
        st.chat_message("assistant").write(answer)

        st.session_state.user_voice_value = ""
        st.rerun()
    # 语音输入
    vocie_result = voice_toolkit()
    # vocie_result会保存最后一次结果
    if (
            vocie_result and vocie_result["voice_result"]["flag"] == "interim"
    ) or st.session_state["voice_flag"] == "interim":
        st.session_state["voice_flag"] = "interim"
        st.session_state["user_voice_value"] = vocie_result["voice_result"]["value"]
        if vocie_result["voice_result"]["flag"] == "final":
            st.session_state["voice_flag"] = "final"
            st.rerun()