Spaces:
Runtime error
Runtime error
| # import basics | |
| import os | |
| from dotenv import load_dotenv | |
| # import streamlit | |
| import streamlit as st | |
| from PIL import Image | |
| import json | |
| # import langchain | |
| from langchain.agents import AgentExecutor | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain.chat_models import init_chat_model | |
| from langchain_core.messages import SystemMessage, AIMessage, HumanMessage | |
| from langchain.agents import create_tool_calling_agent | |
| from langchain import hub | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_community.vectorstores import SupabaseVectorStore | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_core.tools import tool | |
| from langchain.callbacks.tracers.langchain import LangChainTracer | |
| from langchain.callbacks.tracers.schemas import Run | |
| from langchain_community.document_loaders import PyPDFLoader, TextLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
| # import supabase db | |
| from supabase.client import Client, create_client | |
| # load environment variables | |
| load_dotenv() | |
| # initiating supabase | |
| supabase_url = os.environ.get("SUPABASE_URL") | |
| supabase_key = os.environ.get("SUPABASE_SERVICE_KEY") | |
| supabase: Client = create_client(supabase_url, supabase_key) | |
| # initiating embeddings model | |
| embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | |
| # initiating vector store | |
| vector_store = SupabaseVectorStore( | |
| embedding=embeddings, | |
| client=supabase, | |
| table_name="documents", | |
| query_name="match_documents", | |
| ) | |
| # initiating llm | |
| llm = ChatOpenAI(model="gpt-4.1",temperature=1) | |
| # pulling prompt from hub | |
| prompt = hub.pull("jackfengrag/myrag") | |
| # Store for captured documents | |
| if "retrieved_documents" not in st.session_state: | |
| st.session_state.retrieved_documents = {} | |
| # Custom callback handler to capture retrieved documents | |
| class DocumentCaptureHandler: | |
| def __init__(self): | |
| self.captured_docs = [] | |
| def capture_docs(self, docs): | |
| self.captured_docs.extend(docs) | |
| document_handler = DocumentCaptureHandler() | |
| # creating the retriever tool | |
| def retrieve(query: str): | |
| """Retrieve information related to a query.""" | |
| retrieved_docs = vector_store.similarity_search(query, k=5) | |
| # Capture the documents for display | |
| document_handler.capture_docs(retrieved_docs) | |
| serialized = "\n\n".join( | |
| (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}") | |
| for doc in retrieved_docs | |
| ) | |
| return serialized, retrieved_docs | |
| # combining all tools | |
| tools = [retrieve] | |
| # initiating the agent | |
| agent = create_tool_calling_agent(llm, tools, prompt) | |
| # create the agent executor | |
| agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) | |
| # Function to format document for display | |
| def format_source_document(doc, index): | |
| source = doc.metadata.get("source", "Unknown source") | |
| # Extract filename from source path | |
| if isinstance(source, str) and "/" in source: | |
| source = source.split("/")[-1] | |
| # Format source document for display with everything in black color | |
| return f""" | |
| <div style="padding: 10px; margin-bottom: 10px; border-radius: 5px; background-color: #f5f5f5; color: #000000;"> | |
| <p><strong style="color: #000000;">Source {index+1}: {source}</strong></p> | |
| <p style="font-size: 0.9em; color: #000000;">{doc.page_content[:300]}...</p> | |
| </div> | |
| """ | |
| # initiating streamlit app with a new logo | |
| st.set_page_config( | |
| page_title="LangChain RAG Assistant", | |
| page_icon="π§ ", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom styling for the app | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| font-size: 2.5rem; | |
| color: #4CAF50; | |
| text-align: center; | |
| margin-bottom: 1rem; | |
| } | |
| .subheader { | |
| font-size: 1.2rem; | |
| color: #555; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .source-title { | |
| font-weight: bold; | |
| margin-bottom: 5px; | |
| } | |
| .source-content { | |
| font-size: 0.9em; | |
| color: #333; | |
| padding-left: 10px; | |
| border-left: 2px solid #4CAF50; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Create sidebar for settings | |
| with st.sidebar: | |
| st.markdown("## Settings") | |
| show_sources = st.checkbox("Show source documents", value=True) | |
| st.markdown("---") | |
| st.markdown("## About") | |
| st.markdown("This assistant uses Agentic RAG (Retrieval-Augmented Generation) to provide information about LangChain by default, With any technical document you upload.") | |
| st.markdown("It retrieves relevant documents from a vector database and uses them to generate responses.") | |
| # Display custom header with new logo | |
| st.markdown("<h1 class='main-header'>π§ Technical Document Knowledge Assistant</h1>", unsafe_allow_html=True) | |
| st.markdown("<p class='subheader'>Powered by Agentic RAG Technology</p>", unsafe_allow_html=True) | |
| # Add a horizontal line | |
| st.markdown("---") | |
| # initialize chat history | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # initialize sources history | |
| if "sources_history" not in st.session_state: | |
| st.session_state.sources_history = [] | |
| # display chat messages from history on app rerun | |
| for i, message in enumerate(st.session_state.messages): | |
| if isinstance(message, HumanMessage): | |
| with st.chat_message("user"): | |
| st.markdown(message.content) | |
| elif isinstance(message, AIMessage): | |
| with st.chat_message("assistant"): | |
| st.markdown(message.content) | |
| # Display sources if available and option is enabled | |
| if show_sources and i//2 < len(st.session_state.sources_history): | |
| sources = st.session_state.sources_history[i//2] | |
| if sources: | |
| with st.expander("π View Source Documents", expanded=False): | |
| for j, doc in enumerate(sources): | |
| st.markdown(format_source_document(doc, j), unsafe_allow_html=True) | |
| # --- Document Upload and Ingestion UI --- | |
| st.markdown("## π Upload and Ingest Documents") | |
| uploaded_files = st.file_uploader( | |
| "Upload PDF, TXT, or Markdown (MD) files to ingest into the knowledge base:", | |
| type=["pdf", "txt", "md"], | |
| accept_multiple_files=True, | |
| key="file_uploader" | |
| ) | |
| if uploaded_files: | |
| for uploaded_file in uploaded_files: | |
| file_name = uploaded_file.name | |
| file_path = os.path.join("documents", file_name) | |
| # Save uploaded file to disk | |
| with open(file_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| # Load and split document | |
| if file_name.lower().endswith(".pdf"): | |
| loader = PyPDFLoader(file_path) | |
| elif file_name.lower().endswith(".txt"): | |
| loader = TextLoader(file_path) | |
| elif file_name.lower().endswith(".md"): | |
| loader = UnstructuredMarkdownLoader(file_path) | |
| else: | |
| st.warning(f"Unsupported file type: {file_name}") | |
| continue | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| docs = text_splitter.split_documents(documents) | |
| # Ingest into vector store in batches of 10 | |
| batch_size = 3 | |
| for doc in docs: | |
| doc.page_content = doc.page_content.replace('\u0000', '') | |
| cleaned_docs = docs | |
| num_batches = (len(cleaned_docs) + batch_size - 1) // batch_size | |
| for batch_idx in range(num_batches): | |
| batch_docs = cleaned_docs[batch_idx*batch_size:(batch_idx+1)*batch_size] | |
| retry_count = 0 | |
| while retry_count < 3: | |
| try: | |
| SupabaseVectorStore.from_documents( | |
| batch_docs, | |
| embeddings, | |
| client=supabase, | |
| table_name="documents", | |
| query_name="rag_query", | |
| chunk_size=100, | |
| ) | |
| if retry_count > 0: | |
| st.info(f"Batch {batch_idx+1} for {file_name} succeeded after {retry_count} retries.") | |
| break # Success, exit retry loop | |
| except Exception as e: | |
| error_message = str(e) | |
| # Retry on SSL errors | |
| if any(kw in error_message.lower() for kw in ["ssl", "tls", "certificate", "handshake", "bad record"]): | |
| retry_count += 1 | |
| st.warning(f"SSL error on batch {batch_idx+1} for {file_name}, retrying ({retry_count}/3)...") | |
| time.sleep(1) | |
| continue | |
| # Skip on duplicate errors | |
| if any(kw in error_message.lower() for kw in ["duplicate", "already exists", "unique constraint", "unique violation", "conflict"]): | |
| st.warning(f"Duplicate detected in batch {batch_idx+1} for {file_name}, skipping batch: {error_message}") | |
| break | |
| # Other errors: show and skip batch | |
| st.error(f"Error in batch {batch_idx+1} for {file_name}: {error_message}") | |
| break | |
| else: | |
| st.error(f"Failed to ingest batch {batch_idx+1} for {file_name} after 3 SSL retries.") | |
| st.success(f"Ingested {file_name} in {num_batches} batches!") | |
| # create the bar where we can type messages | |
| user_question = st.chat_input("Ask me anything about LangChain...") | |
| # did the user submit a prompt? | |
| if user_question: | |
| # Reset document handler for new query | |
| document_handler.captured_docs = [] | |
| # add the message from the user (prompt) to the screen with streamlit | |
| with st.chat_message("user"): | |
| st.markdown(user_question) | |
| st.session_state.messages.append(HumanMessage(user_question)) | |
| # Show spinner while agent is generating a response | |
| with st.spinner("Thinking... Generating response..."): | |
| # invoking the agent | |
| result = agent_executor.invoke({"input": user_question, "chat_history":st.session_state.messages}) | |
| ai_message = result["output"] | |
| # Store the captured documents for this response | |
| st.session_state.sources_history.append(document_handler.captured_docs) | |
| # adding the response from the llm to the screen (and chat) | |
| with st.chat_message("assistant"): | |
| import re | |
| def render_markdown_with_codeblocks(text): | |
| code_block_pattern = r"```([\w\+\-]*)\n([\s\S]*?)```" | |
| related_code_pattern = r"<related_code>([\s\S]*?)</related_code>" | |
| last_end = 0 | |
| # Find all code blocks (triple backtick and related_code) in order | |
| matches = [] | |
| for m in re.finditer(code_block_pattern, text): | |
| matches.append((m.start(), m.end(), 'backtick', m)) | |
| for m in re.finditer(related_code_pattern, text): | |
| matches.append((m.start(), m.end(), 'related_code', m)) | |
| matches.sort() # sort by start position | |
| for match in matches: | |
| start, end, kind, m = match | |
| if start > last_end: | |
| st.markdown(text[last_end:start]) | |
| if kind == 'backtick': | |
| code_lang = m.group(1) or None | |
| code_content = m.group(2) | |
| st.code(code_content, language=code_lang) | |
| elif kind == 'related_code': | |
| code_content = m.group(1) | |
| st.code(code_content) | |
| last_end = end | |
| if last_end < len(text): | |
| st.markdown(text[last_end:]) | |
| render_markdown_with_codeblocks(ai_message) | |
| st.session_state.messages.append(AIMessage(ai_message)) | |
| # Display sources if option is enabled | |
| if show_sources and document_handler.captured_docs: | |
| with st.expander("π View Source Documents", expanded=True): | |
| for i, doc in enumerate(document_handler.captured_docs): | |
| st.markdown(format_source_document(doc, i), unsafe_allow_html=True) | |