|
|
import inspect |
|
|
import json |
|
|
import logging |
|
|
import warnings |
|
|
from collections.abc import AsyncGenerator |
|
|
from contextlib import asynccontextmanager |
|
|
from typing import Annotated, Any |
|
|
from uuid import UUID, uuid4 |
|
|
|
|
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, status |
|
|
from fastapi.responses import StreamingResponse |
|
|
from fastapi.routing import APIRoute |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer |
|
|
from langchain_core._api import LangChainBetaWarning |
|
|
from langchain_core.messages import AIMessage, AIMessageChunk, AnyMessage, HumanMessage, ToolMessage |
|
|
from langchain_core.runnables import RunnableConfig |
|
|
from langfuse import Langfuse |
|
|
from langfuse.langchain import ( |
|
|
CallbackHandler, |
|
|
) |
|
|
from langgraph.types import Command, Interrupt |
|
|
from langsmith import Client as LangsmithClient |
|
|
|
|
|
from agents import DEFAULT_AGENT, AgentGraph, get_agent, get_all_agent_info, load_agent |
|
|
from core import settings |
|
|
from memory import initialize_database, initialize_store |
|
|
from schema import ( |
|
|
ChatHistory, |
|
|
ChatHistoryInput, |
|
|
ChatMessage, |
|
|
Feedback, |
|
|
FeedbackResponse, |
|
|
ServiceMetadata, |
|
|
StreamInput, |
|
|
UserInput, |
|
|
) |
|
|
from service.utils import ( |
|
|
convert_message_content_to_string, |
|
|
langchain_to_chat_message, |
|
|
remove_tool_calls, |
|
|
) |
|
|
|
|
|
warnings.filterwarnings("ignore", category=LangChainBetaWarning) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def custom_generate_unique_id(route: APIRoute) -> str: |
|
|
"""Generate idiomatic operation IDs for OpenAPI client generation.""" |
|
|
return route.name |
|
|
|
|
|
|
|
|
def verify_bearer( |
|
|
http_auth: Annotated[ |
|
|
HTTPAuthorizationCredentials | None, |
|
|
Depends(HTTPBearer(description="Please provide AUTH_SECRET api key.", auto_error=False)), |
|
|
], |
|
|
) -> None: |
|
|
if not settings.AUTH_SECRET: |
|
|
return |
|
|
auth_secret = settings.AUTH_SECRET.get_secret_value() |
|
|
if not http_auth or http_auth.credentials != auth_secret: |
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: |
|
|
""" |
|
|
Configurable lifespan that initializes the appropriate database checkpointer, store, |
|
|
and agents with async loading - for example for starting up MCP clients. |
|
|
""" |
|
|
try: |
|
|
|
|
|
async with initialize_database() as saver, initialize_store() as store: |
|
|
|
|
|
if hasattr(saver, "setup"): |
|
|
await saver.setup() |
|
|
|
|
|
if hasattr(store, "setup"): |
|
|
await store.setup() |
|
|
|
|
|
|
|
|
agents = get_all_agent_info() |
|
|
for a in agents: |
|
|
try: |
|
|
await load_agent(a.key) |
|
|
logger.info(f"Agent loaded: {a.key}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load agent {a.key}: {e}") |
|
|
|
|
|
|
|
|
agent = get_agent(a.key) |
|
|
|
|
|
agent.checkpointer = saver |
|
|
|
|
|
agent.store = store |
|
|
yield |
|
|
except Exception as e: |
|
|
logger.error(f"Error during database/store/agents initialization: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan, generate_unique_id_function=custom_generate_unique_id) |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=settings.CORS_ORIGINS, |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
router = APIRouter(dependencies=[Depends(verify_bearer)]) |
|
|
|
|
|
|
|
|
@router.get("/info") |
|
|
async def info() -> ServiceMetadata: |
|
|
models = list(settings.AVAILABLE_MODELS) |
|
|
models.sort() |
|
|
return ServiceMetadata( |
|
|
agents=get_all_agent_info(), |
|
|
models=models, |
|
|
default_agent=DEFAULT_AGENT, |
|
|
default_model=settings.DEFAULT_MODEL, |
|
|
) |
|
|
|
|
|
|
|
|
async def _handle_input(user_input: UserInput, agent: AgentGraph) -> tuple[dict[str, Any], UUID]: |
|
|
""" |
|
|
Parse user input and handle any required interrupt resumption. |
|
|
Returns kwargs for agent invocation and the run_id. |
|
|
""" |
|
|
run_id = uuid4() |
|
|
thread_id = user_input.thread_id or str(uuid4()) |
|
|
user_id = user_input.user_id or str(uuid4()) |
|
|
|
|
|
configurable = {"thread_id": thread_id, "user_id": user_id} |
|
|
if user_input.model is not None: |
|
|
configurable["model"] = user_input.model |
|
|
|
|
|
callbacks: list[Any] = [] |
|
|
if settings.LANGFUSE_TRACING: |
|
|
|
|
|
langfuse_handler = CallbackHandler() |
|
|
|
|
|
callbacks.append(langfuse_handler) |
|
|
|
|
|
if user_input.agent_config: |
|
|
|
|
|
reserved_keys = {"thread_id", "user_id", "model"} |
|
|
if overlap := reserved_keys & user_input.agent_config.keys(): |
|
|
raise HTTPException( |
|
|
status_code=422, |
|
|
detail=f"agent_config contains reserved keys: {overlap}", |
|
|
) |
|
|
configurable.update(user_input.agent_config) |
|
|
|
|
|
config = RunnableConfig( |
|
|
configurable=configurable, |
|
|
run_id=run_id, |
|
|
callbacks=callbacks, |
|
|
) |
|
|
|
|
|
|
|
|
state = await agent.aget_state(config=config) |
|
|
interrupted_tasks = [ |
|
|
task for task in state.tasks if hasattr(task, "interrupts") and task.interrupts |
|
|
] |
|
|
|
|
|
input: Command | dict[str, Any] |
|
|
if interrupted_tasks: |
|
|
|
|
|
input = Command(resume=user_input.message) |
|
|
else: |
|
|
input = {"messages": [HumanMessage(content=user_input.message)]} |
|
|
|
|
|
kwargs = { |
|
|
"input": input, |
|
|
"config": config, |
|
|
} |
|
|
|
|
|
return kwargs, run_id |
|
|
|
|
|
|
|
|
@router.post("/{agent_id}/invoke", operation_id="invoke_with_agent_id") |
|
|
@router.post("/invoke") |
|
|
async def invoke(user_input: UserInput, agent_id: str = DEFAULT_AGENT) -> ChatMessage: |
|
|
""" |
|
|
Invoke an agent with user input to retrieve a final response. |
|
|
|
|
|
If agent_id is not provided, the default agent will be used. |
|
|
Use thread_id to persist and continue a multi-turn conversation. run_id kwarg |
|
|
is also attached to messages for recording feedback. |
|
|
Use user_id to persist and continue a conversation across multiple threads. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent: AgentGraph = get_agent(agent_id) |
|
|
kwargs, run_id = await _handle_input(user_input, agent) |
|
|
|
|
|
try: |
|
|
response_events: list[tuple[str, Any]] = await agent.ainvoke(**kwargs, stream_mode=["updates", "values"]) |
|
|
response_type, response = response_events[-1] |
|
|
if response_type == "values": |
|
|
|
|
|
output = langchain_to_chat_message(response["messages"][-1]) |
|
|
elif response_type == "updates" and "__interrupt__" in response: |
|
|
|
|
|
|
|
|
output = langchain_to_chat_message( |
|
|
AIMessage(content=response["__interrupt__"][0].value) |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unexpected response type: {response_type}") |
|
|
|
|
|
output.run_id = str(run_id) |
|
|
return output |
|
|
except Exception as e: |
|
|
logger.error(f"An exception occurred: {e}") |
|
|
raise HTTPException(status_code=500, detail="Unexpected error") |
|
|
|
|
|
|
|
|
async def message_generator( |
|
|
user_input: StreamInput, agent_id: str = DEFAULT_AGENT |
|
|
) -> AsyncGenerator[str, None]: |
|
|
""" |
|
|
Generate a stream of messages from the agent. |
|
|
|
|
|
This is the workhorse method for the /stream endpoint. |
|
|
""" |
|
|
agent: AgentGraph = get_agent(agent_id) |
|
|
kwargs, run_id = await _handle_input(user_input, agent) |
|
|
|
|
|
try: |
|
|
|
|
|
async for stream_event in agent.astream( |
|
|
**kwargs, stream_mode=["updates", "messages", "custom"], subgraphs=True |
|
|
): |
|
|
if not isinstance(stream_event, tuple): |
|
|
continue |
|
|
|
|
|
if len(stream_event) == 3: |
|
|
|
|
|
_, stream_mode, event = stream_event |
|
|
else: |
|
|
|
|
|
stream_mode, event = stream_event |
|
|
new_messages = [] |
|
|
if stream_mode == "updates": |
|
|
for node, updates in event.items(): |
|
|
|
|
|
|
|
|
|
|
|
if node == "__interrupt__": |
|
|
interrupt: Interrupt |
|
|
for interrupt in updates: |
|
|
new_messages.append(AIMessage(content=interrupt.value)) |
|
|
continue |
|
|
updates = updates or {} |
|
|
update_messages = updates.get("messages", []) |
|
|
|
|
|
if "supervisor" in node or "sub-agent" in node: |
|
|
|
|
|
if isinstance(update_messages[-1], ToolMessage): |
|
|
if "sub-agent" in node and len(update_messages) > 1: |
|
|
|
|
|
update_messages = update_messages[-2:] |
|
|
else: |
|
|
|
|
|
update_messages = [update_messages[-1]] |
|
|
else: |
|
|
update_messages = [] |
|
|
new_messages.extend(update_messages) |
|
|
|
|
|
if stream_mode == "custom": |
|
|
new_messages = [event] |
|
|
|
|
|
|
|
|
if stream_mode == "updates" and hasattr(event, "items"): |
|
|
for node, updates in event.items(): |
|
|
if updates: |
|
|
other_updates = {k: v for k, v in updates.items() if k != "messages"} |
|
|
if other_updates: |
|
|
yield f"data: {json.dumps({'type': 'update', 'node': node, 'updates': other_updates})}\n\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processed_messages = [] |
|
|
current_message: dict[str, Any] = {} |
|
|
for message in new_messages: |
|
|
if isinstance(message, tuple): |
|
|
key, value = message |
|
|
|
|
|
current_message[key] = value |
|
|
else: |
|
|
|
|
|
if current_message: |
|
|
processed_messages.append(_create_ai_message(current_message)) |
|
|
current_message = {} |
|
|
processed_messages.append(message) |
|
|
|
|
|
|
|
|
if current_message: |
|
|
processed_messages.append(_create_ai_message(current_message)) |
|
|
|
|
|
for message in processed_messages: |
|
|
try: |
|
|
chat_message = langchain_to_chat_message(message) |
|
|
chat_message.run_id = str(run_id) |
|
|
except Exception as e: |
|
|
logger.error(f"Error parsing message: {e}") |
|
|
yield f"data: {json.dumps({'type': 'error', 'content': 'Unexpected error'})}\n\n" |
|
|
continue |
|
|
|
|
|
if chat_message.type == "human" and chat_message.content == user_input.message: |
|
|
continue |
|
|
yield f"data: {json.dumps({'type': 'message', 'content': chat_message.model_dump()})}\n\n" |
|
|
|
|
|
if stream_mode == "messages": |
|
|
if not user_input.stream_tokens: |
|
|
continue |
|
|
msg, metadata = event |
|
|
if "skip_stream" in metadata.get("tags", []): |
|
|
continue |
|
|
|
|
|
|
|
|
if not isinstance(msg, AIMessageChunk): |
|
|
continue |
|
|
content = remove_tool_calls(msg.content) |
|
|
if content: |
|
|
|
|
|
|
|
|
|
|
|
yield f"data: {json.dumps({'type': 'token', 'content': convert_message_content_to_string(content)})}\n\n" |
|
|
except Exception as e: |
|
|
logger.error(f"Error in message generator: {e}") |
|
|
yield f"data: {json.dumps({'type': 'error', 'content': 'Internal server error'})}\n\n" |
|
|
finally: |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
|
|
|
def _create_ai_message(parts: dict) -> AIMessage: |
|
|
sig = inspect.signature(AIMessage) |
|
|
valid_keys = set(sig.parameters) |
|
|
filtered = {k: v for k, v in parts.items() if k in valid_keys} |
|
|
return AIMessage(**filtered) |
|
|
|
|
|
|
|
|
def _sse_response_example() -> dict[int | str, Any]: |
|
|
return { |
|
|
status.HTTP_200_OK: { |
|
|
"description": "Server Sent Event Response", |
|
|
"content": { |
|
|
"text/event-stream": { |
|
|
"example": "data: {'type': 'token', 'content': 'Hello'}\n\ndata: {'type': 'token', 'content': ' World'}\n\ndata: [DONE]\n\n", |
|
|
"schema": {"type": "string"}, |
|
|
} |
|
|
}, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/{agent_id}/stream", |
|
|
response_class=StreamingResponse, |
|
|
responses=_sse_response_example(), |
|
|
operation_id="stream_with_agent_id", |
|
|
) |
|
|
@router.post("/stream", response_class=StreamingResponse, responses=_sse_response_example()) |
|
|
async def stream(user_input: StreamInput, agent_id: str = DEFAULT_AGENT) -> StreamingResponse: |
|
|
""" |
|
|
Stream an agent's response to a user input, including intermediate messages and tokens. |
|
|
|
|
|
If agent_id is not provided, the default agent will be used. |
|
|
Use thread_id to persist and continue a multi-turn conversation. run_id kwarg |
|
|
is also attached to all messages for recording feedback. |
|
|
Use user_id to persist and continue a conversation across multiple threads. |
|
|
|
|
|
Set `stream_tokens=false` to return intermediate messages but not token-by-token. |
|
|
""" |
|
|
return StreamingResponse( |
|
|
message_generator(user_input, agent_id), |
|
|
media_type="text/event-stream", |
|
|
) |
|
|
|
|
|
|
|
|
@router.post("/feedback") |
|
|
async def feedback(feedback: Feedback) -> FeedbackResponse: |
|
|
""" |
|
|
Record feedback for a run to LangSmith. |
|
|
|
|
|
This is a simple wrapper for the LangSmith create_feedback API, so the |
|
|
credentials can be stored and managed in the service rather than the client. |
|
|
See: https://api.smith.langchain.com/redoc#tag/feedback/operation/create_feedback_api_v1_feedback_post |
|
|
""" |
|
|
client = LangsmithClient() |
|
|
kwargs = feedback.kwargs or {} |
|
|
client.create_feedback( |
|
|
run_id=feedback.run_id, |
|
|
key=feedback.key, |
|
|
score=feedback.score, |
|
|
**kwargs, |
|
|
) |
|
|
return FeedbackResponse() |
|
|
|
|
|
|
|
|
@router.post("/history") |
|
|
async def history(input: ChatHistoryInput) -> ChatHistory: |
|
|
""" |
|
|
Get chat history. |
|
|
""" |
|
|
|
|
|
agent: AgentGraph = get_agent(DEFAULT_AGENT) |
|
|
try: |
|
|
state_snapshot = await agent.aget_state( |
|
|
config=RunnableConfig(configurable={"thread_id": input.thread_id}) |
|
|
) |
|
|
messages: list[AnyMessage] = state_snapshot.values["messages"] |
|
|
chat_messages: list[ChatMessage] = [langchain_to_chat_message(m) for m in messages] |
|
|
return ChatHistory(messages=chat_messages) |
|
|
except Exception as e: |
|
|
logger.error(f"An exception occurred: {e}") |
|
|
raise HTTPException(status_code=500, detail="Unexpected error") |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint.""" |
|
|
|
|
|
health_status = {"status": "ok"} |
|
|
|
|
|
if settings.LANGFUSE_TRACING: |
|
|
try: |
|
|
langfuse = Langfuse() |
|
|
health_status["langfuse"] = "connected" if langfuse.auth_check() else "disconnected" |
|
|
except Exception as e: |
|
|
logger.error(f"Langfuse connection error: {e}") |
|
|
health_status["langfuse"] = "disconnected" |
|
|
|
|
|
return health_status |
|
|
|
|
|
|
|
|
app.include_router(router) |
|
|
|