Llama2 / src /streamlit_app.py
jamjammin's picture
Update src/streamlit_app.py
088661b verified
import os
import json
import tempfile
import streamlit as st
from dotenv import load_dotenv
# UI templates
from htmlTemplates import css, bot_template, user_template
# Text splitters
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
# Vector store / embeddings
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
# Loaders
from langchain_community.document_loaders.pdf import PyPDFLoader
from langchain_community.document_loaders.text import TextLoader
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain.docstore.document import Document
# LLM + chain
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_groq import ChatGroq
# ---------- PDF ----------
def get_pdf_text(pdf_docs):
temp_dir = tempfile.TemporaryDirectory()
temp_filepath = os.path.join(temp_dir.name, pdf_docs.name)
with open(temp_filepath, "wb") as f:
f.write(pdf_docs.getvalue())
pdf_loader = PyPDFLoader(temp_filepath)
pdf_doc = pdf_loader.load()
# Keep temp_dir alive
if "temp_dirs" not in st.session_state:
st.session_state["temp_dirs"] = []
st.session_state["temp_dirs"].append(temp_dir)
return pdf_doc
# ---------- TXT ----------
def get_text_file(docs):
temp_dir = tempfile.TemporaryDirectory()
temp_filepath = os.path.join(temp_dir.name, docs.name)
with open(temp_filepath, "wb") as f:
f.write(docs.getvalue())
text_loader = TextLoader(temp_filepath, encoding="utf-8")
text_doc = text_loader.load()
if "temp_dirs" not in st.session_state:
st.session_state["temp_dirs"] = []
st.session_state["temp_dirs"].append(temp_dir)
return text_doc
# ---------- CSV ----------
def get_csv_file(docs):
temp_dir = tempfile.TemporaryDirectory()
temp_filepath = os.path.join(temp_dir.name, docs.name)
with open(temp_filepath, "wb") as f:
f.write(docs.getvalue())
csv_loader = CSVLoader(temp_filepath, encoding="utf-8")
csv_doc = csv_loader.load()
if "temp_dirs" not in st.session_state:
st.session_state["temp_dirs"] = []
st.session_state["temp_dirs"].append(temp_dir)
return csv_doc
# ---------- JSON ----------
def get_json_file(file) -> list[Document]:
raw = file.getvalue().decode("utf-8", errors="ignore")
data = json.loads(raw)
docs = []
def add_doc(x):
docs.append(Document(page_content=json.dumps(x, ensure_ascii=False)))
if isinstance(data, dict) and "scans" in data and isinstance(data["scans"], list):
for s in data["scans"]:
rels = s.get("relationships", [])
if isinstance(rels, list) and rels:
for r in rels:
add_doc(r)
if not docs:
add_doc(data)
elif isinstance(data, list):
for item in data:
add_doc(item)
else:
add_doc(data)
return docs
# ---------- Chunking ----------
def get_text_chunks(documents):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
)
return text_splitter.split_documents(documents)
# ---------- Vector store ----------
def get_vectorstore(text_chunks):
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L12-v2",
model_kwargs={"device": "cpu"},
)
vectorstore = FAISS.from_documents(text_chunks, embeddings)
return vectorstore
# ---------- Conversation chain ----------
def get_conversation_chain(vectorstore):
llm = ChatGroq(
groq_api_key=os.environ.get("GROQ_API_KEY"),
model_name="llama-3.1-8b-instant",
temperature=0.75,
max_tokens=512,
)
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory,
)
return conversation_chain
# ---------- UI ----------
def handle_userinput(user_question):
if st.session_state.conversation is None:
st.warning("λ¨Όμ € λ¬Έμ„œλ₯Ό μ—…λ‘œλ“œν•˜κ³  Process λ²„νŠΌμ„ λˆŒλŸ¬μ£Όμ„Έμš”.")
return
response = st.session_state.conversation({'question': user_question})
st.session_state.chat_history = response['chat_history']
for i, message in enumerate(st.session_state.chat_history):
if i % 2 == 0:
st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
else:
st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
def process_files(docs, mode: str):
mime_map = {
"pdf": ["application/pdf", "application/octet-stream"],
"txt": ["text/plain"],
"csv": ["text/csv", "application/vnd.ms-excel"],
"json": ["application/json"],
}
loader_map = {
"pdf": get_pdf_text,
"txt": get_text_file,
"csv": get_csv_file,
"json": get_json_file,
}
valid_mimes = mime_map[mode]
loader_fn = loader_map[mode]
doc_list = []
for file in docs or []:
if file.type in valid_mimes:
doc_list.extend(loader_fn(file))
else:
st.error(f"{mode.upper()} 파일이 μ•„λ‹™λ‹ˆλ‹€. (받은 MIME: {file.type})")
if not doc_list:
st.error("처리 κ°€λŠ₯ν•œ λ¬Έμ„œλ₯Ό μ°Ύμ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€.")
st.stop()
text_chunks = get_text_chunks(doc_list)
vectorstore = get_vectorstore(text_chunks)
st.session_state.conversation = get_conversation_chain(vectorstore)
st.success(f"{mode.upper()} λ¬Έμ„œ 처리 μ™„λ£Œ! 이제 μ§ˆλ¬Έμ„ μž…λ ₯ν•΄ λ³΄μ„Έμš”.")
def main():
load_dotenv()
st.set_page_config(page_title="Basic_RAG_AI_Chatbot_with_Llama", page_icon="πŸ“š")
st.write(css, unsafe_allow_html=True)
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
st.header("Basic_RAG_AI_Chatbot_with_Llama3 πŸ“š")
user_question = st.text_input("Ask a question about your documents:")
if user_question:
handle_userinput(user_question)
with st.sidebar:
st.subheader("Your documents")
st.markdown("νŒŒμΌμ„ μ—…λ‘œλ“œν•œ ν›„ μ•„λž˜ λ²„νŠΌμ„ 눌러 μ²˜λ¦¬ν•˜μ„Έμš”.")
docs = st.file_uploader(
"Upload your Files here and click on 'Process'",
accept_multiple_files=True
)
# λ²„νŠΌμ„ μ„Έλ‘œλ‘œ λ‚˜μ—΄ν•˜μ—¬ λͺ¨λ“  λ²„νŠΌμ΄ ν™•μ‹€νžˆ 보이도둝 함
if st.button("Process[PDF]"):
with st.spinner("Processing PDF..."):
process_files(docs, "pdf")
if st.button("Process[TXT]"):
with st.spinner("Processing TXT..."):
process_files(docs, "txt")
if st.button("Process[CSV]"):
with st.spinner("Processing CSV..."):
process_files(docs, "csv")
if st.button("Process[JSON]"):
with st.spinner("Processing JSON..."):
process_files(docs, "json")
if __name__ == '__main__':
main()