Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import tiktoken | |
| import re | |
| from loguru import logger | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain_community.llms import HuggingFacePipeline | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from langchain.document_loaders import PyPDFLoader | |
| from langchain.document_loaders import Docx2txtLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.vectorstores import FAISS | |
| from langchain.memory import StreamlitChatMessageHistory | |
| def preprocess_korean_text(text): | |
| """ํ๊ตญ์ด ํ ์คํธ ์ ์ฒ๋ฆฌ ํจ์""" | |
| # ๋ถํ์ํ ํน์๋ฌธ์ ์ ๊ฑฐ (ํ๊ตญ์ด, ์์ด, ์ซ์, ๊ณต๋ฐฑ๋ง ์ ์ง) | |
| text = re.sub(r'[^๊ฐ-ํฃa-zA-Z0-9\s.,!?]', ' ', text) | |
| # ์ฐ์๋ ๊ณต๋ฐฑ์ ํ๋๋ก ํตํฉ | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def main(): | |
| st.set_page_config( | |
| page_title="ํ๊ตญ์ด ๋ฌธ์ QA ์ฑ๋ด", | |
| page_icon="๐ฐ๐ท", | |
| layout="wide" | |
| ) | |
| st.title("๐ฐ๐ท _ํ๊ตญ์ด ์ ์ฉ ๋ฌธ์ :red[QA ์ฑ๋ด]_ ๐") | |
| st.markdown("**์ต๊ณ ์ฑ๋ฅ์ ํ๊ตญ์ด AI ๋ชจ๋ธ๋ก ๊ตฌ๋๋๋ ๋ฌธ์ ์ง์์๋ต ์์คํ **") | |
| if "conversation" not in st.session_state: | |
| st.session_state.conversation = None | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = None | |
| if "processComplete" not in st.session_state: | |
| st.session_state.processComplete = None | |
| with st.sidebar: | |
| st.header("โ๏ธ ์ค์ ") | |
| uploaded_files = st.file_uploader( | |
| "๐ ํ๊ตญ์ด ๋ฌธ์ ์ ๋ก๋", | |
| type=['pdf','docx'], | |
| accept_multiple_files=True, | |
| help="PDF, DOCX ํ์์ ํ๊ตญ์ด ๋ฌธ์๋ฅผ ์ ๋ก๋ํ์ธ์." | |
| ) | |
| st.subheader("๐ค AI ๋ชจ๋ธ ์ ํ") | |
| # ์ต๊ณ ์ฑ๋ฅ ํ๊ตญ์ด ๋ชจ๋ธ๋ค๋ก ๊ต์ฒด | |
| model_options = { | |
| "๐ฅ EEVE-Korean-10.8B (์ต๊ณ ์ฑ๋ฅ)": "yanolja/EEVE-Korean-Instruct-10.8B-v1.0", | |
| "๐ฅ Llama3-Korean-Bllossom-8B": "MLP-KTLim/llama-3-Korean-Bllossom-8B", | |
| "๐ฅ KoAlpaca-Polyglot-12.8B": "beomi/KoAlpaca-Polyglot-12.8B", | |
| "โก Kullm-Polyglot-5.8B (๋น ๋ฆ)": "nlpai-lab/kullm-polyglot-5.8b-v2", | |
| "๐ Korean-Vicuna-13B": "kfkas/Llama-2-ko-7b-Chat" | |
| } | |
| selected_model_name = st.selectbox( | |
| "๋ชจ๋ธ ์ ํ:", | |
| list(model_options.keys()), | |
| help="EEVE ๋ชจ๋ธ์ด ํ๊ตญ์ด ์ง์์ฌํญ ์ดํด์ ๊ฐ์ฅ ๋ฐ์ด๋ฉ๋๋ค." | |
| ) | |
| selected_model = model_options[selected_model_name] | |
| st.subheader("๐ ์๋ฒ ๋ฉ ๋ชจ๋ธ") | |
| embedding_options = { | |
| "๐ฐ๐ท KoSBERT (์ถ์ฒ)": "jhgan/ko-sroberta-multitask", | |
| "๐ฅ KoSimCSE": "BM-K/KoSimCSE-roberta-multitask", | |
| "โญ KR-SBERT": "snunlp/KR-SBERT-V40K-klueNLI-augSTS" | |
| } | |
| selected_embedding_name = st.selectbox( | |
| "์๋ฒ ๋ฉ ๋ชจ๋ธ:", | |
| list(embedding_options.keys()) | |
| ) | |
| selected_embedding = embedding_options[selected_embedding_name] | |
| st.subheader("โ๏ธ ๊ณ ๊ธ ์ค์ ") | |
| chunk_size = st.slider("์ฒญํฌ ํฌ๊ธฐ", 200, 1000, 400, help="ํ๊ตญ์ด๋ 400-600์๊ฐ ์ต์ ์ ๋๋ค.") | |
| chunk_overlap = st.slider("์ฒญํฌ ๊ฒน์นจ", 20, 200, 40, help="๊ฒน์นจ์ด ํด์๋ก ๋ฌธ๋งฅ ์ฐ๊ฒฐ์ฑ์ด ํฅ์๋ฉ๋๋ค.") | |
| temperature = st.slider("์ฐฝ์์ฑ (Temperature)", 0.1, 1.0, 0.3, help="๋ฎ์์๋ก ์ ํ, ๋์์๋ก ์ฐฝ์์ ") | |
| process = st.button("๐ ๋ฌธ์ ์ฒ๋ฆฌ ์์", type="primary") | |
| if process: | |
| if uploaded_files: | |
| with st.spinner("๐ฅ ์ต๊ณ ์ฑ๋ฅ ํ๊ตญ์ด AI๋ก ๋ฌธ์๋ฅผ ๋ถ์ ์ค์ ๋๋ค..."): | |
| try: | |
| files_text = get_text(uploaded_files) | |
| text_chunks = get_text_chunks(files_text, chunk_size, chunk_overlap) | |
| vectorstore = get_vectorstore(text_chunks, selected_embedding) | |
| st.session_state.conversation = get_conversation_chain(vectorstore, selected_model, temperature) | |
| st.session_state.processComplete = True | |
| st.success(f"โ {len(files_text)}๊ฐ ๋ฌธ์, {len(text_chunks)}๊ฐ ์ฒญํฌ๋ก ์ฒ๋ฆฌ ์๋ฃ!") | |
| st.balloons() | |
| except Exception as e: | |
| st.error(f"โ ๋ฌธ์ ์ฒ๋ฆฌ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
| else: | |
| st.error("๐ ํ์ผ์ ๋จผ์ ์ ๋ก๋ํด์ฃผ์ธ์!") | |
| if 'messages' not in st.session_state: | |
| st.session_state['messages'] = [{ | |
| "role": "assistant", | |
| "content": "์๋ ํ์ธ์! ๐ฐ๐ท **ํ๊ตญ์ด ์ ์ฉ ๊ณ ์ฑ๋ฅ AI ์ฑ๋ด**์ ๋๋ค.\n\n๐ **ํน์ง:**\n- ์ต์ ํ๊ตญ์ด ํนํ AI ๋ชจ๋ธ ์ฌ์ฉ\n- ๋ณต์กํ ์ง์์ฌํญ ์๋ฒฝ ์ดํด\n- ์ ํํ๊ณ ์์ฐ์ค๋ฌ์ด ํ๊ตญ์ด ๋ต๋ณ\n\n๐ ๋ฌธ์๋ฅผ ์ ๋ก๋ํ๊ณ '๋ฌธ์ ์ฒ๋ฆฌ ์์'์ ๋๋ฌ์ฃผ์ธ์!" | |
| }] | |
| # ์ฑํ ์ธํฐํ์ด์ค | |
| st.subheader("๐ฌ ๋ํ") | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if query := st.chat_input("๐ค ๋ฌธ์์ ๋ํด ๋ฌด์์ด๋ ๋ฌผ์ด๋ณด์ธ์... (๋ณต์กํ ์ง๋ฌธ๋ ํ์!)"): | |
| if st.session_state.conversation is None: | |
| st.error("๋จผ์ ํ์ผ์ ์ ๋ก๋ํ๊ณ '๋ฌธ์ ์ฒ๋ฆฌ ์์' ๋ฒํผ์ ๋๋ฌ์ฃผ์ธ์!") | |
| st.stop() | |
| st.session_state.messages.append({"role": "user", "content": query}) | |
| with st.chat_message("user"): | |
| st.markdown(query) | |
| with st.chat_message("assistant"): | |
| with st.spinner("๐ง ํ๊ตญ์ด AI๊ฐ ๊น์ด ๋ถ์ํ๊ณ ์์ต๋๋ค..."): | |
| try: | |
| # ํ๊ตญ์ด ํ๋กฌํํธ ์ต์ ํ | |
| enhanced_query = f"๋ค์ ์ง๋ฌธ์ ๋ํด ๋ฌธ์ ๋ด์ฉ์ ๋ฐํ์ผ๋ก ์ ํํ๊ณ ์์ธํ๊ฒ ํ๊ตญ์ด๋ก ๋ต๋ณํด์ฃผ์ธ์: {query}" | |
| result = st.session_state.conversation({"question": enhanced_query}) | |
| response = result['answer'] | |
| source_documents = result.get('source_documents', []) | |
| # ๋ต๋ณ ํ์ฒ๋ฆฌ | |
| if response: | |
| # ๋ถํ์ํ ์์ด ์ ๊ฑฐ ๋ฐ ํ๊ตญ์ด ๋ต๋ณ ์ถ์ถ | |
| response = clean_korean_response(response) | |
| st.markdown(response) | |
| else: | |
| st.markdown("์ฃ์กํฉ๋๋ค. ํด๋น ์ง๋ฌธ์ ๋ํ ๋ต๋ณ์ ๋ฌธ์์์ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| if source_documents: | |
| with st.expander("๐ ์ฐธ๊ณ ๋ฌธ์ ๋ฐ ๊ทผ๊ฑฐ"): | |
| for i, doc in enumerate(source_documents[:3]): | |
| st.markdown(f"**๐ ๋ฌธ์ {i+1}:** {doc.metadata.get('source', 'Unknown')}") | |
| with st.container(): | |
| st.text_area( | |
| f"๊ด๋ จ ๋ด์ฉ {i+1}", | |
| doc.page_content[:400] + "...", | |
| height=120, | |
| disabled=True | |
| ) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| except Exception as e: | |
| error_msg = f"โ ๋ต๋ณ ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}" | |
| st.error(error_msg) | |
| st.session_state.messages.append({"role": "assistant", "content": "์ฃ์กํฉ๋๋ค. ์ผ์์ ์ธ ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค. ๋ค์ ์๋ํด์ฃผ์ธ์."}) | |
| def clean_korean_response(response): | |
| """ํ๊ตญ์ด ๋ต๋ณ ์ ์ """ | |
| # ์์ด ํจํด ์ ๊ฑฐ | |
| response = re.sub(r'\b[A-Za-z]+\b', '', response) | |
| # ๋ถํ์ํ ๊ธฐํธ ์ ๋ฆฌ | |
| response = re.sub(r'[\[\]\(\)\{\}]', '', response) | |
| # ์ฐ์ ๊ณต๋ฐฑ ์ ๋ฆฌ | |
| response = re.sub(r'\s+', ' ', response).strip() | |
| return response | |
| def get_text(docs): | |
| """๋ฌธ์์์ ํ ์คํธ ์ถ์ถ ๋ฐ ์ ์ฒ๋ฆฌ""" | |
| doc_list = [] | |
| for doc in docs: | |
| file_name = doc.name | |
| with open(file_name, "wb") as file: | |
| file.write(doc.getvalue()) | |
| logger.info(f"Uploaded {file_name}") | |
| try: | |
| if '.pdf' in doc.name: | |
| loader = PyPDFLoader(file_name) | |
| documents = loader.load_and_split() | |
| elif '.docx' in doc.name: | |
| loader = Docx2txtLoader(file_name) | |
| documents = loader.load_and_split() | |
| # ๊ฐ ๋ฌธ์์ ํ ์คํธ ์ ์ฒ๋ฆฌ | |
| for document in documents: | |
| document.page_content = preprocess_korean_text(document.page_content) | |
| # ๋๋ฌด ์งง์ ์ฒญํฌ ์ ๊ฑฐ | |
| if len(document.page_content.strip()) < 50: | |
| continue | |
| doc_list.extend([doc for doc in documents if len(doc.page_content.strip()) >= 50]) | |
| except Exception as e: | |
| st.error(f"ํ์ผ {file_name} ์ฒ๋ฆฌ ์ค ์ค๋ฅ: {str(e)}") | |
| return doc_list | |
| def get_text_chunks(text, chunk_size=400, chunk_overlap=40): | |
| """ํ๊ตญ์ด ์ต์ ํ๋ ํ ์คํธ ์ฒญํน""" | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| length_function=len, | |
| separators=["\n\n", "\n", ".", "!", "?", ";", ":", ",", " ", ""] # ํ๊ตญ์ด ๊ตฌ๋ถ์ ์ต์ ํ | |
| ) | |
| chunks = text_splitter.split_documents(text) | |
| return chunks | |
| def get_vectorstore(text_chunks, embedding_model): | |
| """ํ๊ตญ์ด ํนํ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ์ฌ์ฉํ ๋ฒกํฐ ์คํ ์ด ์์ฑ""" | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=embedding_model, | |
| model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| vectordb = FAISS.from_documents(text_chunks, embeddings) | |
| return vectordb | |
| def get_conversation_chain(vectorstore, model_name, temperature): | |
| """ํ๊ตญ์ด ํนํ ๋ํ ์ฒด์ธ ์์ฑ""" | |
| try: | |
| # ํ๊ตญ์ด ํนํ ํ ํฌ๋์ด์ ๋ฐ ๋ชจ๋ธ ๋ก๋ฉ | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| # ํจ๋ฉ ํ ํฐ ์ค์ | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| torch_dtype="auto", | |
| device_map=None # GPU ์ฌ์ฉ ์ค์ ์ ๊ฑฐ | |
| ) | |
| # ํ๊ตญ์ด ์ต์ ํ ํ์ดํ๋ผ์ธ | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=512, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| device=-1, # CPU ์ฌ์ฉ | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| # ํ๊ตญ์ด ํนํ ๊ฒ์ ์ค์ | |
| conversation_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=vectorstore.as_retriever( | |
| search_type='mmr', | |
| search_kwargs={ | |
| 'k': 4, # ๋ ๋ง์ ๋ฌธ์ ๊ฒ์ | |
| 'fetch_k': 8, | |
| 'lambda_mult': 0.7 # ๋ค์์ฑ๊ณผ ๊ด๋ จ์ฑ ๊ท ํ | |
| } | |
| ), | |
| memory=ConversationBufferMemory( | |
| memory_key='chat_history', | |
| return_messages=True, | |
| output_key='answer' | |
| ), | |
| return_source_documents=True, | |
| verbose=True | |
| ) | |
| return conversation_chain | |
| except Exception as e: | |
| st.error(f"๋ชจ๋ธ ๋ก๋ฉ ์ค ์ค๋ฅ: {str(e)}") | |
| st.info("๋ ๊ฐ๋ฒผ์ด ๋ชจ๋ธ์ ์ ํํ๊ฑฐ๋ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ์ธํด์ฃผ์ธ์.") | |
| return None | |
| if __name__ == '__main__': | |
| main() |