|
|
import streamlit as st |
|
|
from streamlit_chat import message |
|
|
import os |
|
|
import openai |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="AI对话一问一答模式", |
|
|
page_icon=":robot:" |
|
|
) |
|
|
st.header("🔥AI对话一问一答模式") |
|
|
|
|
|
|
|
|
def get_text1(): |
|
|
if 'openai_key' not in st.session_state: |
|
|
input_text1 = st.text_input("📫请输入你的账号: ", key="input") |
|
|
if st.button("确认登陆!", key="input3"): |
|
|
st.session_state['openai_key'] = input_text1 |
|
|
return input_text1 |
|
|
else: |
|
|
return st.session_state['openai_key'] |
|
|
|
|
|
openai_key = get_text1() |
|
|
if openai_key: |
|
|
openai.api_key = openai_key |
|
|
st.write("") |
|
|
|
|
|
|
|
|
def openai_create(prompt): |
|
|
try: |
|
|
|
|
|
response = openai.Completion.create( |
|
|
model="text-davinci-003", |
|
|
prompt=prompt, |
|
|
temperature=0.5, |
|
|
max_tokens=1024, |
|
|
top_p=1, |
|
|
frequency_penalty=0, |
|
|
presence_penalty=0.6, |
|
|
stop=[" Human:", " AI:"] |
|
|
) |
|
|
|
|
|
return response.choices[0].text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
return "你的账号填写有误,请刷新页面重新填写正确的账号!" |
|
|
|
|
|
|
|
|
def chatgpt_clone(input): |
|
|
output = openai_create(input) |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'generated' not in st.session_state: |
|
|
st.session_state['generated'] = [] |
|
|
|
|
|
if 'past' not in st.session_state: |
|
|
st.session_state['past'] = [] |
|
|
|
|
|
|
|
|
def get_text(): |
|
|
input_text = st.text_input("📫你想说的: ", key="input1") |
|
|
if st.button("发送", key="input2"): |
|
|
return input_text |
|
|
return None |
|
|
|
|
|
|
|
|
user_input = get_text() |
|
|
|
|
|
if user_input: |
|
|
|
|
|
output = chatgpt_clone(user_input) |
|
|
st.session_state.past.append(user_input) |
|
|
st.session_state.generated.append(output) |
|
|
|
|
|
if st.session_state['generated']: |
|
|
|
|
|
for i in range(len(st.session_state['generated'])-1, -1, -1): |
|
|
message(st.session_state["generated"][i], key=str(i)) |
|
|
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') |
|
|
|