Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import platform | |
| from langchain_community.document_loaders import ObsidianLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter, Language | |
| from langchain.embeddings import CacheBackedEmbeddings | |
| from langchain.storage import LocalFileStore | |
| from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain_cohere import CohereRerank | |
| from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
| from langchain_core.runnables import ConfigurableField, RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_groq import ChatGroq | |
| from langchain_google_genai import GoogleGenerativeAI | |
| from prompt_template import PROMPT_TEMPLATE | |
| DIRECTORIES = ["./docs/obsidian-help", "./docs/obsidian-developer"] | |
| FAISS_DB_INDEX = "db_index" | |
| def load_and_process_documents(directories): | |
| md_docs = [] | |
| for directory in directories: | |
| try: | |
| loader = ObsidianLoader(directory, encoding="utf-8") | |
| md_docs.extend(loader.load()) | |
| except Exception: | |
| pass | |
| md_splitter = RecursiveCharacterTextSplitter.from_language( | |
| language=Language.MARKDOWN, | |
| chunk_size=2000, | |
| chunk_overlap=200, | |
| ) | |
| return md_splitter.split_documents(md_docs) | |
| def setup_retrieval_system(splitted_docs): | |
| if platform.system() == "Darwin": | |
| model_kwargs = {"device": "mps"} | |
| else: | |
| model_kwargs = {"device": "cpu"} | |
| embeddings = HuggingFaceBgeEmbeddings( | |
| model_name="BAAI/bge-m3", | |
| model_kwargs=model_kwargs, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| store = LocalFileStore("./.cache/") | |
| cached_embeddings = CacheBackedEmbeddings.from_bytes_store( | |
| embeddings, | |
| store, | |
| namespace=embeddings.model_name, | |
| ) | |
| if os.path.exists(FAISS_DB_INDEX): | |
| db = FAISS.load_local( | |
| FAISS_DB_INDEX, | |
| cached_embeddings, | |
| allow_dangerous_deserialization=True, | |
| ) | |
| else: | |
| db = FAISS.from_documents(splitted_docs, cached_embeddings) | |
| db.save_local(folder_path=FAISS_DB_INDEX) | |
| faiss_retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 10}) | |
| bm25_retriever = BM25Retriever.from_documents(splitted_docs) | |
| bm25_retriever.k = 10 | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=[bm25_retriever, faiss_retriever], | |
| weights=[0.5, 0.5], | |
| search_type="mmr", | |
| ) | |
| compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5) | |
| return ContextualCompressionRetriever( | |
| base_compressor=compressor, | |
| base_retriever=ensemble_retriever, | |
| ) | |
| def setup_language_model(): | |
| return ChatGroq( | |
| model_name="llama3-70b-8192", | |
| temperature=0, | |
| ).configurable_alternatives( | |
| ConfigurableField(id="llm"), | |
| default_key="llama3", | |
| gemini=GoogleGenerativeAI( | |
| model="gemini-pro", | |
| temperature=0, | |
| ), | |
| ) | |
| def format_docs(docs): | |
| formatted_docs = [] | |
| for doc in docs: | |
| formatted_doc = f"Page Content:\n{doc.page_content}\n" | |
| if doc.metadata.get("source"): | |
| formatted_doc += f"Source: {doc.metadata['source']}\n" | |
| formatted_docs.append(formatted_doc) | |
| return "\n---\n".join(formatted_docs) | |
| def main(): | |
| splitted_docs = load_and_process_documents(DIRECTORIES) | |
| compression_retriever = setup_retrieval_system(splitted_docs) | |
| llm = setup_language_model() | |
| rag_chain = ( | |
| {"context": compression_retriever | format_docs, "question": RunnablePassthrough()} | |
| | PROMPT_TEMPLATE | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| def predict(message, history=None): | |
| return rag_chain.invoke(message) | |
| gr.ChatInterface( | |
| predict, | |
| title="옵시디언 노트앱 및 플러그인 개발에 대해서 물어보세요!", | |
| description="안녕하세요!\n저는 옵시디언 노트앱과 플러그인 개발에 대한 인공지능 QA봇입니다. 옵시디언 노트앱의 사용법, 고급 기능, 플러그인 및 테마 개발에 대해 깊은 지식을 가지고 있어요. 문서 작업, 정보 정리 또는 개발에 관한 도움이 필요하시면 언제든지 질문해주세요!", | |
| ).launch() | |
| if __name__ == "__main__": | |
| main() | |