import inspect import json import logging import warnings from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import Annotated, Any from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, FastAPI, HTTPException, status from fastapi.responses import StreamingResponse from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from langchain_core._api import LangChainBetaWarning from langchain_core.messages import AIMessage, AIMessageChunk, AnyMessage, HumanMessage, ToolMessage from langchain_core.runnables import RunnableConfig from langfuse import Langfuse # type: ignore[import-untyped] from langfuse.langchain import ( CallbackHandler, # type: ignore[import-untyped] ) from langgraph.types import Command, Interrupt from langsmith import Client as LangsmithClient from agents import DEFAULT_AGENT, AgentGraph, get_agent, get_all_agent_info, load_agent from core import settings from memory import initialize_database, initialize_store from schema import ( ChatHistory, ChatHistoryInput, ChatMessage, Feedback, FeedbackResponse, ServiceMetadata, StreamInput, UserInput, ) from service.utils import ( convert_message_content_to_string, langchain_to_chat_message, remove_tool_calls, ) warnings.filterwarnings("ignore", category=LangChainBetaWarning) logger = logging.getLogger(__name__) def custom_generate_unique_id(route: APIRoute) -> str: """Generate idiomatic operation IDs for OpenAPI client generation.""" return route.name def verify_bearer( http_auth: Annotated[ HTTPAuthorizationCredentials | None, Depends(HTTPBearer(description="Please provide AUTH_SECRET api key.", auto_error=False)), ], ) -> None: if not settings.AUTH_SECRET: return auth_secret = settings.AUTH_SECRET.get_secret_value() if not http_auth or http_auth.credentials != auth_secret: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """ Configurable lifespan that initializes the appropriate database checkpointer, store, and agents with async loading - for example for starting up MCP clients. """ try: # Initialize both checkpointer (for short-term memory) and store (for long-term memory) async with initialize_database() as saver, initialize_store() as store: # Set up both components if hasattr(saver, "setup"): # ignore: union-attr await saver.setup() # Only setup store for Postgres as InMemoryStore doesn't need setup if hasattr(store, "setup"): # ignore: union-attr await store.setup() # Configure agents with both memory components and async loading agents = get_all_agent_info() for a in agents: try: await load_agent(a.key) logger.info(f"Agent loaded: {a.key}") except Exception as e: logger.error(f"Failed to load agent {a.key}: {e}") # Continue with other agents rather than failing startup agent = get_agent(a.key) # Set checkpointer for thread-scoped memory (conversation history) agent.checkpointer = saver # Set store for long-term memory (cross-conversation knowledge) agent.store = store yield except Exception as e: logger.error(f"Error during database/store/agents initialization: {e}") raise app = FastAPI(lifespan=lifespan, generate_unique_id_function=custom_generate_unique_id) app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) router = APIRouter(dependencies=[Depends(verify_bearer)]) @router.get("/info") async def info() -> ServiceMetadata: models = list(settings.AVAILABLE_MODELS) models.sort() return ServiceMetadata( agents=get_all_agent_info(), models=models, default_agent=DEFAULT_AGENT, default_model=settings.DEFAULT_MODEL, ) async def _handle_input(user_input: UserInput, agent: AgentGraph) -> tuple[dict[str, Any], UUID]: """ Parse user input and handle any required interrupt resumption. Returns kwargs for agent invocation and the run_id. """ run_id = uuid4() thread_id = user_input.thread_id or str(uuid4()) user_id = user_input.user_id or str(uuid4()) configurable = {"thread_id": thread_id, "user_id": user_id} if user_input.model is not None: configurable["model"] = user_input.model callbacks: list[Any] = [] if settings.LANGFUSE_TRACING: # Initialize Langfuse CallbackHandler for Langchain (tracing) langfuse_handler = CallbackHandler() callbacks.append(langfuse_handler) if user_input.agent_config: # Check for reserved keys (including 'model' even if not in configurable) reserved_keys = {"thread_id", "user_id", "model"} if overlap := reserved_keys & user_input.agent_config.keys(): raise HTTPException( status_code=422, detail=f"agent_config contains reserved keys: {overlap}", ) configurable.update(user_input.agent_config) config = RunnableConfig( configurable=configurable, run_id=run_id, callbacks=callbacks, ) # Check for interrupts that need to be resumed state = await agent.aget_state(config=config) interrupted_tasks = [ task for task in state.tasks if hasattr(task, "interrupts") and task.interrupts ] input: Command | dict[str, Any] if interrupted_tasks: # assume user input is response to resume agent execution from interrupt input = Command(resume=user_input.message) else: input = {"messages": [HumanMessage(content=user_input.message)]} kwargs = { "input": input, "config": config, } return kwargs, run_id @router.post("/{agent_id}/invoke", operation_id="invoke_with_agent_id") @router.post("/invoke") async def invoke(user_input: UserInput, agent_id: str = DEFAULT_AGENT) -> ChatMessage: """ Invoke an agent with user input to retrieve a final response. If agent_id is not provided, the default agent will be used. Use thread_id to persist and continue a multi-turn conversation. run_id kwarg is also attached to messages for recording feedback. Use user_id to persist and continue a conversation across multiple threads. """ # NOTE: Currently this only returns the last message or interrupt. # In the case of an agent outputting multiple AIMessages (such as the background step # in interrupt-agent, or a tool step in research-assistant), it's omitted. Arguably, # you'd want to include it. You could update the API to return a list of ChatMessages # in that case. agent: AgentGraph = get_agent(agent_id) kwargs, run_id = await _handle_input(user_input, agent) try: response_events: list[tuple[str, Any]] = await agent.ainvoke(**kwargs, stream_mode=["updates", "values"]) # type: ignore # fmt: skip response_type, response = response_events[-1] if response_type == "values": # Normal response, the agent completed successfully output = langchain_to_chat_message(response["messages"][-1]) elif response_type == "updates" and "__interrupt__" in response: # The last thing to occur was an interrupt # Return the value of the first interrupt as an AIMessage output = langchain_to_chat_message( AIMessage(content=response["__interrupt__"][0].value) ) else: raise ValueError(f"Unexpected response type: {response_type}") output.run_id = str(run_id) return output except Exception as e: logger.error(f"An exception occurred: {e}") raise HTTPException(status_code=500, detail="Unexpected error") async def message_generator( user_input: StreamInput, agent_id: str = DEFAULT_AGENT ) -> AsyncGenerator[str, None]: """ Generate a stream of messages from the agent. This is the workhorse method for the /stream endpoint. """ agent: AgentGraph = get_agent(agent_id) kwargs, run_id = await _handle_input(user_input, agent) try: # Process streamed events from the graph and yield messages over the SSE stream. async for stream_event in agent.astream( **kwargs, stream_mode=["updates", "messages", "custom"], subgraphs=True ): if not isinstance(stream_event, tuple): continue # Handle different stream event structures based on subgraphs if len(stream_event) == 3: # With subgraphs=True: (node_path, stream_mode, event) _, stream_mode, event = stream_event else: # Without subgraphs: (stream_mode, event) stream_mode, event = stream_event new_messages = [] if stream_mode == "updates": for node, updates in event.items(): # A simple approach to handle agent interrupts. # In a more sophisticated implementation, we could add # some structured ChatMessage type to return the interrupt value. if node == "__interrupt__": interrupt: Interrupt for interrupt in updates: new_messages.append(AIMessage(content=interrupt.value)) continue updates = updates or {} update_messages = updates.get("messages", []) # special cases for using langgraph-supervisor library if "supervisor" in node or "sub-agent" in node: # the only tools that come from the actual agent are the handoff and handback tools if isinstance(update_messages[-1], ToolMessage): if "sub-agent" in node and len(update_messages) > 1: # If this is a sub-agent, we want to keep the last 2 messages - the handback tool, and it's result update_messages = update_messages[-2:] else: # If this is a supervisor, we want to keep the last message only - the handoff result. The tool comes from the 'agent' node. update_messages = [update_messages[-1]] else: update_messages = [] new_messages.extend(update_messages) if stream_mode == "custom": new_messages = [event] # Send update events for non-message updates (like follow_up) if stream_mode == "updates" and hasattr(event, "items"): for node, updates in event.items(): if updates: other_updates = {k: v for k, v in updates.items() if k != "messages"} if other_updates: yield f"data: {json.dumps({'type': 'update', 'node': node, 'updates': other_updates})}\n\n" # LangGraph streaming may emit tuples: (field_name, field_value) # e.g. ('content', ), ('tool_calls', [ToolCall,...]), ('additional_kwargs', {...}), etc. # We accumulate only supported fields into `parts` and skip unsupported metadata. # More info at: https://langchain-ai.github.io/langgraph/cloud/how-tos/stream_messages/ processed_messages = [] current_message: dict[str, Any] = {} for message in new_messages: if isinstance(message, tuple): key, value = message # Store parts in temporary dict current_message[key] = value else: # Add complete message if we have one in progress if current_message: processed_messages.append(_create_ai_message(current_message)) current_message = {} processed_messages.append(message) # Add any remaining message parts if current_message: processed_messages.append(_create_ai_message(current_message)) for message in processed_messages: try: chat_message = langchain_to_chat_message(message) chat_message.run_id = str(run_id) except Exception as e: logger.error(f"Error parsing message: {e}") yield f"data: {json.dumps({'type': 'error', 'content': 'Unexpected error'})}\n\n" continue # LangGraph re-sends the input message, which feels weird, so drop it if chat_message.type == "human" and chat_message.content == user_input.message: continue yield f"data: {json.dumps({'type': 'message', 'content': chat_message.model_dump()})}\n\n" if stream_mode == "messages": if not user_input.stream_tokens: continue msg, metadata = event if "skip_stream" in metadata.get("tags", []): continue # For some reason, astream("messages") causes non-LLM nodes to send extra messages. # Drop them. if not isinstance(msg, AIMessageChunk): continue content = remove_tool_calls(msg.content) if content: # Empty content in the context of OpenAI usually means # that the model is asking for a tool to be invoked. # So we only print non-empty content. yield f"data: {json.dumps({'type': 'token', 'content': convert_message_content_to_string(content)})}\n\n" except Exception as e: logger.error(f"Error in message generator: {e}") yield f"data: {json.dumps({'type': 'error', 'content': 'Internal server error'})}\n\n" finally: yield "data: [DONE]\n\n" def _create_ai_message(parts: dict) -> AIMessage: sig = inspect.signature(AIMessage) valid_keys = set(sig.parameters) filtered = {k: v for k, v in parts.items() if k in valid_keys} return AIMessage(**filtered) def _sse_response_example() -> dict[int | str, Any]: return { status.HTTP_200_OK: { "description": "Server Sent Event Response", "content": { "text/event-stream": { "example": "data: {'type': 'token', 'content': 'Hello'}\n\ndata: {'type': 'token', 'content': ' World'}\n\ndata: [DONE]\n\n", "schema": {"type": "string"}, } }, } } @router.post( "/{agent_id}/stream", response_class=StreamingResponse, responses=_sse_response_example(), operation_id="stream_with_agent_id", ) @router.post("/stream", response_class=StreamingResponse, responses=_sse_response_example()) async def stream(user_input: StreamInput, agent_id: str = DEFAULT_AGENT) -> StreamingResponse: """ Stream an agent's response to a user input, including intermediate messages and tokens. If agent_id is not provided, the default agent will be used. Use thread_id to persist and continue a multi-turn conversation. run_id kwarg is also attached to all messages for recording feedback. Use user_id to persist and continue a conversation across multiple threads. Set `stream_tokens=false` to return intermediate messages but not token-by-token. """ return StreamingResponse( message_generator(user_input, agent_id), media_type="text/event-stream", ) @router.post("/feedback") async def feedback(feedback: Feedback) -> FeedbackResponse: """ Record feedback for a run to LangSmith. This is a simple wrapper for the LangSmith create_feedback API, so the credentials can be stored and managed in the service rather than the client. See: https://api.smith.langchain.com/redoc#tag/feedback/operation/create_feedback_api_v1_feedback_post """ client = LangsmithClient() kwargs = feedback.kwargs or {} client.create_feedback( run_id=feedback.run_id, key=feedback.key, score=feedback.score, **kwargs, ) return FeedbackResponse() @router.post("/history") async def history(input: ChatHistoryInput) -> ChatHistory: """ Get chat history. """ # TODO: Hard-coding DEFAULT_AGENT here is wonky agent: AgentGraph = get_agent(DEFAULT_AGENT) try: state_snapshot = await agent.aget_state( config=RunnableConfig(configurable={"thread_id": input.thread_id}) ) messages: list[AnyMessage] = state_snapshot.values["messages"] chat_messages: list[ChatMessage] = [langchain_to_chat_message(m) for m in messages] return ChatHistory(messages=chat_messages) except Exception as e: logger.error(f"An exception occurred: {e}") raise HTTPException(status_code=500, detail="Unexpected error") @app.get("/health") async def health_check(): """Health check endpoint.""" health_status = {"status": "ok"} if settings.LANGFUSE_TRACING: try: langfuse = Langfuse() health_status["langfuse"] = "connected" if langfuse.auth_check() else "disconnected" except Exception as e: logger.error(f"Langfuse connection error: {e}") health_status["langfuse"] = "disconnected" return health_status app.include_router(router)