File size: 7,418 Bytes
b7a77a6
 
 
 
 
 
 
 
 
 
 
 
 
40150eb
b7a77a6
 
40150eb
 
b7a77a6
 
 
 
 
40150eb
b7a77a6
 
 
 
 
088661b
 
b7a77a6
088661b
 
 
 
 
 
b7a77a6
40150eb
 
088661b
b7a77a6
 
088661b
 
 
 
 
 
 
 
 
 
 
 
b7a77a6
 
088661b
 
 
 
 
 
 
 
 
 
 
 
 
b7a77a6
088661b
b7a77a6
 
 
 
088661b
b7a77a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
088661b
 
 
b7a77a6
 
 
 
088661b
b7a77a6
088661b
 
 
b7a77a6
 
 
 
088661b
 
b7a77a6
 
088661b
 
b7a77a6
 
 
40150eb
 
b7a77a6
 
 
 
088661b
b7a77a6
40150eb
b7a77a6
088661b
b7a77a6
 
 
 
088661b
b7a77a6
 
088661b
 
b7a77a6
 
 
 
088661b
 
b7a77a6
 
 
 
 
 
 
 
40150eb
 
 
 
 
 
 
 
 
 
 
 
b7a77a6
 
40150eb
 
b7a77a6
088661b
b7a77a6
40150eb
b7a77a6
 
40150eb
b7a77a6
 
 
 
 
 
 
 
40150eb
b7a77a6
 
 
 
088661b
b7a77a6
 
 
 
 
 
 
 
 
 
 
 
 
 
088661b
b7a77a6
 
 
 
 
088661b
 
 
 
b7a77a6
088661b
 
 
b7a77a6
088661b
 
 
b7a77a6
088661b
 
 
b7a77a6
 
088661b
40150eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import os
import json
import tempfile

import streamlit as st
from dotenv import load_dotenv

# UI templates
from htmlTemplates import css, bot_template, user_template

# Text splitters
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter

# Vector store / embeddings
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings

# Loaders
from langchain_community.document_loaders.pdf import PyPDFLoader
from langchain_community.document_loaders.text import TextLoader
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain.docstore.document import Document

# LLM + chain
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_groq import ChatGroq


# ---------- PDF ----------
def get_pdf_text(pdf_docs):
    temp_dir = tempfile.TemporaryDirectory()
    temp_filepath = os.path.join(temp_dir.name, pdf_docs.name)
    with open(temp_filepath, "wb") as f:
        f.write(pdf_docs.getvalue())
    pdf_loader = PyPDFLoader(temp_filepath)
    pdf_doc = pdf_loader.load()
    # Keep temp_dir alive
    if "temp_dirs" not in st.session_state:
        st.session_state["temp_dirs"] = []
    st.session_state["temp_dirs"].append(temp_dir)
    return pdf_doc


# ---------- TXT ----------
def get_text_file(docs):
    temp_dir = tempfile.TemporaryDirectory()
    temp_filepath = os.path.join(temp_dir.name, docs.name)
    with open(temp_filepath, "wb") as f:
        f.write(docs.getvalue())
    text_loader = TextLoader(temp_filepath, encoding="utf-8")
    text_doc = text_loader.load()
    if "temp_dirs" not in st.session_state:
        st.session_state["temp_dirs"] = []
    st.session_state["temp_dirs"].append(temp_dir)
    return text_doc


# ---------- CSV ----------
def get_csv_file(docs):
    temp_dir = tempfile.TemporaryDirectory()
    temp_filepath = os.path.join(temp_dir.name, docs.name)
    with open(temp_filepath, "wb") as f:
        f.write(docs.getvalue())
    csv_loader = CSVLoader(temp_filepath, encoding="utf-8")
    csv_doc = csv_loader.load()
    if "temp_dirs" not in st.session_state:
        st.session_state["temp_dirs"] = []
    st.session_state["temp_dirs"].append(temp_dir)
    return csv_doc


# ---------- JSON ----------
def get_json_file(file) -> list[Document]:
    raw = file.getvalue().decode("utf-8", errors="ignore")
    data = json.loads(raw)

    docs = []

    def add_doc(x):
        docs.append(Document(page_content=json.dumps(x, ensure_ascii=False)))

    if isinstance(data, dict) and "scans" in data and isinstance(data["scans"], list):
        for s in data["scans"]:
            rels = s.get("relationships", [])
            if isinstance(rels, list) and rels:
                for r in rels:
                    add_doc(r)
        if not docs:
            add_doc(data)
    elif isinstance(data, list):
        for item in data:
            add_doc(item)
    else:
        add_doc(data)

    return docs


# ---------- Chunking ----------
def get_text_chunks(documents):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
        length_function=len,
    )
    return text_splitter.split_documents(documents)


# ---------- Vector store ----------
def get_vectorstore(text_chunks):
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L12-v2",
        model_kwargs={"device": "cpu"},
    )
    vectorstore = FAISS.from_documents(text_chunks, embeddings)
    return vectorstore


# ---------- Conversation chain ----------
def get_conversation_chain(vectorstore):
    llm = ChatGroq(
        groq_api_key=os.environ.get("GROQ_API_KEY"),
        model_name="llama-3.1-8b-instant",
        temperature=0.75,
        max_tokens=512,
    )

    memory = ConversationBufferMemory(
        memory_key="chat_history",
        return_messages=True
    )
    retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

    conversation_chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=retriever,
        memory=memory,
    )
    return conversation_chain


# ---------- UI ----------
def handle_userinput(user_question):
    if st.session_state.conversation is None:
        st.warning("λ¨Όμ € λ¬Έμ„œλ₯Ό μ—…λ‘œλ“œν•˜κ³  Process λ²„νŠΌμ„ λˆŒλŸ¬μ£Όμ„Έμš”.")
        return

    response = st.session_state.conversation({'question': user_question})
    st.session_state.chat_history = response['chat_history']

    for i, message in enumerate(st.session_state.chat_history):
        if i % 2 == 0:
            st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
        else:
            st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)


def process_files(docs, mode: str):
    mime_map = {
        "pdf": ["application/pdf", "application/octet-stream"],
        "txt": ["text/plain"],
        "csv": ["text/csv", "application/vnd.ms-excel"],
        "json": ["application/json"],
    }
    loader_map = {
        "pdf": get_pdf_text,
        "txt": get_text_file,
        "csv": get_csv_file,
        "json": get_json_file,
    }

    valid_mimes = mime_map[mode]
    loader_fn = loader_map[mode]

    doc_list = []
    for file in docs or []:
        if file.type in valid_mimes:
            doc_list.extend(loader_fn(file))
        else:
            st.error(f"{mode.upper()} 파일이 μ•„λ‹™λ‹ˆλ‹€. (받은 MIME: {file.type})")

    if not doc_list:
        st.error("처리 κ°€λŠ₯ν•œ λ¬Έμ„œλ₯Ό μ°Ύμ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€.")
        st.stop()

    text_chunks = get_text_chunks(doc_list)
    vectorstore = get_vectorstore(text_chunks)
    st.session_state.conversation = get_conversation_chain(vectorstore)
    st.success(f"{mode.upper()} λ¬Έμ„œ 처리 μ™„λ£Œ! 이제 μ§ˆλ¬Έμ„ μž…λ ₯ν•΄ λ³΄μ„Έμš”.")


def main():
    load_dotenv()
    st.set_page_config(page_title="Basic_RAG_AI_Chatbot_with_Llama", page_icon="πŸ“š")
    st.write(css, unsafe_allow_html=True)

    if "conversation" not in st.session_state:
        st.session_state.conversation = None
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = None

    st.header("Basic_RAG_AI_Chatbot_with_Llama3 πŸ“š")
    user_question = st.text_input("Ask a question about your documents:")
    if user_question:
        handle_userinput(user_question)

    with st.sidebar:
        st.subheader("Your documents")
        st.markdown("νŒŒμΌμ„ μ—…λ‘œλ“œν•œ ν›„ μ•„λž˜ λ²„νŠΌμ„ 눌러 μ²˜λ¦¬ν•˜μ„Έμš”.")
        docs = st.file_uploader(
            "Upload your Files here and click on 'Process'",
            accept_multiple_files=True
        )

        # λ²„νŠΌμ„ μ„Έλ‘œλ‘œ λ‚˜μ—΄ν•˜μ—¬ λͺ¨λ“  λ²„νŠΌμ΄ ν™•μ‹€νžˆ 보이도둝 함
        if st.button("Process[PDF]"):
            with st.spinner("Processing PDF..."):
                process_files(docs, "pdf")

        if st.button("Process[TXT]"):
            with st.spinner("Processing TXT..."):
                process_files(docs, "txt")

        if st.button("Process[CSV]"):
            with st.spinner("Processing CSV..."):
                process_files(docs, "csv")

        if st.button("Process[JSON]"):
            with st.spinner("Processing JSON..."):
                process_files(docs, "json")


if __name__ == '__main__':
    main()