Spaces:
Runtime error
Runtime error
| import os | |
| import streamlit as st | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain.schema import Document | |
| from langchain_openai import ChatOpenAI | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langgraph.graph import StateGraph, END | |
| from graphviz import Digraph # For workflow visualization | |
| from typing_extensions import TypedDict | |
| from typing import List | |
| from utils.build_rag import RAG | |
| # Fetch API Keys | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") | |
| # Check for Missing API Keys | |
| if not OPENAI_API_KEY or not TAVILY_API_KEY: | |
| st.error("❌ API keys missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in your `.env` file.") | |
| st.stop() # Stop the app execution | |
| # Set up LLM and Tools | |
| llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY) | |
| web_search_tool = TavilySearchResults(api_key=TAVILY_API_KEY, k=2) | |
| # Prompt templates | |
| def get_prompt(): | |
| template = """Answer the question based only on the following context: | |
| {context} | |
| Question: {question} | |
| """ | |
| return ChatPromptTemplate.from_template(template) | |
| # Define Graph State | |
| class GraphState(TypedDict): | |
| question: str | |
| generation: str | |
| web_search: str | |
| documents: List[Document] | |
| # RAG Setup | |
| rag = RAG() | |
| retriever = rag.get_retriever() | |
| prompt = get_prompt() | |
| output_parser = StrOutputParser() | |
| # Nodes | |
| def retrieve(state): | |
| question = state["question"] | |
| documents = retriever.get_relevant_documents(question) | |
| st.sidebar.write(f"Retrieved Documents: {len(documents)}") | |
| return {"documents": documents, "question": question} | |
| def grade_documents(state): | |
| question = state["question"] | |
| documents = state["documents"] | |
| filtered_docs = [] | |
| web_search = "No" | |
| for doc in documents: | |
| score = {"binary_score": "yes"} # Dummy grader; integrate as needed | |
| if score["binary_score"] == "yes": | |
| filtered_docs.append(doc) | |
| else: | |
| web_search = "Yes" | |
| st.sidebar.write(f"Document Grading Results: {len(filtered_docs)} relevant") | |
| return {"documents": filtered_docs, "web_search": web_search, "question": question} | |
| def generate(state): | |
| context = "\n".join([doc.page_content for doc in state["documents"]]) | |
| response = output_parser.parse(llm.invoke({"context": context, "question": state["question"]}).content) | |
| return {"generation": response} | |
| def transform_query(state): | |
| question = state["question"] | |
| new_question = llm.invoke(f"Rewrite: {question}").content | |
| st.sidebar.write(f"Rewritten Question: {new_question}") | |
| return {"question": new_question} | |
| def web_search(state): | |
| question = state["question"] | |
| results = web_search_tool.invoke({"query": question}) | |
| docs = "\n".join([result["content"] for result in results]) | |
| state["documents"].append(Document(page_content=docs)) | |
| st.sidebar.write("Web Search Completed") | |
| return {"documents": state["documents"], "question": question} | |
| def decide_to_generate(state): | |
| return "generate" if state["web_search"] == "No" else "transform_query" | |
| # Build Graph | |
| workflow = StateGraph(GraphState) | |
| workflow.add_node("retrieve", retrieve) | |
| workflow.add_node("grade_documents", grade_documents) | |
| workflow.add_node("generate", generate) | |
| workflow.add_node("transform_query", transform_query) | |
| workflow.add_node("web_search_node", web_search) | |
| workflow.set_entry_point("retrieve") | |
| workflow.add_edge("retrieve", "grade_documents") | |
| workflow.add_conditional_edges("grade_documents", decide_to_generate, {"transform_query": "transform_query", "generate": "generate"}) | |
| workflow.add_edge("transform_query", "web_search_node") | |
| workflow.add_edge("web_search_node", "generate") | |
| workflow.add_edge("generate", END) | |
| app = workflow.compile() | |
| # Visualize Workflows | |
| def plot_workflow(): | |
| graph = Digraph() | |
| graph.attr(size='6,6') | |
| # Add nodes | |
| graph.node("retrieve", "Retrieve Documents") | |
| graph.node("grade_documents", "Grade Documents") | |
| graph.node("generate", "Generate Answer") | |
| graph.node("transform_query", "Transform Query") | |
| graph.node("web_search_node", "Web Search") | |
| graph.node("END", "End") | |
| # Add edges | |
| graph.edge("retrieve", "grade_documents") | |
| graph.edge("grade_documents", "generate", label="Relevant Docs") | |
| graph.edge("grade_documents", "transform_query", label="No Relevant Docs") | |
| graph.edge("transform_query", "web_search_node") | |
| graph.edge("web_search_node", "generate") | |
| graph.edge("generate", "END") | |
| return graph | |
| # Streamlit App | |
| st.title("Self-Corrective RAG") | |
| st.write("### Compare RAG Pipeline Outputs (With and Without Self-Correction)") | |
| # Plot Workflow | |
| st.subheader("Workflow Visualization") | |
| st.graphviz_chart(plot_workflow().source) | |
| # User Input | |
| question = st.text_input("Enter your question:", "What is Llama2?") | |
| if st.button("Run Comparison"): | |
| # Run Basic RAG | |
| st.subheader("Without Self-Correction:") | |
| docs = retriever.invoke(question) | |
| basic_context = "\n".join([doc.page_content for doc in docs]) | |
| basic_response = output_parser.parse(llm.invoke({"context": basic_context, "question": question}).content) | |
| st.write(basic_response) | |
| # Run Self-Corrective RAG | |
| st.subheader("With Self-Correction:") | |
| inputs = {"question": question} | |
| final_generation = "" | |
| for output in app.stream(inputs): | |
| for key, value in output.items(): | |
| if key == "generation": | |
| final_generation = value | |
| st.write(final_generation) | |