RAG_Chatbot / app.py
AD-Styles's picture
Update app.py
030cb07 verified
import os
import uuid
import gradio as gr
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
# 1. LLM ์ดˆ๊ธฐํ™”
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
# 2. ๋ฌธ์„œ ๋กœ๋“œ ๋ฐ ๋ฒกํ„ฐ DB ๊ตฌ์ถ•
loader = PyPDFLoader("Maximizing Muscle Hypertrophy.pdf")
pages = loader.load_and_split()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(pages)
embeddings = GoogleGenerativeAIEmbeddings(model="gemini-embedding-001")
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
retriever = vectorstore.as_retriever()
# ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ๋ฅผ ํ•˜๋‚˜์˜ ๋ฌธ์ž์—ด๋กœ ๊ฒฐํ•ฉํ•˜๋Š” ํ—ฌํผ ํ•จ์ˆ˜
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# 3. ํ”„๋กฌํ”„ํŠธ ์ •์˜
qa_prompt = ChatPromptTemplate.from_messages([
("system", """๋…ผ๋ฌธ ๋ฆฌ๋ทฐ ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค. ์ œ๊ณต๋œ ๋ฌธ์„œ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ํ•œ๊ตญ์–ด๋กœ ๋‹ต๋ณ€ํ•˜์„ธ์š”.
๋ฌธ์„œ์— ์—†๋Š” ๋‚ด์šฉ์€ ๋ชจ๋ฅธ๋‹ค๊ณ  ๋‹ตํ•˜์„ธ์š”.
{context}"""),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
# 4. ์—๋Ÿฌ๊ฐ€ ๋‚˜๋˜ chains ๋ชจ๋“ˆ์„ ๋ฒ„๋ฆฌ๊ณ  LCEL(ํŒŒ์ดํ”„๋ผ์ธ) ๋ฌธ๋ฒ•์œผ๋กœ RAG ์ฒด์ธ ๊ตฌ์ถ•
rag_chain = (
RunnablePassthrough.assign(context=(lambda x: format_docs(retriever.invoke(x["input"]))))
| qa_prompt
| llm
| StrOutputParser()
)
# 5. ๋ฉ”๋ชจ๋ฆฌ(๋Œ€ํ™” ๊ธฐ๋ก) ์—ฐ๋™
store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = InMemoryChatMessageHistory()
return store[session_id]
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
)
# 6. Gradio ์—ฐ๋™ ํ•จ์ˆ˜
def chat_response(message, history, session_id):
# LCEL ์ฒด์ธ์€ ๋”•์…”๋„ˆ๋ฆฌ๊ฐ€ ์•„๋‹Œ ๋ฌธ์ž์—ด์„ ๋ฐ”๋กœ ๋ฐ˜ํ™˜ํ•˜๋ฏ€๋กœ ["answer"] ์ถ”์ถœ์ด ํ•„์š” ์—†์Œ
response = conversational_rag_chain.invoke(
{"input": message},
config={"configurable": {"session_id": session_id}}
)
return response
# 7. ๋‹ค์ค‘ ์‚ฌ์šฉ์ž ํ™˜๊ฒฝ UI ์‹คํ–‰
with gr.Blocks() as demo:
session_state = gr.State(lambda: str(uuid.uuid4()))
gr.ChatInterface(
fn=chat_response,
additional_inputs=[session_state],
title="๐Ÿ’ช ๊ทผ๋น„๋Œ€ ๊ทน๋Œ€ํ™” ๋…ผ๋ฌธ Q&A ๋ด‡",
description="'Maximizing Muscle Hypertrophy' ๋…ผ๋ฌธ์— ๋Œ€ํ•ด ๊ถ๊ธˆํ•œ ์ ์„ ๋ฌผ์–ด๋ณด์„ธ์š”!"
)
if __name__ == "__main__":
demo.launch()