Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| from langchain.embeddings import CacheBackedEmbeddings | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain.storage import LocalFileStore | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_community.chat_models import ChatOllama | |
| from langchain_community.document_loaders import ( | |
| NotebookLoader, | |
| TextLoader, | |
| DataFrameLoader, | |
| ) | |
| from langchain_community.document_loaders.generic import GenericLoader | |
| from langchain_community.document_loaders.parsers.language.language_parser import ( | |
| LanguageParser, | |
| ) | |
| from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
| from langchain_community.vectorstores import FAISS, Chroma | |
| from langchain_core.callbacks.manager import CallbackManager | |
| from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.runnables import ( | |
| ConfigurableField, | |
| RunnablePassthrough, | |
| RunnableLambda, | |
| ) | |
| from langchain_google_genai import GoogleGenerativeAI | |
| from langchain_groq import ChatGroq | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from langchain_text_splitters import Language, RecursiveCharacterTextSplitter | |
| from langchain_cohere import CohereRerank | |
| from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
| from langchain_community.document_transformers import LongContextReorder | |
| # Load environment variables | |
| load_dotenv() | |
| # Repository directories | |
| repo_root_dir = "./docs/langchain" | |
| repo_dirs = [ | |
| "libs/core/langchain_core", | |
| "libs/community/langchain_community", | |
| "libs/experimental/langchain_experimental", | |
| "libs/partners", | |
| "libs/cookbook", | |
| ] | |
| repo_dirs = [os.path.join(repo_root_dir, repo) for repo in repo_dirs] | |
| # Load Python documents | |
| py_documents = [] | |
| for path in repo_dirs: | |
| py_loader = GenericLoader.from_filesystem( | |
| path, | |
| glob="**/*", | |
| suffixes=[".py"], | |
| parser=LanguageParser(language=Language.PYTHON, parser_threshold=30), | |
| ) | |
| py_documents.extend(py_loader.load()) | |
| print(f".py ํ์ผ์ ๊ฐ์: {len(py_documents)}") | |
| # Load Markdown documents | |
| mdx_documents = [] | |
| for dirpath, _, filenames in os.walk(repo_root_dir): | |
| for file in filenames: | |
| if file.endswith(".mdx") and "*venv/" not in dirpath: | |
| try: | |
| mdx_loader = TextLoader(os.path.join(dirpath, file), encoding="utf-8") | |
| mdx_documents.extend(mdx_loader.load()) | |
| except Exception: | |
| pass | |
| print(f".mdx ํ์ผ์ ๊ฐ์: {len(mdx_documents)}") | |
| # Load Jupyter Notebook documents | |
| ipynb_documents = [] | |
| for dirpath, _, filenames in os.walk(repo_root_dir): | |
| for file in filenames: | |
| if file.endswith(".ipynb") and "*venv/" not in dirpath: | |
| try: | |
| ipynb_loader = NotebookLoader( | |
| os.path.join(dirpath, file), | |
| include_outputs=True, | |
| max_output_length=20, | |
| remove_newline=True, | |
| ) | |
| ipynb_documents.extend(ipynb_loader.load()) | |
| except Exception: | |
| pass | |
| print(f".ipynb ํ์ผ์ ๊ฐ์: {len(ipynb_documents)}") | |
| ## wikidocs | |
| df = pd.read_parquet("./docs/wikidocs_14314.parquet") | |
| loader = DataFrameLoader(df, page_content_column="content") | |
| wiki_documents = loader.load() | |
| # Split documents into chunks | |
| def split_documents(documents, language, chunk_size=2000, chunk_overlap=200): | |
| splitter = RecursiveCharacterTextSplitter.from_language( | |
| language=language, chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
| ) | |
| return splitter.split_documents(documents) | |
| py_docs = split_documents(py_documents, Language.PYTHON) | |
| mdx_docs = split_documents(mdx_documents, Language.MARKDOWN) | |
| ipynb_docs = split_documents(ipynb_documents, Language.PYTHON) | |
| wiki_docs = split_documents(wiki_documents, Language.MARKDOWN) | |
| print(f"๋ถํ ๋ .py ๋ฌธ์์ ๊ฐ์: {len(py_docs)}") | |
| print(f"๋ถํ ๋ .mdx ๋ฌธ์์ ๊ฐ์: {len(mdx_docs)}") | |
| print(f"๋ถํ ๋ .ipynb ๋ฌธ์์ ๊ฐ์: {len(ipynb_docs)}") | |
| print(f"๋ถํ ๋ wiki ๋ฌธ์์ ๊ฐ์: {len(wiki_docs)}") | |
| combined_documents = py_docs + mdx_docs + ipynb_docs + wiki_docs | |
| print(f"์ด ๋ํ๋จผํธ ๊ฐ์: {len(combined_documents)}") | |
| # Define the device setting function | |
| def get_device(): | |
| if torch.cuda.is_available(): | |
| return "cuda:0" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" | |
| else: | |
| return "cpu" | |
| # Use the function to set the device in model_kwargs | |
| device = get_device() | |
| # Initialize embeddings and cache | |
| store = LocalFileStore("~/.cache/embedding") | |
| embeddings = HuggingFaceBgeEmbeddings( | |
| model_name="BAAI/bge-m3", | |
| model_kwargs={"device": device}, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| cached_embeddings = CacheBackedEmbeddings.from_bytes_store( | |
| embeddings, store, namespace=embeddings.model_name | |
| ) | |
| # Create and save FAISS index | |
| FAISS_DB_INDEX = "./langchain_faiss" | |
| if not os.path.exists(FAISS_DB_INDEX): | |
| faiss_db = FAISS.from_documents( | |
| documents=combined_documents, | |
| embedding=cached_embeddings, | |
| ) | |
| faiss_db.save_local(folder_path=FAISS_DB_INDEX) | |
| # Create and save Chroma index | |
| CHROMA_DB_INDEX = "./langchain_chroma" | |
| if not os.path.exists(CHROMA_DB_INDEX): | |
| chroma_db = Chroma.from_documents( | |
| documents=combined_documents, | |
| embedding=cached_embeddings, | |
| persist_directory=CHROMA_DB_INDEX, | |
| ) | |
| # load vectorstore | |
| faiss_db = FAISS.load_local( | |
| FAISS_DB_INDEX, cached_embeddings, allow_dangerous_deserialization=True | |
| ) | |
| chroma_db = Chroma( | |
| embedding_function=cached_embeddings, | |
| persist_directory=CHROMA_DB_INDEX, | |
| ) | |
| # Create retrievers | |
| faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10}) | |
| chroma_retriever = chroma_db.as_retriever( | |
| search_type="similarity", search_kwargs={"k": 10} | |
| ) | |
| bm25_retriever = BM25Retriever.from_documents(combined_documents) | |
| bm25_retriever.k = 10 | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=[ | |
| bm25_retriever.with_config(run_name="bm25"), | |
| faiss_retriever.with_config(run_name="faiss"), | |
| chroma_retriever.with_config(run_name="chroma"), | |
| ], | |
| weights=[0.4, 0.3, 0.3], | |
| ) | |
| compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=10) | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=compressor, | |
| base_retriever=ensemble_retriever, | |
| ) | |
| # Create prompt template | |
| prompt = PromptTemplate.from_template( | |
| """๋น์ ์ 20๋ ์ฐจ AI ๊ฐ๋ฐ์์ ๋๋ค. ๋น์ ์ ์๋ฌด๋ ์ฃผ์ด์ง ์ง๋ฌธ์ ๋ํด ์ฃผ์ด์ง ๋ฌธ์์ ์ ๋ณด๋ฅผ ์ต๋ํ ํ์ฉํ์ฌ ๋ต๋ณํ๋ ๊ฒ์ ๋๋ค. | |
| ๋ฌธ์๋ Python ์ฝ๋์ ๋ํ ์ ๋ณด๋ฅผ ํฌํจํ๊ณ ์์ผ๋ฏ๋ก, ๋ต๋ณ ์์ฑ ์ Python ์ฝ๋ ์ค๋ํซ๊ณผ ๊ตฌ์ฒด์ ์ธ ์ค๋ช ์ ํฌํจํด ์ฃผ์ธ์. | |
| ๋ต๋ณ์ ๊ฐ๋ฅํ ํ ์์ธํ๊ณ ๋ช ํํ๊ฒ ์์ฑํ๋ฉฐ, ์ดํดํ๊ธฐ ์ฌ์ด ํ๊ธ๋ก ์์ฑํด ์ฃผ์ธ์. | |
| ํ์ฌ ์ฃผ์ด์ง ๋ฌธ์์์ ๋ต๋ณ์ ์ฐพ์ ์ ์๋ ๊ฒฝ์ฐ, "ํ์ฌ ์ ๊ณต๋ ์ง๋ฌธ๋ง์ผ๋ก๋ ์ ํํ ๋ต๋ณ์ ๋๋ฆฌ๊ธฐ ์ด๋ ค์์. ์ถ๊ฐ ์ ๋ณด๋ฅผ ์ฃผ์๋ฉด ๋ ๋์์ ๋๋ฆด ์ ์์ ๊ฒ ๊ฐ์ต๋๋ค. ์ธ์ ๋ ์ง ์ง๋ฌธํด ์ฃผ์ธ์!"๋ผ๊ณ ๋ต๋ณํด ์ฃผ์ธ์. | |
| ๊ฐ ๋ต๋ณ์ ์ถ์ฒ(source)๋ฅผ ๋ฐ๋์ ํ๊ธฐํด ์ฃผ์ธ์. | |
| # ์ฐธ๊ณ ๋ฌธ์: | |
| {context} | |
| # ์ง๋ฌธ: | |
| {question} | |
| # ๋ต๋ณ: | |
| ์ถ์ฒ: | |
| - source1 | |
| - source2 | |
| - ... | |
| """ | |
| ) | |
| # Define callback handler for streaming | |
| class StreamCallback(BaseCallbackHandler): | |
| def on_llm_new_token(self, token: str, **kwargs): | |
| print(token, end="", flush=True) | |
| streaming = os.getenv("STREAMING", "true") == "true" | |
| print("STREAMING", streaming) | |
| # Initialize LLMs with configuration | |
| llm = ChatOpenAI( | |
| model="gpt-4o", | |
| temperature=0, | |
| streaming=streaming, | |
| callbacks=[StreamCallback()], | |
| ).configurable_alternatives( | |
| ConfigurableField(id="llm"), | |
| default_key="gpt4", | |
| claude=ChatAnthropic( | |
| model="claude-3-opus-20240229", | |
| temperature=0, | |
| streaming=True, | |
| callbacks=[StreamCallback()], | |
| ), | |
| gpt3=ChatOpenAI( | |
| model="gpt-3.5-turbo", | |
| temperature=0, | |
| streaming=True, | |
| callbacks=[StreamCallback()], | |
| ), | |
| gemini=GoogleGenerativeAI( | |
| model="gemini-1.5-flash", | |
| temperature=0, | |
| streaming=True, | |
| callbacks=[StreamCallback()], | |
| ), | |
| llama3=ChatGroq( | |
| model_name="llama3-70b-8192", | |
| temperature=0, | |
| streaming=True, | |
| callbacks=[StreamCallback()], | |
| ), | |
| ollama=ChatOllama( | |
| model="EEVE-Korean-10.8B:long", | |
| callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), | |
| ), | |
| ) | |
| # Create retrieval-augmented generation chain | |
| rag_chain = ( | |
| { | |
| "context": compression_retriever | |
| | RunnableLambda(LongContextReorder().transform_documents), | |
| "question": RunnablePassthrough(), | |
| } | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| model_key = os.getenv("MODEL_KEY", "gemini") | |
| print("MODEL_KEY", model_key) | |
| def respond_stream( | |
| message, | |
| history: list[tuple[str, str]], | |
| ): | |
| response = "" | |
| for chunk in rag_chain.with_config(configurable={"llm": model_key}).stream(message): | |
| response += chunk | |
| yield response | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| ): | |
| return rag_chain.with_config(configurable={"llm": model_key}).invoke(message) | |
| """ | |
| For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
| """ | |
| demo = gr.ChatInterface( | |
| respond_stream if streaming else respond, | |
| title="๋ญ์ฒด์ธ์ ๋ํด์ ๋ฌผ์ด๋ณด์ธ์!", | |
| description="์๋ ํ์ธ์!\n์ ๋ ๋ญ์ฒด์ธ์ ๋ํ ์ธ๊ณต์ง๋ฅ QA๋ด์ ๋๋ค. ๋ญ์ฒด์ธ์ ๋ํด ๊น์ ์ง์์ ๊ฐ์ง๊ณ ์์ด์. ๋ญ์ฒด์ธ ๊ฐ๋ฐ์ ๊ดํ ๋์์ด ํ์ํ์๋ฉด ์ธ์ ๋ ์ง ์ง๋ฌธํด์ฃผ์ธ์!", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |