immunochat / app.py
lss9566's picture
Upload app.py
3980253 verified
import streamlit as st
import tiktoken
import re
from loguru import logger
from langchain.chains import ConversationalRetrievalChain
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.document_loaders import PyPDFLoader
from langchain.document_loaders import Docx2txtLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.vectorstores import FAISS
from langchain.memory import StreamlitChatMessageHistory
def preprocess_korean_text(text):
"""ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜"""
# ๋ถˆํ•„์š”ํ•œ ํŠน์ˆ˜๋ฌธ์ž ์ œ๊ฑฐ (ํ•œ๊ตญ์–ด, ์˜์–ด, ์ˆซ์ž, ๊ณต๋ฐฑ๋งŒ ์œ ์ง€)
text = re.sub(r'[^๊ฐ€-ํžฃa-zA-Z0-9\s.,!?]', ' ', text)
# ์—ฐ์†๋œ ๊ณต๋ฐฑ์„ ํ•˜๋‚˜๋กœ ํ†ตํ•ฉ
text = re.sub(r'\s+', ' ', text).strip()
return text
def main():
st.set_page_config(
page_title="ํ•œ๊ตญ์–ด ๋ฌธ์„œ QA ์ฑ—๋ด‡",
page_icon="๐Ÿ‡ฐ๐Ÿ‡ท",
layout="wide"
)
st.title("๐Ÿ‡ฐ๐Ÿ‡ท _ํ•œ๊ตญ์–ด ์ „์šฉ ๋ฌธ์„œ :red[QA ์ฑ—๋ด‡]_ ๐Ÿ“š")
st.markdown("**์ตœ๊ณ  ์„ฑ๋Šฅ์˜ ํ•œ๊ตญ์–ด AI ๋ชจ๋ธ๋กœ ๊ตฌ๋™๋˜๋Š” ๋ฌธ์„œ ์งˆ์˜์‘๋‹ต ์‹œ์Šคํ…œ**")
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
if "processComplete" not in st.session_state:
st.session_state.processComplete = None
with st.sidebar:
st.header("โš™๏ธ ์„ค์ •")
uploaded_files = st.file_uploader(
"๐Ÿ“ ํ•œ๊ตญ์–ด ๋ฌธ์„œ ์—…๋กœ๋“œ",
type=['pdf','docx'],
accept_multiple_files=True,
help="PDF, DOCX ํ˜•์‹์˜ ํ•œ๊ตญ์–ด ๋ฌธ์„œ๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”."
)
st.subheader("๐Ÿค– AI ๋ชจ๋ธ ์„ ํƒ")
# ์ตœ๊ณ  ์„ฑ๋Šฅ ํ•œ๊ตญ์–ด ๋ชจ๋ธ๋“ค๋กœ ๊ต์ฒด
model_options = {
"๐Ÿฅ‡ EEVE-Korean-10.8B (์ตœ๊ณ  ์„ฑ๋Šฅ)": "yanolja/EEVE-Korean-Instruct-10.8B-v1.0",
"๐Ÿฅˆ Llama3-Korean-Bllossom-8B": "MLP-KTLim/llama-3-Korean-Bllossom-8B",
"๐Ÿฅ‰ KoAlpaca-Polyglot-12.8B": "beomi/KoAlpaca-Polyglot-12.8B",
"โšก Kullm-Polyglot-5.8B (๋น ๋ฆ„)": "nlpai-lab/kullm-polyglot-5.8b-v2",
"๐Ÿ’Ž Korean-Vicuna-13B": "kfkas/Llama-2-ko-7b-Chat"
}
selected_model_name = st.selectbox(
"๋ชจ๋ธ ์„ ํƒ:",
list(model_options.keys()),
help="EEVE ๋ชจ๋ธ์ด ํ•œ๊ตญ์–ด ์ง€์‹œ์‚ฌํ•ญ ์ดํ•ด์— ๊ฐ€์žฅ ๋›ฐ์–ด๋‚ฉ๋‹ˆ๋‹ค."
)
selected_model = model_options[selected_model_name]
st.subheader("๐Ÿ“Š ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ")
embedding_options = {
"๐Ÿ‡ฐ๐Ÿ‡ท KoSBERT (์ถ”์ฒœ)": "jhgan/ko-sroberta-multitask",
"๐Ÿ”ฅ KoSimCSE": "BM-K/KoSimCSE-roberta-multitask",
"โญ KR-SBERT": "snunlp/KR-SBERT-V40K-klueNLI-augSTS"
}
selected_embedding_name = st.selectbox(
"์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ:",
list(embedding_options.keys())
)
selected_embedding = embedding_options[selected_embedding_name]
st.subheader("โš™๏ธ ๊ณ ๊ธ‰ ์„ค์ •")
chunk_size = st.slider("์ฒญํฌ ํฌ๊ธฐ", 200, 1000, 400, help="ํ•œ๊ตญ์–ด๋Š” 400-600์ž๊ฐ€ ์ตœ์ ์ž…๋‹ˆ๋‹ค.")
chunk_overlap = st.slider("์ฒญํฌ ๊ฒน์นจ", 20, 200, 40, help="๊ฒน์นจ์ด ํด์ˆ˜๋ก ๋ฌธ๋งฅ ์—ฐ๊ฒฐ์„ฑ์ด ํ–ฅ์ƒ๋ฉ๋‹ˆ๋‹ค.")
temperature = st.slider("์ฐฝ์˜์„ฑ (Temperature)", 0.1, 1.0, 0.3, help="๋‚ฎ์„์ˆ˜๋ก ์ •ํ™•, ๋†’์„์ˆ˜๋ก ์ฐฝ์˜์ ")
process = st.button("๐Ÿš€ ๋ฌธ์„œ ์ฒ˜๋ฆฌ ์‹œ์ž‘", type="primary")
if process:
if uploaded_files:
with st.spinner("๐Ÿ”ฅ ์ตœ๊ณ  ์„ฑ๋Šฅ ํ•œ๊ตญ์–ด AI๋กœ ๋ฌธ์„œ๋ฅผ ๋ถ„์„ ์ค‘์ž…๋‹ˆ๋‹ค..."):
try:
files_text = get_text(uploaded_files)
text_chunks = get_text_chunks(files_text, chunk_size, chunk_overlap)
vectorstore = get_vectorstore(text_chunks, selected_embedding)
st.session_state.conversation = get_conversation_chain(vectorstore, selected_model, temperature)
st.session_state.processComplete = True
st.success(f"โœ… {len(files_text)}๊ฐœ ๋ฌธ์„œ, {len(text_chunks)}๊ฐœ ์ฒญํฌ๋กœ ์ฒ˜๋ฆฌ ์™„๋ฃŒ!")
st.balloons()
except Exception as e:
st.error(f"โŒ ๋ฌธ์„œ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
else:
st.error("๐Ÿ“ ํŒŒ์ผ์„ ๋จผ์ € ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”!")
if 'messages' not in st.session_state:
st.session_state['messages'] = [{
"role": "assistant",
"content": "์•ˆ๋…•ํ•˜์„ธ์š”! ๐Ÿ‡ฐ๐Ÿ‡ท **ํ•œ๊ตญ์–ด ์ „์šฉ ๊ณ ์„ฑ๋Šฅ AI ์ฑ—๋ด‡**์ž…๋‹ˆ๋‹ค.\n\n๐Ÿ“š **ํŠน์ง•:**\n- ์ตœ์‹  ํ•œ๊ตญ์–ด ํŠนํ™” AI ๋ชจ๋ธ ์‚ฌ์šฉ\n- ๋ณต์žกํ•œ ์ง€์‹œ์‚ฌํ•ญ ์™„๋ฒฝ ์ดํ•ด\n- ์ •ํ™•ํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ํ•œ๊ตญ์–ด ๋‹ต๋ณ€\n\n๐Ÿ“ ๋ฌธ์„œ๋ฅผ ์—…๋กœ๋“œํ•˜๊ณ  '๋ฌธ์„œ ์ฒ˜๋ฆฌ ์‹œ์ž‘'์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”!"
}]
# ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค
st.subheader("๐Ÿ’ฌ ๋Œ€ํ™”")
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if query := st.chat_input("๐Ÿค” ๋ฌธ์„œ์— ๋Œ€ํ•ด ๋ฌด์—‡์ด๋“  ๋ฌผ์–ด๋ณด์„ธ์š”... (๋ณต์žกํ•œ ์งˆ๋ฌธ๋„ ํ™˜์˜!)"):
if st.session_state.conversation is None:
st.error("๋จผ์ € ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•˜๊ณ  '๋ฌธ์„œ ์ฒ˜๋ฆฌ ์‹œ์ž‘' ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”!")
st.stop()
st.session_state.messages.append({"role": "user", "content": query})
with st.chat_message("user"):
st.markdown(query)
with st.chat_message("assistant"):
with st.spinner("๐Ÿง  ํ•œ๊ตญ์–ด AI๊ฐ€ ๊นŠ์ด ๋ถ„์„ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค..."):
try:
# ํ•œ๊ตญ์–ด ํ”„๋กฌํ”„ํŠธ ์ตœ์ ํ™”
enhanced_query = f"๋‹ค์Œ ์งˆ๋ฌธ์— ๋Œ€ํ•ด ๋ฌธ์„œ ๋‚ด์šฉ์„ ๋ฐ”ํƒ•์œผ๋กœ ์ •ํ™•ํ•˜๊ณ  ์ƒ์„ธํ•˜๊ฒŒ ํ•œ๊ตญ์–ด๋กœ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”: {query}"
result = st.session_state.conversation({"question": enhanced_query})
response = result['answer']
source_documents = result.get('source_documents', [])
# ๋‹ต๋ณ€ ํ›„์ฒ˜๋ฆฌ
if response:
# ๋ถˆํ•„์š”ํ•œ ์˜์–ด ์ œ๊ฑฐ ๋ฐ ํ•œ๊ตญ์–ด ๋‹ต๋ณ€ ์ถ”์ถœ
response = clean_korean_response(response)
st.markdown(response)
else:
st.markdown("์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ํ•ด๋‹น ์งˆ๋ฌธ์— ๋Œ€ํ•œ ๋‹ต๋ณ€์„ ๋ฌธ์„œ์—์„œ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
if source_documents:
with st.expander("๐Ÿ“– ์ฐธ๊ณ  ๋ฌธ์„œ ๋ฐ ๊ทผ๊ฑฐ"):
for i, doc in enumerate(source_documents[:3]):
st.markdown(f"**๐Ÿ“„ ๋ฌธ์„œ {i+1}:** {doc.metadata.get('source', 'Unknown')}")
with st.container():
st.text_area(
f"๊ด€๋ จ ๋‚ด์šฉ {i+1}",
doc.page_content[:400] + "...",
height=120,
disabled=True
)
st.session_state.messages.append({"role": "assistant", "content": response})
except Exception as e:
error_msg = f"โŒ ๋‹ต๋ณ€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
st.error(error_msg)
st.session_state.messages.append({"role": "assistant", "content": "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์ผ์‹œ์ ์ธ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”."})
def clean_korean_response(response):
"""ํ•œ๊ตญ์–ด ๋‹ต๋ณ€ ์ •์ œ"""
# ์˜์–ด ํŒจํ„ด ์ œ๊ฑฐ
response = re.sub(r'\b[A-Za-z]+\b', '', response)
# ๋ถˆํ•„์š”ํ•œ ๊ธฐํ˜ธ ์ •๋ฆฌ
response = re.sub(r'[\[\]\(\)\{\}]', '', response)
# ์—ฐ์† ๊ณต๋ฐฑ ์ •๋ฆฌ
response = re.sub(r'\s+', ' ', response).strip()
return response
def get_text(docs):
"""๋ฌธ์„œ์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ ๋ฐ ์ „์ฒ˜๋ฆฌ"""
doc_list = []
for doc in docs:
file_name = doc.name
with open(file_name, "wb") as file:
file.write(doc.getvalue())
logger.info(f"Uploaded {file_name}")
try:
if '.pdf' in doc.name:
loader = PyPDFLoader(file_name)
documents = loader.load_and_split()
elif '.docx' in doc.name:
loader = Docx2txtLoader(file_name)
documents = loader.load_and_split()
# ๊ฐ ๋ฌธ์„œ์˜ ํ…์ŠคํŠธ ์ „์ฒ˜๋ฆฌ
for document in documents:
document.page_content = preprocess_korean_text(document.page_content)
# ๋„ˆ๋ฌด ์งง์€ ์ฒญํฌ ์ œ๊ฑฐ
if len(document.page_content.strip()) < 50:
continue
doc_list.extend([doc for doc in documents if len(doc.page_content.strip()) >= 50])
except Exception as e:
st.error(f"ํŒŒ์ผ {file_name} ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜: {str(e)}")
return doc_list
def get_text_chunks(text, chunk_size=400, chunk_overlap=40):
"""ํ•œ๊ตญ์–ด ์ตœ์ ํ™”๋œ ํ…์ŠคํŠธ ์ฒญํ‚น"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
separators=["\n\n", "\n", ".", "!", "?", ";", ":", ",", " ", ""] # ํ•œ๊ตญ์–ด ๊ตฌ๋ถ„์ž ์ตœ์ ํ™”
)
chunks = text_splitter.split_documents(text)
return chunks
def get_vectorstore(text_chunks, embedding_model):
"""ํ•œ๊ตญ์–ด ํŠนํ™” ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ๋ฒกํ„ฐ ์Šคํ† ์–ด ์ƒ์„ฑ"""
embeddings = HuggingFaceEmbeddings(
model_name=embedding_model,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
vectordb = FAISS.from_documents(text_chunks, embeddings)
return vectordb
def get_conversation_chain(vectorstore, model_name, temperature):
"""ํ•œ๊ตญ์–ด ํŠนํ™” ๋Œ€ํ™” ์ฒด์ธ ์ƒ์„ฑ"""
try:
# ํ•œ๊ตญ์–ด ํŠนํ™” ํ† ํฌ๋‚˜์ด์ € ๋ฐ ๋ชจ๋ธ ๋กœ๋”ฉ
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# ํŒจ๋”ฉ ํ† ํฐ ์„ค์ •
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype="auto",
device_map=None # GPU ์‚ฌ์šฉ ์„ค์ • ์ œ๊ฑฐ
)
# ํ•œ๊ตญ์–ด ์ตœ์ ํ™” ํŒŒ์ดํ”„๋ผ์ธ
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=temperature,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
device=-1, # CPU ์‚ฌ์šฉ
pad_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=pipe)
# ํ•œ๊ตญ์–ด ํŠนํ™” ๊ฒ€์ƒ‰ ์„ค์ •
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(
search_type='mmr',
search_kwargs={
'k': 4, # ๋” ๋งŽ์€ ๋ฌธ์„œ ๊ฒ€์ƒ‰
'fetch_k': 8,
'lambda_mult': 0.7 # ๋‹ค์–‘์„ฑ๊ณผ ๊ด€๋ จ์„ฑ ๊ท ํ˜•
}
),
memory=ConversationBufferMemory(
memory_key='chat_history',
return_messages=True,
output_key='answer'
),
return_source_documents=True,
verbose=True
)
return conversation_chain
except Exception as e:
st.error(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜: {str(e)}")
st.info("๋” ๊ฐ€๋ฒผ์šด ๋ชจ๋ธ์„ ์„ ํƒํ•˜๊ฑฐ๋‚˜ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ™•์ธํ•ด์ฃผ์„ธ์š”.")
return None
if __name__ == '__main__':
main()