|
|
import streamlit as st |
|
|
from streamlit_chat import message |
|
|
import tempfile |
|
|
from langchain.document_loaders import PyPDFLoader |
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.vectorstores import FAISS |
|
|
from langchain.chains import ConversationalRetrievalChain |
|
|
from langchain_g4f import G4FLLM |
|
|
from g4f import Provider, models |
|
|
|
|
|
|
|
|
DB_FAISS_PATH = 'vectorstore/db_faiss' |
|
|
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2' |
|
|
LLM_MODEL = models.gpt_35_long |
|
|
LLM_PROVIDER = Provider.OpenaiChat |
|
|
|
|
|
|
|
|
def configure_ui(): |
|
|
"""Configure Streamlit UI settings""" |
|
|
st.set_page_config(page_title="Zendo AI Assistant", page_icon="📄") |
|
|
hide_streamlit_style = """ |
|
|
<style> |
|
|
#MainMenu {visibility: hidden;} |
|
|
footer {visibility: hidden;} |
|
|
.stTextInput input {font-size: 16px;} |
|
|
</style> |
|
|
""" |
|
|
st.markdown(hide_streamlit_style, unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
def init_session_state(): |
|
|
"""Initialize session state variables""" |
|
|
if 'history' not in st.session_state: |
|
|
st.session_state['history'] = [] |
|
|
if 'generated' not in st.session_state: |
|
|
st.session_state['generated'] = ["こんにちは!Zendoアシスタントです。PDFの内容について何でも聞いてください 🤗"] |
|
|
if 'past' not in st.session_state: |
|
|
st.session_state['past'] = ["ようこそ!"] |
|
|
|
|
|
|
|
|
def load_llm(): |
|
|
"""Load the language model""" |
|
|
return G4FLLM( |
|
|
model=LLM_MODEL, |
|
|
provider=LLM_PROVIDER, |
|
|
) |
|
|
|
|
|
|
|
|
def process_pdf(uploaded_file): |
|
|
"""Process the uploaded PDF file""" |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmpfile: |
|
|
tmpfile.write(uploaded_file.getvalue()) |
|
|
tmpfile_path = tmpfile.name |
|
|
|
|
|
loader = PyPDFLoader(tmpfile_path) |
|
|
pdf_data = loader.load() |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
|
model_name=EMBEDDING_MODEL, |
|
|
model_kwargs={'device': 'cpu'} |
|
|
) |
|
|
|
|
|
db = FAISS.from_documents(pdf_data, embeddings) |
|
|
db.save_local(DB_FAISS_PATH) |
|
|
return db |
|
|
|
|
|
|
|
|
def conversational_chat(query, chain): |
|
|
"""Handle conversational chat with memory""" |
|
|
result = chain({ |
|
|
"question": query, |
|
|
"chat_history": st.session_state['history'] |
|
|
}) |
|
|
st.session_state['history'].append((query, result["answer"])) |
|
|
return result["answer"] |
|
|
|
|
|
|
|
|
def main(): |
|
|
configure_ui() |
|
|
init_session_state() |
|
|
|
|
|
st.title("📄 Zendo AI Assistant - PDFチャットボット") |
|
|
|
|
|
|
|
|
col1, col2 = st.columns([1, 3]) |
|
|
with col1: |
|
|
language = st.selectbox("言語/Language", ["日本語", "English", "Tiếng Việt"]) |
|
|
|
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
|
"PDFファイルをアップロードしてください (Upload PDF file)", |
|
|
type="pdf", |
|
|
help="PDFをアップロードすると、その内容について質問できます" |
|
|
) |
|
|
|
|
|
if uploaded_file: |
|
|
with st.spinner("PDFを処理中...少々お待ちください"): |
|
|
db = process_pdf(uploaded_file) |
|
|
llm = load_llm() |
|
|
chain = ConversationalRetrievalChain.from_llm( |
|
|
llm=llm, |
|
|
retriever=db.as_retriever() |
|
|
) |
|
|
st.success("PDFの処理が完了しました!質問をどうぞ") |
|
|
|
|
|
|
|
|
response_container = st.container() |
|
|
|
|
|
with st.form(key='chat_form', clear_on_submit=True): |
|
|
user_input = st.text_input( |
|
|
"メッセージを入力...", |
|
|
key='input', |
|
|
placeholder="PDFについて質問してください" |
|
|
) |
|
|
submit_button = st.form_submit_button(label='送信') |
|
|
|
|
|
if submit_button and user_input: |
|
|
output = conversational_chat(user_input, chain) |
|
|
st.session_state['past'].append(user_input) |
|
|
st.session_state['generated'].append(output) |
|
|
|
|
|
|
|
|
if st.session_state['generated']: |
|
|
with response_container: |
|
|
for i in range(len(st.session_state['generated'])): |
|
|
message( |
|
|
st.session_state["past"][i], |
|
|
is_user=True, |
|
|
key=str(i) + '_user', |
|
|
avatar_style="big-smile" |
|
|
) |
|
|
message( |
|
|
st.session_state["generated"][i], |
|
|
key=str(i), |
|
|
avatar_style="thumbs" |
|
|
) |
|
|
else: |
|
|
st.info("PDFファイルをアップロードしてチャットを開始してください") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |