langchain-prac-copy / src /streamlit_app.py
trlpop101's picture
change model -> sentence-transformers/clip-ViT-B-32-multilingual-v1
f0181e3 verified
import streamlit as st
from dotenv import load_dotenv
# from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
# from langchain.vectorstores import FAISS
# from langchain.embeddings import HuggingFaceEmbeddings # General embeddings from HuggingFace models.
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from htmlTemplates import css, bot_template, user_template
# from langchain.llms import LlamaCpp # For loading transformer models.
# from langchain.document_loaders import PyPDFLoader, TextLoader, JSONLoader, CSVLoader
# ํ…์ŠคํŠธ ์Šคํ”Œ๋ฆฌํ„ฐ
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
# ๋ฒกํ„ฐ์Šคํ† ์–ด/์ž„๋ฒ ๋”ฉ/LLM
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
# ๋กœ๋”๋“ค (pebblo/pwd ๋Œ๋ ค์˜ค์ง€ ์•Š๊ฒŒ ์„œ๋ธŒ๋ชจ๋“ˆ๋กœ)
from langchain_community.document_loaders.pdf import PyPDFLoader
from langchain_community.document_loaders.text import TextLoader
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders.json_loader import JSONLoader
import tempfile # ์ž„์‹œ ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž…๋‹ˆ๋‹ค.
import os
import json
from langchain.docstore.document import Document
from langchain_groq import ChatGroq
# PDF ๋ฌธ์„œ๋กœ๋ถ€ํ„ฐ ํ…์ŠคํŠธ๋ฅผ ์ถ”์ถœํ•˜๋Š” ํ•จ์ˆ˜
def get_pdf_text(pdf_docs):
temp_dir = tempfile.TemporaryDirectory()
temp_filepath = os.path.join(temp_dir.name, pdf_docs.name)
with open(temp_filepath, "wb") as f:
f.write(pdf_docs.getvalue())
pdf_loader = PyPDFLoader(temp_filepath)
pdf_doc = pdf_loader.load()
return pdf_doc
# txt ํŒŒ์ผ๋กœ๋ถ€ํ„ฐ text ์ถ”์ถœ
def get_text_file(txt_docs):
temp_dir = tempfile.TemporaryDirectory()
temp_filepath = os.path.join(temp_dir.name, txt_docs.name)
with open(temp_filepath, "wb") as f:
f.write(txt_docs.getvalue())
text_loader = TextLoader(temp_filepath)
text_doc = text_loader.load()
return text_doc
# csv ํŒŒ์ผ๋กœ๋ถ€ํ„ฐ text ์ถ”์ถœ
def get_csv_file(csv_docs):
temp_dir = tempfile.TemporaryDirectory()
temp_filepath = os.path.join(temp_dir.name, csv_docs.name)
with open(temp_filepath,"wb") as f:
f.write(csv_docs.getvalue())
csv_loader = CSVLoader(temp_filepath)
csv_doc = csv_loader.load()
return csv_doc
# def get_json_file(docs):
# temp_dir = tempfile.TemporaryDirectory()
# temp_filepath = os.path.join(temp_dir.name, docs.name)
# with open(temp_filepath, "wb") as f:
# f.write(docs.getvalue())
# json_loader = JSONLoader(temp_filepath,
# jq_schema='.scans[].relationships',
# text_content=False)
#
# json_doc = json_loader.load()
# # print('json_doc = ',json_doc)
# return json_doc
def get_json_file(file) -> list[Document]:
# Streamlit UploadedFile -> str
raw = file.getvalue().decode("utf-8", errors="ignore")
data = json.loads(raw)
docs = []
# ์˜ˆ์ „ jq ๊ฒฝ๋กœ๊ฐ€ '.scans[].relationships'์˜€๋‹ค๋ฉด, ๋™์ผํ•œ ์˜๋ฏธ๋กœ ํŒŒ์‹ฑ:
# ์กด์žฌํ•˜๋ฉด ๊ทธ๊ฒƒ๋งŒ ๋ฝ‘๊ณ , ์—†์œผ๋ฉด ํ†ต์œผ๋กœ ๋ฌธ์„œํ™”
def add_doc(x):
docs.append(Document(page_content=json.dumps(x, ensure_ascii=False)))
if isinstance(data, dict) and "scans" in data and isinstance(data["scans"], list):
for s in data["scans"]:
rels = s.get("relationships", [])
if isinstance(rels, list) and rels:
for r in rels:
add_doc(r)
if not docs: # ๊ทธ๋ž˜๋„ ๋ชป ๋ฝ‘์•˜์œผ๋ฉด ์ „์ฒด๋ฅผ ํ•˜๋‚˜๋กœ
add_doc(data)
elif isinstance(data, list):
for item in data:
add_doc(item)
else:
add_doc(data)
return docs
# ๋ฌธ์„œ๋“ค์„ ์ฒ˜๋ฆฌํ•˜์—ฌ ํ…์ŠคํŠธ ์ฒญํฌ๋กœ ๋‚˜๋ˆ„๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
def get_text_chunks(documents):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, # ์ฒญํฌ์˜ ํฌ๊ธฐ๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
chunk_overlap=200, # ์ฒญํฌ ์‚ฌ์ด์˜ ์ค‘๋ณต์„ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
length_function=len # ํ…์ŠคํŠธ์˜ ๊ธธ์ด๋ฅผ ์ธก์ •ํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
)
documents = text_splitter.split_documents(documents) # ๋ฌธ์„œ๋“ค์„ ์ฒญํฌ๋กœ ๋‚˜๋ˆ•๋‹ˆ๋‹ค.
return documents # ๋‚˜๋ˆˆ ์ฒญํฌ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
# ํ…์ŠคํŠธ ์ฒญํฌ๋“ค๋กœ๋ถ€ํ„ฐ ๋ฒกํ„ฐ ์Šคํ† ์–ด๋ฅผ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
def get_vectorstore(text_chunks):
# ์›ํ•˜๋Š” ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/clip-ViT-B-32-multilingual-v1',
model_kwargs={'device': 'cpu'}) # ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
vectorstore = FAISS.from_documents(text_chunks, embeddings) # FAISS ๋ฒกํ„ฐ ์Šคํ† ์–ด๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
return vectorstore # ์ƒ์„ฑ๋œ ๋ฒกํ„ฐ ์Šคํ† ์–ด๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
def get_conversation_chain(vectorstore):
# Groq LLM
llm = ChatGroq(
groq_api_key=os.environ.get("GROQ_API_KEY"),
model_name="llama-3.1-8b-instant",
temperature=0.75, # ํ•„์š”์— ๋งž๊ฒŒ ํŠœ๋‹
max_tokens=512 # ์ปจํ…์ŠคํŠธ ์ดˆ๊ณผ ๋ฐฉ์ง€์šฉ (ํ•„์š”์‹œ ์กฐ์ •)
)
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory,
)
return conversation_chain
# ์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
def handle_userinput(user_question):
print('user_question => ', user_question)
# ๋Œ€ํ™” ์ฒด์ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์šฉ์ž ์งˆ๋ฌธ์— ๋Œ€ํ•œ ์‘๋‹ต์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
response = st.session_state.conversation({'question': user_question})
# ๋Œ€ํ™” ๊ธฐ๋ก์„ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
st.session_state.chat_history = response['chat_history']
for i, message in enumerate(st.session_state.chat_history):
if i % 2 == 0:
st.write(user_template.replace(
"{{MSG}}", message.content), unsafe_allow_html=True)
else:
st.write(bot_template.replace(
"{{MSG}}", message.content), unsafe_allow_html=True)
def main():
load_dotenv()
st.set_page_config(page_title="Basic_RAG_AI_Chatbot_with_Llama",
page_icon=":books:")
st.write(css, unsafe_allow_html=True)
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
st.header("Basic_RAG_AI_Chatbot_with_Llama3 :books:")
user_question = st.text_input("Ask a question about your documents:")
if user_question:
handle_userinput(user_question)
with st.sidebar:
st.subheader("Your documents")
docs = st.file_uploader(
"Upload your Files here and click on 'Process'", accept_multiple_files=True)
if st.button("Process[PDF]"):
with st.spinner("Processing"):
# get pdf text
doc_list = []
for file in docs:
print('file - type : ', file.type)
if file.type in ['application/octet-stream', 'application/pdf']:
# file is .pdf
doc_list.extend(get_pdf_text(file))
else:
st.error("PDF ํŒŒ์ผ์ด ์•„๋‹™๋‹ˆ๋‹ค.")
if not doc_list:
st.error("์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•œ ๋ฌธ์„œ๋ฅผ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.")
st.stop()
text_chunks = get_text_chunks(doc_list)
vectorstore = get_vectorstore(text_chunks)
st.session_state.conversation = get_conversation_chain(vectorstore)
################## TXT, CSV ๋ฒ„ํŠผ ๊ตฌํ˜„
# TXT ๋ฒ„ํŠผ ๊ตฌํ˜„ ์ฐธ๊ณ  : if file.type == 'text/plain':
# CSV ๋ฒ„ํŠผ ๊ตฌํ˜„ ์ฐธ๊ณ  : if file.type == 'text/csv':
if st.button("Process[JSON]"):
with st.spinner("Processing"):
doc_list = []
for file in docs:
print('file - type : ', file.type)
if file.type == 'application/json':
# file is .json
doc_list.extend(get_json_file(file))
else:
st.error("JSON ํŒŒ์ผ์ด ์•„๋‹™๋‹ˆ๋‹ค.")
if not doc_list:
st.error("์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•œ ๋ฌธ์„œ๋ฅผ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.")
st.stop()
text_chunks = get_text_chunks(doc_list)
vectorstore = get_vectorstore(text_chunks)
st.session_state.conversation = get_conversation_chain(vectorstore)
if st.button("Process[TXT]"):
with st.spinner("Processing"):
# get txt text
doc_list = []
for file in docs:
print('file - type : ', file.type)
if file.type == 'text/plain':
doc_list.extend(get_text_file(file))
else:
st.error("TXT ํŒŒ์ผ์ด ์•„๋‹™๋‹ˆ๋‹ค.")
if not doc_list:
st.error("์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•œ ๋ฌธ์„œ๋ฅผ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.")
st.stop()
text_chunks = get_text_chunks(doc_list)
vectorstore = get_vectorstore(text_chunks)
st.session_state.conversation = get_conversation_chain(vectorstore)
if st.button("Process[CSV]"):
with st.spinner("Processing"):
# get csv text
doc_list = []
for file in docs:
print('file - type : ', file.type)
if file.type == 'text/csv':
doc_list.extend(get_csv_file(file))
else:
st.error("csv ํŒŒ์ผ์ด ์•„๋‹™๋‹ˆ๋‹ค.")
if not doc_list:
st.error("์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•œ ๋ฌธ์„œ๋ฅผ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.")
st.stop()
text_chunks = get_text_chunks(doc_list)
vectorstore = get_vectorstore(text_chunks)
st.session_state.conversation = get_conversation_chain(vectorstore)
if __name__ == '__main__':
main()