from langgraph.graph import StateGraph, MessagesState, END, START from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_core.messages import SystemMessage from langgraph.checkpoint.memory import MemorySaver from langchain_community.document_loaders import WikipediaLoader from langchain_experimental.utilities.python import PythonREPL from pinecone import Pinecone from typing import List, Annotated from pydantic import BaseModel, Field from IPython.display import Image, display import operator import prompts # set environment variables import os from dotenv import load_dotenv load_dotenv() llm = ChatOpenAI(model="gpt-4o", temperature=0) weak_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.5) class QuestionState(MessagesState): topic: str # topic of the question subtopic: str # subtopic of the question difficulty: str # difficulty of the question description: str # description of the subtopic context: Annotated[list, operator.add] # knowledge base of the subtopic relevant_questions: List[dict] # relevant questions num_questions: int # number of relevant questions to extract human_feedback: str # feedback from the human question: str # question to ask steps: List[str] # steps to solve the question tool_requests: List[dict] # tool requests to solve the question tool_results: List[dict] # tool results to solve the question verified: bool # if the solution is verified solution: str # solution to the question answer: str # answer to the question # ------------------------------- # Node 1: Generate Description Node # ------------------------------- def generate_description(state: QuestionState): """ Generate a description for the subtopic """ topic = state["topic"] subtopic = state["subtopic"] # generate description system_message = prompts.DESCRIPTION_INSTRUCTION.format( topic=topic, subtopic=subtopic ) description = weak_llm.invoke( [SystemMessage(content=system_message)], max_tokens=30 ).content # write description to state return {"description": description} # ------------------------------- # Node 2: Search Wikipedia Node # ------------------------------- def search_wikipedia(state: QuestionState): """ Search wikipedia for the topic and subtopic """ subtopic = state["subtopic"] search_query = f"What is {subtopic}" # search wikipedia search_docs = WikipediaLoader( query=search_query, load_max_docs=1, doc_content_chars_max=1500 ).load() # Format formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ] ) return {"context": [formatted_search_docs]} # ------------------------------- # Node 3: Search Document Node # ------------------------------- def search_document(state: QuestionState): """ Search the document for relevant context """ topic = state["topic"] subtopic = state["subtopic"] # Initialize OpenAI Embeddings client client = OpenAIEmbeddings(model="text-embedding-3-large") query = f"Search about {topic} in area of {subtopic}" embedded_query = client.embed_query(query) # Initialize Pinecone client api_key = os.environ.get("PINECONE_API_KEY") pc = Pinecone(api_key=api_key) # 2. Vector DB query with metadata filter index_name = os.environ.get("PINECONE_INDEX_NAME") index = pc.Index(index_name) filters = { "topic": {"$eq": topic}, "subtopic": {"$eq": subtopic}, "type": {"$eq": "description"}, } # Execute similarity search try: results = index.query( vector=embedded_query, filter=filters, top_k=1, # Get top 5 similar questions include_metadata=True, ) except Exception as e: raise ConnectionError(f"Vector DB query failed: {str(e)}") # Get the context if results and hasattr(results, "matches") and len(results.matches) > 0: context = results.matches[0].metadata.get("context", "") return {"context": [context]} else: return {"context": []} # ------------------------------- # Node 4: Search Questions Node # ------------------------------- def search_questions(state: QuestionState): """ Search the document for relevant questions """ topic = state["topic"] subtopic = state["subtopic"] num_questions = state["num_questions"] difficulty = state["difficulty"] # Initialize OpenAI Embeddings client client = OpenAIEmbeddings(model="text-embedding-3-large") query = f"Questions related to {topic} in area of {subtopic}" embedded_query = client.embed_query(query) # Initialize Pinecone client api_key = os.environ.get("PINECONE_API_KEY") pc = Pinecone(api_key=api_key) # 2. Vector DB query with metadata filter index_name = os.environ.get("PINECONE_INDEX_NAME") index = pc.Index(index_name) filters = { "topic": {"$eq": topic}, "subtopic": {"$eq": subtopic}, "type": {"$eq": "question"}, "difficulty": {"$eq": difficulty}, } # Execute similarity search try: results = index.query( vector=embedded_query, filter=filters, top_k=num_questions, include_metadata=True, ) except Exception as e: raise ConnectionError(f"Vector DB query failed: {str(e)}") references = [] for match in results.matches: metadata = match.metadata references.append( { "question": metadata["question"], "answer": metadata["answer"], "difficulty": metadata["difficulty"], } ) return {"relevant_questions": references} # ------------------------------- # Node 5: Generate Question Node # ------------------------------- def generate_question(state: QuestionState): """ Generate a question for the subtopic """ topic = state["topic"] subtopic = state["subtopic"] difficulty = state["difficulty"] context = state["context"] relevant_questions = state["relevant_questions"] human_feedback = state.get("human_feedback", "") # generate question query = prompts.QUESTION_INSTRUCTION.format( topic=topic, subtopic=subtopic, difficulty=difficulty, context=context, relevant_questions=relevant_questions, feedback=human_feedback, ) question = llm.invoke([SystemMessage(content=query)], temperature=0.3).content # Clean residual markdown formatting question = question.strip().strip("`").replace("**Question:**", "").strip() print("Generated Question: ", question) # write question to state return {"question": question} # ------------------------------- # Node 6: Feedback Node # ------------------------------- def human_feedback(state: QuestionState): """No-op node that shoulds be interrupted on""" print("Human Feedback Node: ", state) pass def should_continue(state: QuestionState): """Return the next node to execute""" print("Should Continue: ", state) # Check if human feedback human_feedback = state.get("human_feedback", None) if human_feedback: return "generate_question" # Otherwise end return "llm_step_planner" # ------------------------------- # Node 7: LLM Step Planner # ------------------------------- class SolutionPlan(BaseModel): solution_steps: List[str] = Field(description="List of steps to solve the problem") def llm_step_planner(state: QuestionState): question = state["question"] try: prompt = prompts.STEP_INSTRUCTION.format(question=question) structured_llm = llm.with_structured_output(SolutionPlan) steps = structured_llm.invoke([SystemMessage(content=prompt)]) print("Steps", steps) return {"steps": steps.solution_steps} except Exception as e: return {"error": f"LLM Parsing Error: {str(e)}"} # ------------------------------- # Node 8: LLM Tool Decider # ------------------------------- class ToolRequest(BaseModel): code: str = Field(description="Python code to execute") description: str = Field(description="Description of the code") class ToolRequestList(BaseModel): tool_requests: List[ToolRequest] = Field(description="List of tool requests") def llm_tool_decider(state: QuestionState): if "error" in state and state["error"]: return state # Pass through error try: question = state["question"] steps = state.get("steps", []) prompt = prompts.TOOL_INSTRUCTION.format(question=question, steps=steps) structured_llm = llm.with_structured_output(ToolRequestList) tool_requests = structured_llm.invoke( [SystemMessage(content=prompt)], max_tokens=500, temperature=0.2 ) print("Tool Requests", tool_requests) return { "tool_requests": [req.model_dump() for req in tool_requests.tool_requests] } except Exception as e: return {"error": f"LLM Tool Decider Error: {str(e)}"} # ------------------------------- # Node 9: LLM Tool Executor # ------------------------------- code_executor = PythonREPL() def tool_executor(state: QuestionState): if "error" in state and state["error"]: return state try: tool_results = [] for req in state.get("tool_requests", []): print("Req", req) if req.get("type", "sympy") == "sympy": # default to sympy try: output = code_executor.run(req["code"]) # Executes full code tool_results.append( { "description": req.get("description", ""), "result": output.strip(), } ) except Exception as e: tool_results.append( { "description": req.get("description", ""), "result": f"Execution Error: {str(e)}", } ) else: tool_results.append( { "description": f"Unknown tool type: {req.get('type')}", "result": None, } ) print("Tool Results", tool_results) return {"tool_results": tool_results} except Exception as e: return {"error": f"Tool Execution Error: {str(e)}"} # ------------------------------- # Node 10: LLM Verifier # ------------------------------- class VerifierResponse(BaseModel): verified: bool = Field(description="Whether the solution is verified") explanation: str = Field(description="Explanation for verification decision") def llm_verifier(state: QuestionState): if "error" in state and state["error"]: return state try: question = state["question"] steps = state.get("steps", []) tool_results = state.get("tool_results", []) prompt = prompts.VERIFICATION_INSTRUCTION.format( question=question, steps=steps, tool_results=tool_results ) structured_llm = weak_llm.with_structured_output(VerifierResponse) verification_results = structured_llm.invoke( [SystemMessage(content=prompt)], max_tokens=500 ).model_dump() result = False if verification_results.get("verified", False): result = True else: result = False return { "verified": result, "error": ( None if result else f"Verification Failed: {verification_results.get('explanation', 'No explanation')}" ), } except Exception as e: return {"error": f"LLM Verifier Error: {str(e)}"} # ------------------------------- # Node 11: LLM Finalizer # ------------------------------- class FinalizerResponse(BaseModel): solution: str = Field(description="Markdown solution") answer: str = Field(description="Final answer") def llm_finalizer(state: QuestionState): if "error" in state and state["error"]: state["solution"] = f"### Error\n{state['error']}" state["answer"] = "N/A" return state try: question = state["question"] steps = state.get("steps", []) tool_results = state.get("tool_results", []) verified = state.get("verified", False) prompt = prompts.FINALIZE_INSTRUCTION.format( question=question, steps=steps, tool_results=tool_results, verified=verified, ) structured_llm = llm.with_structured_output(FinalizerResponse) final_response = structured_llm.invoke( [SystemMessage(content=prompt)], max_tokens=1000, temperature=0.2 ) return {"solution": final_response.solution, "answer": final_response.answer} except Exception as e: return {"solution": f"### Finalization Error\n{str(e)}", "answer": "N/A"} # ------------------------------- # Graph Construction # ------------------------------- builder = StateGraph(QuestionState) builder.add_node("generate_description", generate_description) # builder.add_node("search_wikipedia", search_wikipedia) builder.add_node("search_document", search_document) builder.add_node("search_questions", search_questions) builder.add_node("generate_question", generate_question) builder.add_node("feedback", human_feedback) builder.add_node("llm_step_planner", llm_step_planner) builder.add_node("llm_tool_decider", llm_tool_decider) builder.add_node("tool_executor", tool_executor) builder.add_node("llm_verifier", llm_verifier) builder.add_node("llm_finalizer", llm_finalizer) # Add edges builder.add_edge(START, "generate_description") # builder.add_edge("generate_description", "search_wikipedia") builder.add_edge("generate_description", "search_document") builder.add_edge("generate_description", "search_questions") # builder.add_edge("search_wikipedia", "generate_question") builder.add_edge("search_document", "generate_question") builder.add_edge("search_questions", "generate_question") builder.add_edge("generate_question", "feedback") builder.add_conditional_edges( "feedback", should_continue, ["generate_question", "llm_step_planner"] ) # builder.add_edge("generate_question", "llm_step_planner") builder.add_edge("llm_step_planner", "llm_tool_decider") builder.add_edge("llm_tool_decider", "tool_executor") builder.add_edge("tool_executor", "llm_verifier") builder.add_edge("llm_verifier", "llm_finalizer") builder.add_edge("llm_finalizer", END) # Compile memory = MemorySaver() question_graph = builder.compile(interrupt_before=["feedback"], checkpointer=memory) question_graph.name = "QuestionGenerationGraph" # question_graph = builder.compile(checkpointer=memory) # display(Image(question_graph.get_graph(xray=1).draw_mermaid_png()))