| import streamlit as st |
| import os |
| from langchain_google_genai import ChatGoogleGenerativeAI |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_community.document_loaders import TextLoader |
| from langchain_huggingface import HuggingFaceEmbeddings |
| from langchain.prompts import PromptTemplate |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.runnables import RunnablePassthrough |
| from langchain.vectorstores import Chroma |
| from chromadb.config import Settings |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
| |
| page = st.title("Chat with AskUSTH") |
|
|
| |
| if "gemini_api" not in st.session_state: |
| st.session_state.gemini_api = None |
|
|
| if "rag" not in st.session_state: |
| st.session_state.rag = None |
|
|
| if "llm" not in st.session_state: |
| st.session_state.llm = None |
|
|
| if "embd" not in st.session_state: |
| st.session_state.embd = None |
|
|
| if "model" not in st.session_state: |
| st.session_state.model = None |
|
|
| if "save_dir" not in st.session_state: |
| st.session_state.save_dir = None |
|
|
| if "uploaded_files" not in st.session_state: |
| st.session_state.uploaded_files = set() |
|
|
| if "chat_history" not in st.session_state: |
| st.session_state.chat_history = [] |
|
|
| |
| def load_txt(file_path): |
| loader = TextLoader(file_path=file_path, encoding="utf-8") |
| doc = loader.load() |
| return doc |
|
|
| |
| def format_docs(docs): |
| """Định dạng các tài liệu thành chuỗi văn bản.""" |
| return "\n\n".join(doc.page_content for doc in docs) |
|
|
| |
| @st.cache_resource |
| def get_chat_google_model(api_key): |
| os.environ["GOOGLE_API_KEY"] = api_key |
| return ChatGoogleGenerativeAI( |
| model="gemini-1.5-pro", |
| temperature=0, |
| max_tokens=None, |
| timeout=None, |
| max_retries=2, |
| ) |
|
|
| |
| @st.cache_resource |
| def get_embedding_model(): |
| model_name = "bkai-foundation-models/vietnamese-bi-encoder" |
| model_kwargs = {'device': 'cpu'} |
| encode_kwargs = {'normalize_embeddings': False} |
| |
| model = HuggingFaceEmbeddings( |
| model_name=model_name, |
| model_kwargs=model_kwargs, |
| encode_kwargs=encode_kwargs |
| ) |
| return model |
|
|
| |
| @st.cache_resource |
| def compute_rag_chain(_model, _embd, docs_texts): |
| if not docs_texts: |
| raise ValueError("Không có tài liệu nào để xử lý. Vui lòng tải lên các tệp hợp lệ.") |
| |
| combined_text = "\n\n".join(docs_texts) |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
| texts = text_splitter.split_text(combined_text) |
| |
| if len(texts) > 5000: |
| raise ValueError("Tài liệu tạo ra quá nhiều đoạn. Vui lòng sử dụng tài liệu nhỏ hơn.") |
| |
| |
| persist_dir = "./chromadb_store" |
| if not os.path.exists(persist_dir): |
| os.makedirs(persist_dir) |
| |
| |
| settings = Settings(persist_directory=persist_dir) |
|
|
| |
| vectorstore = Chroma.from_texts(texts=texts, embedding=_embd, client_settings=settings) |
| retriever = vectorstore.as_retriever() |
|
|
| |
| template = """ |
| Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. |
| Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. |
| Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời. |
| Dưới đây là thông tin liên quan mà bạn cần sử dụng tới: |
| {context} |
| hãy trả lời: |
| {question} |
| """ |
| prompt = PromptTemplate(template=template, input_variables=["context", "question"]) |
| |
| rag_chain = ( |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} |
| | prompt |
| | _model |
| | StrOutputParser() |
| ) |
| return rag_chain |
|
|
| |
| @st.dialog("Setup Gemini") |
| def setup_gemini(): |
| st.markdown( |
| """ |
| Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới. |
| """ |
| ) |
| key = st.text_input("Key:", "") |
| if st.button("Save") and key != "": |
| st.session_state.gemini_api = key |
| st.rerun() |
|
|
| if st.session_state.gemini_api is None: |
| setup_gemini() |
|
|
| if st.session_state.gemini_api and st.session_state.model is None: |
| st.session_state.model = get_chat_google_model(st.session_state.gemini_api) |
|
|
| if st.session_state.embd is None: |
| st.session_state.embd = get_embedding_model() |
|
|
| if st.session_state.save_dir is None: |
| save_dir = "./Documents" |
| if not os.path.exists(save_dir): |
| os.makedirs(save_dir) |
| st.session_state.save_dir = save_dir |
|
|
| |
| with st.sidebar: |
| uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"]) |
| max_file_size_mb = 5 |
| if uploaded_files: |
| documents = [] |
| for uploaded_file in uploaded_files: |
| if uploaded_file.size > max_file_size_mb * 1024 * 1024: |
| st.warning(f"Tệp {uploaded_file.name} vượt quá giới hạn {max_file_size_mb}MB.") |
| continue |
| |
| file_path = os.path.join(st.session_state.save_dir, uploaded_file.name) |
| with open(file_path, mode='wb') as w: |
| w.write(uploaded_file.getvalue()) |
| |
| doc = load_txt(file_path) |
| documents.extend([*doc]) |
| |
| if documents: |
| docs_texts = [d.page_content for d in documents] |
| st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts) |
|
|
| |
| for message in st.session_state.chat_history: |
| with st.chat_message(message["role"]): |
| st.write(message["content"]) |
|
|
| prompt = st.chat_input("Bạn muốn hỏi gì?") |
| if st.session_state.model is not None: |
| if prompt: |
| st.session_state.chat_history.append({"role": "user", "content": prompt}) |
| with st.chat_message("user"): |
| st.write(prompt) |
| |
| with st.chat_message("assistant"): |
| if st.session_state.rag is not None: |
| response = st.session_state.rag.invoke(prompt) |
| st.write(response) |
| else: |
| ans = st.session_state.llm.invoke(prompt) |
| response = ans.content |
| st.write(response) |
| |
| st.session_state.chat_history.append({"role": "assistant", "content": response}) |