from langchain.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langgraph.graph import START, END, StateGraph from config import RAG_MODEL from utils.state import State RAG_PROMPT = """\ You are a helpful assistant specializing in AI evaluation and research papers analysis. Your task is to answer questions based ONLY on the provided context. If the context doesn't contain the information needed, acknowledge that limitation rather than making up information. ### Question {question} ### Context {context} Based on the above context only, provide a comprehensive answer. Include specific details from the research papers when relevant. If the question cannot be answered from the context, clearly state that the information is not available in the provided documents. """ def create_rag_graph(retriever): """Create a RAG graph that uses a retriever to answer questions.""" rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) llm = ChatOpenAI(model=RAG_MODEL) def retrieve(state): retrieved_docs = retriever.invoke(state["question"]) return {"context": retrieved_docs} def generate(state): # Extract document content and include document identifiers doc_contents = [] for i, doc in enumerate(state["context"]): # Include document identifier in the content for reference metadata = doc.metadata if hasattr(doc, 'metadata') else {} doc_id = metadata.get('id', f'doc_{i}') doc_contents.append(f"Document {doc_id}:\n{doc.page_content}") # Join all document contents with clear separation docs_content = "\n\n" + "\n\n---\n\n".join(doc_contents) # Format messages with question and context messages = rag_prompt.format_messages(question=state["question"], context=docs_content) response = llm.invoke(messages) return {"response": response.content} # Create graph with individual nodes and edges instead of using add_sequence graph_builder = StateGraph(State) # Add nodes individually graph_builder.add_node("retrieve", retrieve) graph_builder.add_node("generate", generate) # Add edges individually graph_builder.add_edge(START, "retrieve") graph_builder.add_edge("retrieve", "generate") graph_builder.add_edge("generate", END) return graph_builder.compile()