# ============================================================================== # 1. SETUP AND IMPORTS # ============================================================================== import os from fastapi import FastAPI from pydantic import BaseModel from typing import Optional # Set Google API Key securely from an environment variable google_api_key = os.getenv("GOOGLE_API_KEY") if not google_api_key: raise ValueError("Google API key not found. Please set the GOOGLE_API_KEY environment variable.") # All your other imports... import bs4 from langchain import hub from langchain_community.document_loaders import WebBaseLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings from langchain_core.vectorstores import InMemoryVectorStore from langgraph.graph import MessagesState, StateGraph, END from langgraph.prebuilt import ToolNode, tools_condition from langchain_core.messages import HumanMessage from langchain_core.tools import tool from langgraph.checkpoint.memory import MemorySaver # ============================================================================== # 2. CORE LOGIC # ============================================================================== llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=google_api_key) embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=google_api_key) vector_store = InMemoryVectorStore(embeddings) # Load Web Data and Split web_url = "https://lilianweng.github.io/posts/2023-06-23-agent/" loader = WebBaseLoader( web_paths=(web_url,), bs_kwargs=dict(parse_only=bs4.SoupStrainer(class_=("post-content", "post-title", "post-header"))) ) docs = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, add_start_index=True) all_splits = text_splitter.split_documents(docs) vector_store.add_documents(all_splits) # Tool Definition @tool def retrieve(query: str): """retrieve information related to a query.""" retrieved_docs = vector_store.similarity_search(query, k=2) return [doc.page_content for doc in retrieved_docs] # Graph Node Functions def query_or_respond(state: MessagesState): llm_with_tools = llm.bind_tools([retrieve]) response = llm_with_tools.invoke(state["messages"]) return {"messages": [response]} tools = ToolNode([retrieve]) def generate(state: MessagesState): response = llm.invoke(state["messages"]) return {"messages": [response]} # Compile the LangGraph StateGraph graph_builder = StateGraph(MessagesState) graph_builder.add_node("query_or_respond", query_or_respond) graph_builder.add_node("tools", tools) graph_builder.add_node("generate", generate) graph_builder.set_entry_point("query_or_respond") graph_builder.add_conditional_edges( "query_or_respond", tools_condition, {"tools": "tools", END: END} ) graph_builder.add_edge("tools", "generate") graph_builder.add_edge("generate", END) memory = MemorySaver() graph = graph_builder.compile(checkpointer=memory) # ============================================================================== # 3. API SERVER (Replaces your if __name__ == "__main__": block) # ============================================================================== app = FastAPI( title="LangGraph RAG Agent Server", description="An API server for a RAG agent built with LangGraph.", ) # Define the input model for the API class UserRequest(BaseModel): message: str thread_id: Optional[str] = "default_thread" # Use a default thread_id if none is provided # Define the API endpoint @app.post("/invoke") async def invoke_agent(request: UserRequest): # Set up the configuration for memory config = {"configurable": {"thread_id": request.thread_id}} # Define the input for the graph inputs = {"messages": [HumanMessage(content=request.message)]} # Invoke the graph to get the final result response = graph.invoke(inputs, config=config) # Return the AI's final message final_message = response["messages"][-1] return {"response": final_message.content} # This part is for local testing, can be removed if using a production server if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)