Spaces:
Runtime error
Runtime error
| import os | |
| import dotenv | |
| import logging | |
| import gradio as gr | |
| import glob | |
| import concurrent.futures | |
| from typing import List, Any | |
| from tqdm import tqdm | |
| # LangChain imports | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.pydantic_v1 import BaseModel, Field | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain import hub | |
| from langgraph.graph import END, StateGraph, START | |
| from typing_extensions import TypedDict | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.callbacks import get_openai_callback | |
| # Load environment variables | |
| dotenv.load_dotenv() | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Check if OpenAI API key is set | |
| if os.getenv("OPENAI_API_KEY") is None: | |
| raise ValueError("OPENAI_API_KEY is not set in .env file") | |
| # Initialize Retriever for all Markdown files in /MarkdownOutput | |
| def initialize_retriever(): | |
| from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| # Find all markdown files in /MarkdownOutput | |
| markdown_files = glob.glob("./MarkdownOutput/**/*.md", recursive=True) | |
| logger.info(f"Found {len(markdown_files)} markdown files in ./MarkdownOutput.") | |
| # Load and split all markdown documents | |
| all_doc_splits = [] | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| for idx, md_path in enumerate(markdown_files, 1): | |
| logger.info(f"Loading and splitting file {idx}/{len(markdown_files)}: {md_path}") | |
| loader = UnstructuredMarkdownLoader(md_path) | |
| docs = loader.load() | |
| splits = text_splitter.split_documents(docs) | |
| all_doc_splits.extend(splits) | |
| logger.info(f"File {md_path} loaded and split into {len(splits)} chunks.") | |
| logger.info(f"Total document splits: {len(all_doc_splits)}. Creating vector store...") | |
| # Create vector store | |
| vectorstore = Chroma.from_documents( | |
| documents=all_doc_splits, | |
| collection_name="rag-chroma", | |
| embedding=OpenAIEmbeddings( | |
| model="text-embedding-3-large", | |
| dimensions=3072, | |
| timeout=120, | |
| ), | |
| persist_directory="./chroma_rag_cache" | |
| ) | |
| logger.info("Vector store created and persisted to ./chroma_rag_cache.") | |
| # Configure retriever | |
| retriever = vectorstore.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={ | |
| "k": 40, | |
| "fetch_k": 200, | |
| "lambda_mult": 0.2, | |
| "filter": None, | |
| "score_threshold": 0.7, | |
| } | |
| ) | |
| logger.info("Retriever configured and ready to use.") | |
| return retriever | |
| # Define graders and components | |
| def setup_components(retriever, model_choice): | |
| # Data models for grading | |
| class GradeDocuments(BaseModel): | |
| """Binary score for relevance check on retrieved documents.""" | |
| binary_score: str = Field( | |
| description="Documents are relevant to the question, 'yes' or 'no'" | |
| ) | |
| class GradeHallucinations(BaseModel): | |
| """Binary score for hallucination present in generation answer.""" | |
| binary_score: str = Field( | |
| description="Answer is grounded in the facts, 'yes' or 'no'" | |
| ) | |
| class GradeAnswer(BaseModel): | |
| """Binary score to assess answer addresses question.""" | |
| binary_score: str = Field( | |
| description="Answer addresses the question, 'yes' or 'no'" | |
| ) | |
| # LLM models | |
| llm = ChatOpenAI(model=model_choice, temperature=0) | |
| doc_grader = llm.with_structured_output(GradeDocuments) | |
| hallucination_grader_llm = llm.with_structured_output(GradeHallucinations) | |
| answer_grader_llm = llm.with_structured_output(GradeAnswer) | |
| # Prompts | |
| # Document grading prompt | |
| system_doc = """You are a grader assessing relevance of a retrieved document to a user question. \n | |
| It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n | |
| If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n | |
| Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""" | |
| grade_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_doc), | |
| ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"), | |
| ] | |
| ) | |
| retrieval_grader = grade_prompt | doc_grader | |
| # Hallucination grading prompt | |
| system_hallucination = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n | |
| Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.""" | |
| hallucination_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_hallucination), | |
| ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"), | |
| ] | |
| ) | |
| hallucination_grader = hallucination_prompt | hallucination_grader_llm | |
| # Answer grading prompt | |
| system_answer = """You are a grader assessing whether an answer addresses / resolves a question \n | |
| Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.""" | |
| answer_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_answer), | |
| ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"), | |
| ] | |
| ) | |
| answer_grader = answer_prompt | answer_grader_llm | |
| # Question rewriter prompt | |
| system_rewrite = """You a question re-writer that converts an input question to a better version that is optimized \n | |
| for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.""" | |
| re_write_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_rewrite), | |
| ( | |
| "human", | |
| "Here is the initial question: \n\n {question} \n Formulate an improved question.", | |
| ), | |
| ] | |
| ) | |
| question_rewriter = re_write_prompt | llm | StrOutputParser() | |
| # RAG generation prompt and chain | |
| prompt = hub.pull("rlm/rag-prompt") | |
| rag_chain = prompt | llm | StrOutputParser() | |
| return { | |
| "retriever": retriever, | |
| "retrieval_grader": retrieval_grader, | |
| "hallucination_grader": hallucination_grader, | |
| "answer_grader": answer_grader, | |
| "question_rewriter": question_rewriter, | |
| "rag_chain": rag_chain | |
| } | |
| # Build the RAG graph | |
| def build_rag_graph(components): | |
| # Define graph state | |
| class GraphState(TypedDict): | |
| """Represents the state of our graph.""" | |
| question: str | |
| generation: str | |
| documents: List[str] | |
| # Node functions | |
| def retrieve(state): | |
| """Retrieve documents""" | |
| question = state["question"] | |
| documents = components["retriever"].get_relevant_documents(question) | |
| return {"documents": documents, "question": question} | |
| def generate(state): | |
| """Generate answer""" | |
| question = state["question"] | |
| documents = state["documents"] | |
| generation = components["rag_chain"].invoke({"context": documents, "question": question}) | |
| return {"documents": documents, "question": question, "generation": generation} | |
| def grade_documents(state): | |
| """Determines whether the retrieved documents are relevant to the question.""" | |
| question = state["question"] | |
| documents = state["documents"] | |
| # Score each doc | |
| filtered_docs = [] | |
| for d in documents: | |
| score = components["retrieval_grader"].invoke( | |
| {"question": question, "document": d.page_content} | |
| ) | |
| grade = score.binary_score | |
| if grade == "yes": | |
| filtered_docs.append(d) | |
| return {"documents": filtered_docs, "question": question} | |
| def transform_query(state): | |
| """Transform the query to produce a better question.""" | |
| question = state["question"] | |
| documents = state["documents"] | |
| better_question = components["question_rewriter"].invoke({"question": question}) | |
| return {"documents": documents, "question": better_question} | |
| # Edge functions | |
| def decide_to_generate(state): | |
| """Determines whether to generate an answer, or re-generate a question.""" | |
| filtered_documents = state["documents"] | |
| if not filtered_documents: | |
| # All documents have been filtered out | |
| return "transform_query" | |
| else: | |
| # We have relevant documents, so generate answer | |
| return "generate" | |
| def grade_generation_v_documents_and_question(state): | |
| """Determines whether the generation is grounded in the document and answers question.""" | |
| question = state["question"] | |
| documents = state["documents"] | |
| generation = state["generation"] | |
| score = components["hallucination_grader"].invoke( | |
| {"documents": documents, "generation": generation} | |
| ) | |
| grade = score.binary_score | |
| # Check hallucination | |
| if grade == "yes": | |
| # Check question-answering | |
| score = components["answer_grader"].invoke({"question": question, "generation": generation}) | |
| grade = score.binary_score | |
| if grade == "yes": | |
| return "useful" | |
| else: | |
| return "not useful" | |
| else: | |
| return "not supported" | |
| # Build the graph | |
| workflow = StateGraph(GraphState) | |
| # Add nodes | |
| workflow.add_node("retrieve", retrieve) | |
| workflow.add_node("grade_documents", grade_documents) | |
| workflow.add_node("generate", generate) | |
| workflow.add_node("transform_query", transform_query) | |
| # Add edges | |
| workflow.add_edge(START, "retrieve") | |
| workflow.add_edge("retrieve", "grade_documents") | |
| workflow.add_conditional_edges( | |
| "grade_documents", | |
| decide_to_generate, | |
| { | |
| "transform_query": "transform_query", | |
| "generate": "generate", | |
| }, | |
| ) | |
| workflow.add_edge("transform_query", "retrieve") | |
| workflow.add_conditional_edges( | |
| "generate", | |
| grade_generation_v_documents_and_question, | |
| { | |
| "not supported": "generate", | |
| "useful": END, | |
| "not useful": "transform_query", | |
| }, | |
| ) | |
| # Compile the graph | |
| return workflow.compile() | |
| # Initialize global variables | |
| retriever = None | |
| rag_app = None | |
| components = None | |
| current_model_choice = "gpt-4.1" # Default | |
| # Run PDF processing and RAG setup ONCE at startup, with default model | |
| retriever = initialize_retriever() | |
| if retriever is not None: | |
| components = setup_components(retriever, current_model_choice) | |
| rag_app = build_rag_graph(components) | |
| else: | |
| logger.error("No retriever could be initialized. Please add PDF files to the Data directory.") | |
| # Processing function for Gradio | |
| def process_query(question, display_logs=False, model_choice="gpt-4.1"): | |
| logs = [] | |
| answer = "" | |
| token_usage = {} | |
| try: | |
| global retriever, rag_app, components, current_model_choice | |
| if retriever is None: | |
| logs.append("Error: No PDF files found. Please add PDF files to the Data directory and restart the app.") | |
| return "Error: No PDF files found. Please add PDF files to the Data directory.", "\n".join(logs), token_usage | |
| # If model_choice changed, re-initialize components and rag_app | |
| if model_choice != current_model_choice: | |
| logs.append(f"Switching model to {model_choice} ...") | |
| components = setup_components(retriever, model_choice) | |
| rag_app = build_rag_graph(components) | |
| current_model_choice = model_choice | |
| logs.append("Processing query: " + question) | |
| logs.append(f"Using model: {model_choice}") | |
| logs.append("Starting RAG pipeline...") | |
| final_output = None | |
| with get_openai_callback() as cb: | |
| for i, output in enumerate(rag_app.stream({"question": question})): | |
| step_info = f"Step {i+1}: " | |
| if 'retrieve' in output: | |
| step_info += f"Retrieved {len(output['retrieve']['documents'])} documents" | |
| elif 'grade_documents' in output: | |
| step_info += f"Graded documents, {len(output['grade_documents']['documents'])} deemed relevant" | |
| elif 'transform_query' in output: | |
| step_info += f"Transformed query to: {output['transform_query']['question']}" | |
| elif 'generate' in output: | |
| step_info += "Generated answer" | |
| final_output = output | |
| logs.append(step_info) | |
| # Store token usage information | |
| token_usage = { | |
| "total_tokens": cb.total_tokens, | |
| "prompt_tokens": cb.prompt_tokens, | |
| "completion_tokens": cb.completion_tokens, | |
| "total_cost": cb.total_cost | |
| } | |
| logs.append(f"Token usage: {token_usage}") | |
| if final_output and 'generate' in final_output: | |
| answer = final_output['generate']['generation'] | |
| logs.append("Final answer generated successfully") | |
| else: | |
| answer = "No answer could be generated. Please try rephrasing your question." | |
| logs.append("Failed to generate answer") | |
| except Exception as e: | |
| logs.append(f"Error: {str(e)}") | |
| answer = f"An error occurred: {str(e)}" | |
| return answer, "\n".join(logs) if display_logs else "", token_usage | |
| # Create Gradio interface | |
| with gr.Blocks(title="Self-RAG Document Assistant", theme=gr.themes.Base()) as demo: | |
| with gr.Row(): | |
| gr.Markdown("# Self-RAG Document Assistant") | |
| with gr.Row(): | |
| gr.Markdown("""This application uses a Self-RAG (Retrieval Augmented Generation) system to | |
| provide accurate answers by: | |
| 1. Retrieving relevant documents from your PDF database | |
| 2. Grading document relevance to your question | |
| 3. Generating answers grounded in these documents | |
| 4. Self-checking for hallucinations and question addressing""") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask a question about your documents...", | |
| lines=4 | |
| ) | |
| with gr.Column(scale=1): | |
| model_choice_input = gr.Dropdown( | |
| label="Model", | |
| choices=["gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano"], | |
| value="gpt-4.1" | |
| ) | |
| show_logs = gr.Checkbox(label="Show Debugging Logs", value=False) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| answer_output = gr.Textbox( | |
| label="Answer", | |
| lines=10, | |
| placeholder="Your answer will appear here...", | |
| ) | |
| with gr.Row(): | |
| logs_output = gr.Textbox( | |
| label="Process Logs", | |
| lines=15, | |
| visible=False | |
| ) | |
| with gr.Row(): | |
| token_usage_output = gr.JSON( | |
| label="Token Usage Statistics", | |
| visible=True | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=process_query, | |
| inputs=[query_input, show_logs, model_choice_input], | |
| outputs=[answer_output, logs_output, token_usage_output] | |
| ) | |
| show_logs.change( | |
| fn=lambda x: gr.update(visible=x), | |
| inputs=[show_logs], | |
| outputs=[logs_output] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch(share=False) | |