Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_chroma import Chroma | |
| from langchain_core.prompts import ChatPromptTemplate | |
| # 1. ๋ฌธ์ ๋ก๋ ๋ฐ ๋ฒกํฐ DB ๊ตฌ์ถ (์๋ฒ ๊ตฌ๋ ์ 1ํ ๊ณ ์ ) | |
| loader = PyPDFLoader("Maximizing Muscle Hypertrophy.pdf") | |
| pages = loader.load_and_split() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| splits = text_splitter.split_documents(pages) | |
| embeddings = GoogleGenerativeAIEmbeddings(model="gemini-embedding-001") | |
| vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) | |
| # ๋ฏธ์ 3: ๋๋ฉ์ธ ๋ง์ถค ์์คํ ํ๋กฌํํธ | |
| SYSTEM_PROMPT = """๋น์ ์ ์คํฌ์ธ ์์ํ ๋ฐ ๊ทผ๋น๋(Muscle Hypertrophy) ํ๋ จ ๋ถ์ผ์ ์ต๊ณ ๊ถ์์์ด์ ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ ์ ๋ฌธ๊ฐ์ ๋๋ค. | |
| ์ ๊ณต๋ [๋ ผ๋ฌธ ์ปจํ ์คํธ]๋ฅผ ๋ฐํ์ผ๋ก ์ฌ์ฉ์์ ์ง๋ฌธ์ ์ ๋ฌธ์ ์ด๊ณ ๋ช ํํ๋ฉฐ ๊ฐ๊ด์ ์ธ ์ด์กฐ๋ก ๋ต๋ณํ์ธ์. | |
| [์ ์ฝ ์กฐ๊ฑด] | |
| 1. ๋ฐ๋์ ์ ๊ณต๋ ์ปจํ ์คํธ ๋ด์ ์ ๋ณด๋ง์ ์ฌ์ฉํ์ฌ ๋ต๋ณํ์ธ์. | |
| 2. ๋ ผ๋ฌธ์ ์๋ ๋ด์ฉ์ ์ง๋ฌธํ๋ฉด "ํด๋น ๋ด์ฉ์ ์ ๊ณต๋ ๋ ผ๋ฌธ์์ ํ์ธํ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ช ํํ ์ ์ ๊ทธ์ผ์ธ์. | |
| 3. ๊ทผ์ก ์ฑ์ฅ ๊ธฐ์ ์ด๋ ํ๋ จ๋ฒ์ ์ค๋ช ํ ๋๋ ์ผ๋ฐ์ธ๋ ์ดํดํ๊ธฐ ์ฝ๊ฒ ๋จ๊ณ๋ณ๋ก ๊ตฌ์กฐํํ์ฌ ์ค๋ช ํ์ธ์. | |
| 4. ๋ชจ๋ ๋ต๋ณ์ ํ๊ตญ์ด๋ก ์์ฑํ๋ฉฐ, ์ฃผ์ ์ํ ๋ฐ ์ด๋ํ ์ ๋ฌธ ์ฉ์ด๋ ๊ดํธ ์์ ์๋ฌธ์ ๋ณ๊ธฐํ์ธ์ (์: ๋จ๋ฐฑ์ง ํฉ์ฑ(Protein Synthesis)). | |
| [๋ ผ๋ฌธ ์ปจํ ์คํธ] | |
| {context}""" | |
| qa_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", SYSTEM_PROMPT), | |
| ("placeholder", "{chat_history}"), | |
| ("human", "{input}"), | |
| ]) | |
| # Gradio์ ๋ํ ๊ธฐ๋ก ํ์์ LangChain์ด ์ดํดํ ์ ์๊ฒ ๋ณํํ๋ ํฌํผ ํจ์ | |
| def format_history(history): | |
| formatted = [] | |
| for user_msg, ai_msg in history: | |
| formatted.append(("human", user_msg)) | |
| formatted.append(("ai", ai_msg)) | |
| return formatted | |
| # ๋ฏธ์ 1, 2, 5 ํตํฉ: ์คํธ๋ฆฌ๋ฐ, ๋์ ์ค์ , ์ถ์ฒ ํ์ฑ | |
| def chat_response(message, history, temperature, k, model_name): | |
| # ๋ฏธ์ 2: UI์์ ๋๊ฒจ๋ฐ์ k ๊ฐ์ผ๋ก ๊ฒ์ ๋ฒ์ ๋์ ์กฐ์ | |
| docs = vectorstore.similarity_search(message, k=k) | |
| context = "\n\n".join(doc.page_content for doc in docs) | |
| # ๋ฏธ์ 2: UI์์ ๋๊ฒจ๋ฐ์ ๋ชจ๋ธ๊ณผ ์จ๋๋ก LLM ๋์ ์์ฑ | |
| llm = ChatGoogleGenerativeAI(model=model_name, temperature=temperature) | |
| # ํ๋กฌํํธ ์กฐ๋ฆฝ | |
| prompt_value = qa_prompt.invoke({ | |
| "context": context, | |
| "chat_history": format_history(history), | |
| "input": message | |
| }) | |
| partial_message = "" | |
| # ๋ฏธ์ 5: llm.stream()์ ํ์ฉํ ์ค์๊ฐ ์คํธ๋ฆฌ๋ฐ ์ถ๋ ฅ | |
| for chunk in llm.stream(prompt_value): | |
| partial_message += chunk.content | |
| yield partial_message # ๊ธ์๊ฐ ์์ฑ๋ ๋๋ง๋ค UI๋ก ๋ฐ์ด๋ | |
| # ๋ฏธ์ 1: PyPDFLoader ๋ฉํ๋ฐ์ดํฐ์์ ์ถ์ฒ ๋ฐ ํ์ด์ง ์ถ์ถ (page๋ 0๋ถํฐ ์์ํ๋ฏ๋ก +1) | |
| sources = [] | |
| for doc in docs: | |
| source_file = os.path.basename(doc.metadata.get('source', 'Unknown')) | |
| page_num = doc.metadata.get('page', 0) + 1 | |
| sources.append(f"{source_file} (p.{page_num})") | |
| # ๋ฆฌ์คํธ ์ค๋ณต ์ ๊ฑฐ ํ ์ต์ข ํ ์คํธ ์กฐ๋ฆฝ | |
| unique_sources = list(dict.fromkeys(sources)) | |
| source_str = "\n\n๐ **์ถ์ฒ:** " + ", ".join(unique_sources) | |
| # ์ต์ข ์ ์ผ๋ก ๋ต๋ณ ๋์ ์ถ์ฒ๋ฅผ ๋ง๋ถ์ฌ์ ์ ์ก | |
| yield partial_message + source_str | |
| # ๋ฏธ์ 4: ๋ํ ๋ด์ญ ๋ค์ด๋ก๋ ํ์ผ ์์ฑ ํจ์ | |
| def download_chat_history(history): | |
| file_path = "chat_history.txt" | |
| with open(file_path, "w", encoding="utf-8") as f: | |
| for user_msg, ai_msg in history: | |
| f.write(f"๐งโ๐ป ์ฌ์ฉ์: {user_msg}\n") | |
| f.write(f"๐ค AI: {ai_msg}\n") | |
| f.write("-" * 50 + "\n") | |
| return file_path | |
| # UI ๋ ์ด์์ ๊ตฌ์ฑ | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## ๐ช ๊ทผ๋น๋ ๊ทน๋ํ ๋ ผ๋ฌธ Q&A ๋ด (Pro Version)") | |
| # ๋ฏธ์ 2: ์ ์ ์ ์๋ ์ค์ ํจ๋ | |
| with gr.Accordion("โ๏ธ ์ฑ๋ด ์์ธ ์ค์ ", open=False): | |
| with gr.Row(): | |
| model_dd = gr.Dropdown(choices=["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"], value="gemini-2.0-flash", label="๐ค ๋ชจ๋ธ ์ ํ") | |
| temp_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="๐ก๏ธ Temperature (์ฐฝ์์ฑ/ํ๊ฐ ์กฐ์ )") | |
| k_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="๐ ์ฐธ๊ณ ํ ๋ฌธ์ ์กฐ๊ฐ ์ (k)") | |
| # ํต์ฌ ์ฑ๋ด ์ธํฐํ์ด์ค (์ค์ ํจ๋์ ๊ฐ๋ค์ additional_inputs๋ก ์ฐ๊ฒฐ) | |
| chat_interface = gr.ChatInterface( | |
| fn=chat_response, | |
| additional_inputs=[temp_slider, k_slider, model_dd], | |
| chatbot=gr.Chatbot(height=500), | |
| title="", | |
| description="'Maximizing Muscle Hypertrophy' ๋ ผ๋ฌธ ๋ด์ฉ์ ๋ฐํ์ผ๋ก ๊ทผ์ฑ์ฅ ๋ฉ์ปค๋์ฆ์ ์ง๋ฌธํด ๋ณด์ธ์." | |
| ) | |
| # ๋ฏธ์ 4: ๋ํ ๋ด์ญ ๋ค์ด๋ก๋ ์์ญ | |
| with gr.Row(): | |
| download_btn = gr.Button("๐พ ํ์ฌ ๋ํ ๋ด์ญ ์ ์ฅ ๋ฐ ๋ค์ด๋ก๋", variant="primary") | |
| download_file = gr.File(label="๋ค์ด๋ก๋ ์ค๋น ์๋ฃ (๋ฒํผ์ ๋๋ฅด์ธ์)") | |
| # ๋ฒํผ ํด๋ฆญ ์ด๋ฒคํธ (์ฑํ ์ฐฝ์ ํ์คํ ๋ฆฌ๋ฅผ ๊ฐ์ ธ์ ํ์ผ๋ก ๋ณํ) | |
| download_btn.click( | |
| fn=download_chat_history, | |
| inputs=[chat_interface.chatbot], | |
| outputs=[download_file] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |