Spaces:
Sleeping
Sleeping
File size: 5,925 Bytes
0668b2d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import os
import gradio as gr
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
# 1. ๋ฌธ์ ๋ก๋ ๋ฐ ๋ฒกํฐ DB ๊ตฌ์ถ (์๋ฒ ๊ตฌ๋ ์ 1ํ ๊ณ ์ )
loader = PyPDFLoader("Maximizing Muscle Hypertrophy.pdf")
pages = loader.load_and_split()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(pages)
embeddings = GoogleGenerativeAIEmbeddings(model="gemini-embedding-001")
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
# ๋ฏธ์
3: ๋๋ฉ์ธ ๋ง์ถค ์์คํ
ํ๋กฌํํธ
SYSTEM_PROMPT = """๋น์ ์ ์คํฌ์ธ ์์ํ ๋ฐ ๊ทผ๋น๋(Muscle Hypertrophy) ํ๋ จ ๋ถ์ผ์ ์ต๊ณ ๊ถ์์์ด์ ๋
ผ๋ฌธ ๋ฆฌ๋ทฐ ์ ๋ฌธ๊ฐ์
๋๋ค.
์ ๊ณต๋ [๋
ผ๋ฌธ ์ปจํ
์คํธ]๋ฅผ ๋ฐํ์ผ๋ก ์ฌ์ฉ์์ ์ง๋ฌธ์ ์ ๋ฌธ์ ์ด๊ณ ๋ช
ํํ๋ฉฐ ๊ฐ๊ด์ ์ธ ์ด์กฐ๋ก ๋ต๋ณํ์ธ์.
[์ ์ฝ ์กฐ๊ฑด]
1. ๋ฐ๋์ ์ ๊ณต๋ ์ปจํ
์คํธ ๋ด์ ์ ๋ณด๋ง์ ์ฌ์ฉํ์ฌ ๋ต๋ณํ์ธ์.
2. ๋
ผ๋ฌธ์ ์๋ ๋ด์ฉ์ ์ง๋ฌธํ๋ฉด "ํด๋น ๋ด์ฉ์ ์ ๊ณต๋ ๋
ผ๋ฌธ์์ ํ์ธํ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ช
ํํ ์ ์ ๊ทธ์ผ์ธ์.
3. ๊ทผ์ก ์ฑ์ฅ ๊ธฐ์ ์ด๋ ํ๋ จ๋ฒ์ ์ค๋ช
ํ ๋๋ ์ผ๋ฐ์ธ๋ ์ดํดํ๊ธฐ ์ฝ๊ฒ ๋จ๊ณ๋ณ๋ก ๊ตฌ์กฐํํ์ฌ ์ค๋ช
ํ์ธ์.
4. ๋ชจ๋ ๋ต๋ณ์ ํ๊ตญ์ด๋ก ์์ฑํ๋ฉฐ, ์ฃผ์ ์ํ ๋ฐ ์ด๋ํ ์ ๋ฌธ ์ฉ์ด๋ ๊ดํธ ์์ ์๋ฌธ์ ๋ณ๊ธฐํ์ธ์ (์: ๋จ๋ฐฑ์ง ํฉ์ฑ(Protein Synthesis)).
[๋
ผ๋ฌธ ์ปจํ
์คํธ]
{context}"""
qa_prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("placeholder", "{chat_history}"),
("human", "{input}"),
])
# Gradio์ ๋ํ ๊ธฐ๋ก ํ์์ LangChain์ด ์ดํดํ ์ ์๊ฒ ๋ณํํ๋ ํฌํผ ํจ์
def format_history(history):
formatted = []
for user_msg, ai_msg in history:
formatted.append(("human", user_msg))
formatted.append(("ai", ai_msg))
return formatted
# ๋ฏธ์
1, 2, 5 ํตํฉ: ์คํธ๋ฆฌ๋ฐ, ๋์ ์ค์ , ์ถ์ฒ ํ์ฑ
def chat_response(message, history, temperature, k, model_name):
# ๋ฏธ์
2: UI์์ ๋๊ฒจ๋ฐ์ k ๊ฐ์ผ๋ก ๊ฒ์ ๋ฒ์ ๋์ ์กฐ์
docs = vectorstore.similarity_search(message, k=k)
context = "\n\n".join(doc.page_content for doc in docs)
# ๋ฏธ์
2: UI์์ ๋๊ฒจ๋ฐ์ ๋ชจ๋ธ๊ณผ ์จ๋๋ก LLM ๋์ ์์ฑ
llm = ChatGoogleGenerativeAI(model=model_name, temperature=temperature)
# ํ๋กฌํํธ ์กฐ๋ฆฝ
prompt_value = qa_prompt.invoke({
"context": context,
"chat_history": format_history(history),
"input": message
})
partial_message = ""
# ๋ฏธ์
5: llm.stream()์ ํ์ฉํ ์ค์๊ฐ ์คํธ๋ฆฌ๋ฐ ์ถ๋ ฅ
for chunk in llm.stream(prompt_value):
partial_message += chunk.content
yield partial_message # ๊ธ์๊ฐ ์์ฑ๋ ๋๋ง๋ค UI๋ก ๋ฐ์ด๋
# ๋ฏธ์
1: PyPDFLoader ๋ฉํ๋ฐ์ดํฐ์์ ์ถ์ฒ ๋ฐ ํ์ด์ง ์ถ์ถ (page๋ 0๋ถํฐ ์์ํ๋ฏ๋ก +1)
sources = []
for doc in docs:
source_file = os.path.basename(doc.metadata.get('source', 'Unknown'))
page_num = doc.metadata.get('page', 0) + 1
sources.append(f"{source_file} (p.{page_num})")
# ๋ฆฌ์คํธ ์ค๋ณต ์ ๊ฑฐ ํ ์ต์ข
ํ
์คํธ ์กฐ๋ฆฝ
unique_sources = list(dict.fromkeys(sources))
source_str = "\n\n๐ **์ถ์ฒ:** " + ", ".join(unique_sources)
# ์ต์ข
์ ์ผ๋ก ๋ต๋ณ ๋์ ์ถ์ฒ๋ฅผ ๋ง๋ถ์ฌ์ ์ ์ก
yield partial_message + source_str
# ๋ฏธ์
4: ๋ํ ๋ด์ญ ๋ค์ด๋ก๋ ํ์ผ ์์ฑ ํจ์
def download_chat_history(history):
file_path = "chat_history.txt"
with open(file_path, "w", encoding="utf-8") as f:
for user_msg, ai_msg in history:
f.write(f"๐งโ๐ป ์ฌ์ฉ์: {user_msg}\n")
f.write(f"๐ค AI: {ai_msg}\n")
f.write("-" * 50 + "\n")
return file_path
# UI ๋ ์ด์์ ๊ตฌ์ฑ
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## ๐ช ๊ทผ๋น๋ ๊ทน๋ํ ๋
ผ๋ฌธ Q&A ๋ด (Pro Version)")
# ๋ฏธ์
2: ์ ์ ์ ์๋ ์ค์ ํจ๋
with gr.Accordion("โ๏ธ ์ฑ๋ด ์์ธ ์ค์ ", open=False):
with gr.Row():
model_dd = gr.Dropdown(choices=["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"], value="gemini-2.0-flash", label="๐ค ๋ชจ๋ธ ์ ํ")
temp_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="๐ก๏ธ Temperature (์ฐฝ์์ฑ/ํ๊ฐ ์กฐ์ )")
k_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="๐ ์ฐธ๊ณ ํ ๋ฌธ์ ์กฐ๊ฐ ์ (k)")
# ํต์ฌ ์ฑ๋ด ์ธํฐํ์ด์ค (์ค์ ํจ๋์ ๊ฐ๋ค์ additional_inputs๋ก ์ฐ๊ฒฐ)
chat_interface = gr.ChatInterface(
fn=chat_response,
additional_inputs=[temp_slider, k_slider, model_dd],
chatbot=gr.Chatbot(height=500),
title="",
description="'Maximizing Muscle Hypertrophy' ๋
ผ๋ฌธ ๋ด์ฉ์ ๋ฐํ์ผ๋ก ๊ทผ์ฑ์ฅ ๋ฉ์ปค๋์ฆ์ ์ง๋ฌธํด ๋ณด์ธ์."
)
# ๋ฏธ์
4: ๋ํ ๋ด์ญ ๋ค์ด๋ก๋ ์์ญ
with gr.Row():
download_btn = gr.Button("๐พ ํ์ฌ ๋ํ ๋ด์ญ ์ ์ฅ ๋ฐ ๋ค์ด๋ก๋", variant="primary")
download_file = gr.File(label="๋ค์ด๋ก๋ ์ค๋น ์๋ฃ (๋ฒํผ์ ๋๋ฅด์ธ์)")
# ๋ฒํผ ํด๋ฆญ ์ด๋ฒคํธ (์ฑํ
์ฐฝ์ ํ์คํ ๋ฆฌ๋ฅผ ๊ฐ์ ธ์ ํ์ผ๋ก ๋ณํ)
download_btn.click(
fn=download_chat_history,
inputs=[chat_interface.chatbot],
outputs=[download_file]
)
if __name__ == "__main__":
demo.launch() |