import os import json import tempfile import streamlit as st from dotenv import load_dotenv # UI templates from htmlTemplates import css, bot_template, user_template # Text splitters from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter # Vector store / embeddings from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings # Loaders from langchain_community.document_loaders.pdf import PyPDFLoader from langchain_community.document_loaders.text import TextLoader from langchain_community.document_loaders.csv_loader import CSVLoader from langchain.docstore.document import Document # LLM + chain from langchain.memory import ConversationBufferMemory from langchain.chains import ConversationalRetrievalChain from langchain_groq import ChatGroq # ---------- PDF ---------- def get_pdf_text(pdf_docs): temp_dir = tempfile.TemporaryDirectory() temp_filepath = os.path.join(temp_dir.name, pdf_docs.name) with open(temp_filepath, "wb") as f: f.write(pdf_docs.getvalue()) pdf_loader = PyPDFLoader(temp_filepath) pdf_doc = pdf_loader.load() # Keep temp_dir alive if "temp_dirs" not in st.session_state: st.session_state["temp_dirs"] = [] st.session_state["temp_dirs"].append(temp_dir) return pdf_doc # ---------- TXT ---------- def get_text_file(docs): temp_dir = tempfile.TemporaryDirectory() temp_filepath = os.path.join(temp_dir.name, docs.name) with open(temp_filepath, "wb") as f: f.write(docs.getvalue()) text_loader = TextLoader(temp_filepath, encoding="utf-8") text_doc = text_loader.load() if "temp_dirs" not in st.session_state: st.session_state["temp_dirs"] = [] st.session_state["temp_dirs"].append(temp_dir) return text_doc # ---------- CSV ---------- def get_csv_file(docs): temp_dir = tempfile.TemporaryDirectory() temp_filepath = os.path.join(temp_dir.name, docs.name) with open(temp_filepath, "wb") as f: f.write(docs.getvalue()) csv_loader = CSVLoader(temp_filepath, encoding="utf-8") csv_doc = csv_loader.load() if "temp_dirs" not in st.session_state: st.session_state["temp_dirs"] = [] st.session_state["temp_dirs"].append(temp_dir) return csv_doc # ---------- JSON ---------- def get_json_file(file) -> list[Document]: raw = file.getvalue().decode("utf-8", errors="ignore") data = json.loads(raw) docs = [] def add_doc(x): docs.append(Document(page_content=json.dumps(x, ensure_ascii=False))) if isinstance(data, dict) and "scans" in data and isinstance(data["scans"], list): for s in data["scans"]: rels = s.get("relationships", []) if isinstance(rels, list) and rels: for r in rels: add_doc(r) if not docs: add_doc(data) elif isinstance(data, list): for item in data: add_doc(item) else: add_doc(data) return docs # ---------- Chunking ---------- def get_text_chunks(documents): text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, length_function=len, ) return text_splitter.split_documents(documents) # ---------- Vector store ---------- def get_vectorstore(text_chunks): embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L12-v2", model_kwargs={"device": "cpu"}, ) vectorstore = FAISS.from_documents(text_chunks, embeddings) return vectorstore # ---------- Conversation chain ---------- def get_conversation_chain(vectorstore): llm = ChatGroq( groq_api_key=os.environ.get("GROQ_API_KEY"), model_name="llama-3.1-8b-instant", temperature=0.75, max_tokens=512, ) memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True ) retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) conversation_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, memory=memory, ) return conversation_chain # ---------- UI ---------- def handle_userinput(user_question): if st.session_state.conversation is None: st.warning("먼저 문서를 업로드하고 Process 버튼을 눌러주세요.") return response = st.session_state.conversation({'question': user_question}) st.session_state.chat_history = response['chat_history'] for i, message in enumerate(st.session_state.chat_history): if i % 2 == 0: st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True) else: st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True) def process_files(docs, mode: str): mime_map = { "pdf": ["application/pdf", "application/octet-stream"], "txt": ["text/plain"], "csv": ["text/csv", "application/vnd.ms-excel"], "json": ["application/json"], } loader_map = { "pdf": get_pdf_text, "txt": get_text_file, "csv": get_csv_file, "json": get_json_file, } valid_mimes = mime_map[mode] loader_fn = loader_map[mode] doc_list = [] for file in docs or []: if file.type in valid_mimes: doc_list.extend(loader_fn(file)) else: st.error(f"{mode.upper()} 파일이 아닙니다. (받은 MIME: {file.type})") if not doc_list: st.error("처리 가능한 문서를 찾지 못했습니다.") st.stop() text_chunks = get_text_chunks(doc_list) vectorstore = get_vectorstore(text_chunks) st.session_state.conversation = get_conversation_chain(vectorstore) st.success(f"{mode.upper()} 문서 처리 완료! 이제 질문을 입력해 보세요.") def main(): load_dotenv() st.set_page_config(page_title="Basic_RAG_AI_Chatbot_with_Llama", page_icon="📚") st.write(css, unsafe_allow_html=True) if "conversation" not in st.session_state: st.session_state.conversation = None if "chat_history" not in st.session_state: st.session_state.chat_history = None st.header("Basic_RAG_AI_Chatbot_with_Llama3 📚") user_question = st.text_input("Ask a question about your documents:") if user_question: handle_userinput(user_question) with st.sidebar: st.subheader("Your documents") st.markdown("파일을 업로드한 후 아래 버튼을 눌러 처리하세요.") docs = st.file_uploader( "Upload your Files here and click on 'Process'", accept_multiple_files=True ) # 버튼을 세로로 나열하여 모든 버튼이 확실히 보이도록 함 if st.button("Process[PDF]"): with st.spinner("Processing PDF..."): process_files(docs, "pdf") if st.button("Process[TXT]"): with st.spinner("Processing TXT..."): process_files(docs, "txt") if st.button("Process[CSV]"): with st.spinner("Processing CSV..."): process_files(docs, "csv") if st.button("Process[JSON]"): with st.spinner("Processing JSON..."): process_files(docs, "json") if __name__ == '__main__': main()