Spaces:
Build error
Build error
| import os | |
| import streamlit as st | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
| from pprint import pprint | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.pydantic_v1 import BaseModel, Field | |
| from langchain import hub | |
| from langchain_core.output_parsers import StrOutputParser | |
| from typing import List | |
| from typing_extensions import TypedDict | |
| from langgraph.graph import StateGraph, END | |
| # Streamlit setup with new theme and typography | |
| st.set_page_config(page_title="SELF-RAG Workflow Application", page_icon="🤖", layout="centered") | |
| st.markdown( | |
| """ | |
| <style> | |
| .main { | |
| background-color: #272727; | |
| font-family: 'Helvetica Neue', sans-serif; | |
| } | |
| .sidebar .sidebar-content { | |
| background-color: #2E3944; | |
| color: #ffffff; | |
| } | |
| h1 { | |
| color: #14A76C; | |
| } | |
| .stTextInput { | |
| border: 1px solid #272727; | |
| border-radius: 5px; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Sidebar with instructions and API key input | |
| st.sidebar.title("Instructions") | |
| st.sidebar.write(""" | |
| 1. Enter your OpenAI API Key. | |
| 2. Enter your question in the text box. | |
| 3. Provide URLs for the documents you want to use. | |
| 4. Click on the 'Run Workflow' button. | |
| 5. View the results below. | |
| """) | |
| api_key = st.sidebar.text_input("Enter your OpenAI API Key:", type="password") | |
| st.title("SELF-RAG Workflow Application") | |
| input_text = st.text_input("Enter your question : ") | |
| urls_input = st.text_area("Enter URLs (one per line) :") | |
| urls = [url.strip() for url in urls_input.split('\n') if url.strip()] | |
| inputs = {"question": input_text, "transform_attempts": 0} | |
| if st.button("Run Workflow"): | |
| if not api_key: | |
| st.error("Please enter your OpenAI API Key.") | |
| elif not urls: | |
| st.error("Please provide at least one URL.") | |
| elif not input_text: | |
| st.error("Please enter a question.") | |
| else: | |
| # Document loading and processing | |
| try: | |
| texts = [] | |
| docs = [] | |
| for url in urls: | |
| try: | |
| docs.extend(WebBaseLoader(url).load()) | |
| except Exception as e: | |
| st.error(f"Error loading document from {url}: {e}") | |
| if not docs: | |
| st.error("No documents loaded. Please check the URLs.") | |
| else: | |
| text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
| chunk_size=250, chunk_overlap=0 | |
| ) | |
| doc_splits = text_splitter.split_documents(docs) | |
| # Add to vectorDB | |
| vectorstore = FAISS.from_documents( | |
| documents=doc_splits, | |
| embedding=OpenAIEmbeddings(openai_api_key=api_key), | |
| ) | |
| retriever = vectorstore.as_retriever() | |
| ### Retrieval Grader | |
| # Data model | |
| class GradeDocuments(BaseModel): | |
| """Binary score for relevance check on retrieved documents.""" | |
| binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'") | |
| # LLM with function call | |
| llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0, openai_api_key=api_key) | |
| structured_llm_grader = llm.with_structured_output(GradeDocuments) | |
| # Prompt | |
| system = """You are a grader assessing relevance of a retrieved document to a user question. \n | |
| It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n | |
| If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n | |
| Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""" | |
| grade_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system), | |
| ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"), | |
| ] | |
| ) | |
| retrieval_grader = grade_prompt | structured_llm_grader | |
| question = input_text | |
| docs = retriever.get_relevant_documents(question) | |
| if not docs: | |
| st.error("No relevant documents found for the question.") | |
| else: | |
| doc_txt = docs[1].page_content | |
| ### Generate | |
| # Prompt | |
| prompt = hub.pull("rlm/rag-prompt") | |
| # LLM | |
| llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key) | |
| # Post-processing | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| # Chain | |
| rag_chain = prompt | llm | StrOutputParser() | |
| # Run | |
| generation = rag_chain.invoke({"context": docs, "question": question}) | |
| ### Hallucination Grader | |
| # Data model | |
| class GradeHallucinations(BaseModel): | |
| """Binary score for hallucination present in generation answer.""" | |
| binary_score: str = Field(description="Answer is grounded in the facts, 'yes' or 'no'") | |
| # LLM with function call | |
| llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0, openai_api_key=api_key) | |
| structured_llm_grader = llm.with_structured_output(GradeHallucinations) | |
| # Prompt | |
| system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n | |
| Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.""" | |
| hallucination_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system), | |
| ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"), | |
| ] | |
| ) | |
| hallucination_grader = hallucination_prompt | structured_llm_grader | |
| ### Answer Grader | |
| # Data model | |
| class GradeAnswer(BaseModel): | |
| """Binary score to assess answer addresses question.""" | |
| binary_score: str = Field(description="Answer addresses the question, 'yes' or 'no'") | |
| # LLM with function call | |
| llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0, openai_api_key=api_key) | |
| structured_llm_grader = llm.with_structured_output(GradeAnswer) | |
| # Prompt | |
| system = """You are a grader assessing whether an answer addresses / resolves a question \n | |
| Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.""" | |
| answer_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system), | |
| ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"), | |
| ] | |
| ) | |
| answer_grader = answer_prompt | structured_llm_grader | |
| ### Question Re-writer | |
| # LLM | |
| llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0, openai_api_key=api_key) | |
| # Prompt | |
| system = """You a question re-writer that converts an input question to a better version that is optimized \n | |
| for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.""" | |
| re_write_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system), | |
| ( | |
| "human", | |
| "Here is the initial question: \n\n {question} \n Formulate an improved question.", | |
| ), | |
| ] | |
| ) | |
| question_rewriter = re_write_prompt | llm | StrOutputParser() | |
| class GraphState(TypedDict): | |
| """ | |
| Represents the state of our graph. | |
| Attributes: | |
| question: question | |
| generation: LLM generation | |
| documents: list of documents | |
| transform_attempts: int | |
| """ | |
| question: str | |
| generation: str | |
| documents: List[str] | |
| transform_attempts: int | |
| ### Nodes | |
| def retrieve(state): | |
| """ | |
| Retrieve documents | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): New key added to state, documents, that contains retrieved documents | |
| """ | |
| texts.append("---RETRIEVE---") | |
| question = state["question"] | |
| # Retrieval | |
| documents = retriever.get_relevant_documents(question) | |
| return {"documents": documents, "question": question, "transform_attempts": state.get("transform_attempts", 0)} | |
| def generate(state): | |
| """ | |
| Generate answer | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): New key added to state, generation, that contains LLM generation | |
| """ | |
| texts.append("---GENERATE---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| # RAG generation | |
| generation = rag_chain.invoke({"context": documents, "question": question}) | |
| return {"documents": documents, "question": question, "generation": generation, "transform_attempts": state.get("transform_attempts", 0)} | |
| def grade_documents(state): | |
| """ | |
| Determines whether the retrieved documents are relevant to the question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): Updates documents key with only filtered relevant documents | |
| """ | |
| texts.append("---CHECK DOCUMENT RELEVANCE TO QUESTION---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| # Score each doc | |
| filtered_docs = [] | |
| for d in documents: | |
| score = retrieval_grader.invoke( | |
| {"question": question, "document": d.page_content} | |
| ) | |
| grade = score.binary_score | |
| if grade == "yes": | |
| texts.append("---GRADE: DOCUMENT RELEVANT---") | |
| filtered_docs.append(d) | |
| else: | |
| texts.append("---GRADE: DOCUMENT NOT RELEVANT---") | |
| continue | |
| return {"documents": filtered_docs, "question": question, "transform_attempts": state.get("transform_attempts", 0)} | |
| def transform_query(state): | |
| """ | |
| Transform the query to produce a better question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): Updates question key with a re-phrased question | |
| """ | |
| texts.append("---TRANSFORM QUERY---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| # Re-write question | |
| better_question = question_rewriter.invoke({"question": question}) | |
| return {"documents": documents, "question": better_question, "transform_attempts": state.get("transform_attempts", 0) + 1} | |
| ### Edges | |
| def decide_to_generate(state): | |
| """ | |
| Determines whether to generate an answer, or re-generate a question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| str: Binary decision for next node to call | |
| """ | |
| texts.append("---ASSESS GRADED DOCUMENTS---") | |
| filtered_documents = state["documents"] | |
| if not filtered_documents: | |
| if state.get("transform_attempts", 0) >= 3: | |
| return "conclude_no_answer" | |
| else: | |
| # All documents have been filtered check_relevance | |
| # We will re-generate a new query | |
| texts.append( | |
| "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---" | |
| ) | |
| return "transform_query" | |
| else: | |
| # We have relevant documents, so generate answer | |
| texts.append("---DECISION: GENERATE---") | |
| return "generate" | |
| def grade_generation_v_documents_and_question(state): | |
| """ | |
| Determines whether the generation is grounded in the document and answers question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| str: Decision for next node to call | |
| """ | |
| texts.append("---CHECK HALLUCINATIONS---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| generation = state["generation"] | |
| score = hallucination_grader.invoke( | |
| {"documents": documents, "generation": generation} | |
| ) | |
| grade = score.binary_score | |
| # Check hallucination | |
| if grade == "yes": | |
| texts.append("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") | |
| # Check question-answering | |
| texts.append("---GRADE GENERATION vs QUESTION---") | |
| score = answer_grader.invoke({"question": question, "generation": generation}) | |
| grade = score.binary_score | |
| if grade == "yes": | |
| texts.append("---DECISION: GENERATION ADDRESSES QUESTION---") | |
| return "useful" | |
| else: | |
| texts.append("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") | |
| return "not useful" | |
| else: | |
| texts.append("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---") | |
| return "not supported" | |
| workflow = StateGraph(GraphState) | |
| # Define the nodes | |
| workflow.add_node("retrieve", retrieve) # retrieve | |
| workflow.add_node("grade_documents", grade_documents) # grade documents | |
| workflow.add_node("generate", generate) # generate | |
| workflow.add_node("transform_query", transform_query) # transform_query | |
| workflow.add_node("conclude_no_answer", lambda state: {"question": state["question"], "generation": "I don't know the answer since none of the given documents are relevant to the question.", "documents": [], "transform_attempts": state.get("transform_attempts", 0)}) | |
| # Build graph | |
| workflow.set_entry_point("retrieve") | |
| workflow.add_edge("retrieve", "grade_documents") | |
| workflow.add_conditional_edges( | |
| "grade_documents", | |
| decide_to_generate, | |
| { | |
| "transform_query": "transform_query", | |
| "generate": "generate", | |
| "conclude_no_answer": "conclude_no_answer" | |
| }, | |
| ) | |
| workflow.add_edge("transform_query", "retrieve") | |
| workflow.add_conditional_edges( | |
| "generate", | |
| grade_generation_v_documents_and_question, | |
| { | |
| "not supported": "generate", | |
| "useful": END, | |
| "not useful": "transform_query", | |
| }, | |
| ) | |
| # Compile | |
| app = workflow.compile() | |
| try: | |
| for output in app.stream(inputs): | |
| for key, value in output.items(): | |
| for i in texts: | |
| st.write(i) | |
| texts = [] | |
| # Final generation | |
| st.write('## Final Answer') | |
| st.write(value["generation"]) | |
| except Exception as e: | |
| st.error(f"Error in workflow execution: {e}") | |
| except Exception as e: | |
| st.error(f"Error in document processing: {e}") | |