merithalle-ai / services /agent_service.py
Cyril Dupland
Add SynthesisAgent functionality to A3 workflow, including new prompts and payload fields. Introduce synthesis model in A3Payload, update AgentState to include synthesis output, and create synthesis_node for generating executive summaries. Enhance agent service to support specific LLM configurations for synthesis, improving overall data analysis capabilities.
043f287
"""Agent service for executing LangGraph agents."""
from typing import Optional, AsyncIterator, List, Dict, Any, Union
import time
from langchain_core.messages import AIMessageChunk, HumanMessage, AIMessage, BaseMessage, SystemMessage
from langchain_core.language_models.chat_models import BaseChatModel
from domain.enums import ModelName, AgentType
from domain.payloads.base import BaseAgentPayload
from .llm_service import llm_service
from .agent_registry import agent_registry
from services.postprocessing.registry import build_orchestrator
from services.postprocessing.context import RunContext
class AgentService:
"""
Service for executing agent graphs with different LLMs.
This service is the bridge between the API layer and the LangGraph agents.
It handles:
- Creating the right LLM based on model selection
- Getting the right agent graph from the registry
- Executing the graph with or without streaming
- Supporting agent-specific LLM configurations (e.g., coherence_model for A2)
"""
def __init__(self):
"""Initialize the agent service."""
pass
async def invoke(
self,
payload: BaseAgentPayload,
model_name: ModelName,
agent_type: AgentType = AgentType.A2,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
) -> dict:
"""
Invoke agent for a single response (non-streaming).
Args:
payload: Validated agent payload containing all agent-specific data
model_name: LLM model to use
agent_type: Type of agent graph
temperature: Sampling temperature
max_tokens: Max tokens to generate
Returns:
Response dictionary with content and metadata
"""
print("\n" + "#"*70, flush=True)
print(f"# AGENT EXECUTION START - {agent_type.value}", flush=True)
print(f"# Model: {model_name.value} | Temperature: {temperature}", flush=True)
print("#"*70, flush=True)
# Create default LLM instance
llm = llm_service.get_llm(
model_name=model_name,
temperature=temperature,
streaming=False,
max_tokens=max_tokens
)
# Convert payload to dict for state
payload_dict = payload.model_dump()
print(f"[AGENT_SERVICE] Payload keys: {list(payload_dict.keys())}", flush=True)
# Handle agent-specific LLM configurations
coherence_llm = self._get_coherence_llm(payload_dict, temperature, max_tokens)
analysis_llm = self._get_analysis_llm(payload_dict, temperature, max_tokens)
synthesis_llm = self._get_synthesis_llm(payload_dict, temperature, max_tokens)
if coherence_llm:
print(f"[AGENT_SERVICE] Using specific coherence_model: {payload_dict.get('coherence_model')}", flush=True)
if analysis_llm:
print(f"[AGENT_SERVICE] Using specific analysis_model: {payload_dict.get('analysis_model')}", flush=True)
if synthesis_llm:
print(f"[AGENT_SERVICE] Using specific synthesis_model: {payload_dict.get('synthesis_model')}", flush=True)
# Get agent builder and create graph with appropriate LLMs
builder = agent_registry.get_builder(agent_type)
# Pass specific LLMs if the agent supports them
if agent_type == AgentType.A2 and coherence_llm:
graph = builder(llm, coherence_llm=coherence_llm)
elif agent_type == AgentType.A3:
# A3 supports both analysis_llm and synthesis_llm
kwargs = {}
if analysis_llm:
kwargs["analysis_llm"] = analysis_llm
if synthesis_llm:
kwargs["synthesis_llm"] = synthesis_llm
graph = builder(llm, **kwargs) if kwargs else builder(llm)
else:
graph = builder(llm)
# Prepare messages from payload (for agents that need them)
messages = self._prepare_messages_from_payload(payload_dict)
print(f"[AGENT_SERVICE] Initial messages prepared: {len(messages)}", flush=True)
# Execute graph with latency
print(f"[AGENT_SERVICE] Starting graph execution...", flush=True)
start_time = time.time()
result = await graph.ainvoke({
"messages": messages,
"payload": payload_dict
})
latency_s = time.time() - start_time
print(f"[AGENT_SERVICE] Graph execution completed in {latency_s:.2f}s", flush=True)
# Extract response - use agent_output if available, otherwise fallback to last message
response_message = result["messages"][-1]
agent_output = result.get("agent_output")
print(f"[AGENT_SERVICE] Last message type: {type(response_message).__name__}", flush=True)
print(f"[AGENT_SERVICE] Last message content preview: {str(response_message.content)[:100]}...", flush=True)
if agent_output is not None:
# Agent provides structured output
response_content: Any = agent_output
print(f"[AGENT_SERVICE] Using agent_output: {list(agent_output.keys())}", flush=True)
else:
# Fallback to message content for simple agents
response_content = response_message.content
print(f"[AGENT_SERVICE] Using message content (no agent_output)", flush=True)
# Get usage from total_usage in state (accumulated across all nodes)
total_usage = result.get("total_usage", {})
usage_by_model = result.get("usage_by_model", {})
if total_usage:
print(f"[AGENT_SERVICE] Using total_usage from state: {total_usage}", flush=True)
usage_totals = total_usage
else:
# Fallback to last message usage (for backward compatibility)
usage = getattr(response_message, "usage_metadata", None) or {}
print(f"[AGENT_SERVICE] Fallback to message usage_metadata: {usage}", flush=True)
usage_totals = self._normalize_usage(usage)
# If no usage_by_model from state, create it from default model
if not usage_by_model:
usage_by_model = {model_name.value: usage_totals}
print(f"[AGENT_SERVICE] Final usage_totals: {usage_totals}", flush=True)
print(f"[AGENT_SERVICE] Usage by model: {usage_by_model}", flush=True)
ctx = RunContext(
provider=model_name.provider.value,
model=model_name.value,
usage_totals=usage_totals,
usage_by_model=usage_by_model,
latency_s=latency_s,
)
build_orchestrator().run(ctx)
base_metadata: Dict[str, Any] = {
"message_count": len(result["messages"]),
}
base_metadata.update(ctx.metadata_out)
# Add token usage breakdown by model to metadata
if usage_by_model:
base_metadata["usage_by_model"] = usage_by_model
# Add individual token counts for easy access
base_metadata["input_tokens"] = usage_totals.get("input_tokens", 0)
base_metadata["output_tokens"] = usage_totals.get("output_tokens", 0)
base_metadata["total_tokens"] = usage_totals.get("total_tokens", 0)
print("\n" + "#"*70, flush=True)
print(f"# AGENT EXECUTION END - {agent_type.value}", flush=True)
print(f"# Latency: {latency_s:.2f}s | Messages: {len(result['messages'])}", flush=True)
print(f"# Usage totals: {usage_totals}", flush=True)
if agent_output:
print(f"# Agent output keys: {list(agent_output.keys())}", flush=True)
print("#"*70 + "\n", flush=True)
return {
"response": response_content,
"model": model_name.value,
"agent_type": agent_type.value,
"metadata": base_metadata,
}
async def stream(
self,
payload: BaseAgentPayload,
model_name: ModelName,
agent_type: AgentType = AgentType.A2,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
) -> AsyncIterator[dict]:
"""
Stream agent response token by token.
Args:
payload: Validated agent payload containing all agent-specific data
model_name: LLM model to use
agent_type: Type of agent graph
temperature: Sampling temperature
max_tokens: Max tokens to generate
Yields:
Dictionary chunks with content and metadata
"""
print("\n" + "#"*70)
print(f"# AGENT STREAM START - {agent_type.value}")
print(f"# Model: {model_name.value} | Temperature: {temperature}")
print("#"*70)
# Create default LLM instance with streaming enabled
llm = llm_service.get_llm(
model_name=model_name,
temperature=temperature,
streaming=True,
max_tokens=max_tokens
)
# Convert payload to dict for state
payload_dict = payload.model_dump()
print(f"[AGENT_SERVICE] Payload keys: {list(payload_dict.keys())}")
# Handle agent-specific LLM configurations
coherence_llm = self._get_coherence_llm(payload_dict, temperature, max_tokens, streaming=True)
analysis_llm = self._get_analysis_llm(payload_dict, temperature, max_tokens, streaming=True)
synthesis_llm = self._get_synthesis_llm(payload_dict, temperature, max_tokens, streaming=True)
if coherence_llm:
print(f"[AGENT_SERVICE] Using specific coherence_model: {payload_dict.get('coherence_model')}")
if analysis_llm:
print(f"[AGENT_SERVICE] Using specific analysis_model: {payload_dict.get('analysis_model')}")
if synthesis_llm:
print(f"[AGENT_SERVICE] Using specific synthesis_model: {payload_dict.get('synthesis_model')}")
# Get agent builder and create graph with appropriate LLMs
builder = agent_registry.get_builder(agent_type)
# Pass specific LLMs if the agent supports them
if agent_type == AgentType.A2 and coherence_llm:
graph = builder(llm, coherence_llm=coherence_llm)
elif agent_type == AgentType.A3:
# A3 supports both analysis_llm and synthesis_llm
kwargs = {}
if analysis_llm:
kwargs["analysis_llm"] = analysis_llm
if synthesis_llm:
kwargs["synthesis_llm"] = synthesis_llm
graph = builder(llm, **kwargs) if kwargs else builder(llm)
else:
graph = builder(llm)
# Prepare messages from payload
messages = self._prepare_messages_from_payload(payload_dict)
print(f"[AGENT_SERVICE] Initial messages prepared: {len(messages)}")
print(f"[AGENT_SERVICE] Starting streaming graph execution...")
# Track usage and latency for final emissions calculation
usage_totals: Dict[str, int] = {}
usage_by_model: Dict[str, Dict[str, int]] = {}
start_time = time.time()
documents = []
# Track agent_output - will be captured from any node that produces it
captured_agent_output: Optional[Dict[str, Any]] = None
# Stream graph execution
async for msg in graph.astream({
"messages": messages,
"payload": payload_dict
}, stream_mode=["messages","updates"]):
# LangGraph may yield (node_name, message) tuples in messages mode
event = None
params = None
# Only emit assistant outputs; ignore user/history echoes
text: Optional[str] = None
if msg[0] == "messages":
chunk = msg[1]
if isinstance(chunk, tuple) and len(chunk) == 2:
# Prefer the BaseMessage element if present
from langchain_core.messages import BaseMessage as _LCBaseMessage
if isinstance(chunk[1], _LCBaseMessage):
event = chunk[1]
params = chunk[0]
elif isinstance(chunk[0], _LCBaseMessage):
event = chunk[0]
params = chunk[1]
else:
# Fallback to second element by convention
event = chunk[1]
else:
event = chunk
if msg[0] == "updates":
node = msg[1]
# Capture agent_output, total_usage, and usage_by_model from any node
if isinstance(node, dict):
for node_name, node_output in node.items():
if isinstance(node_output, dict):
if "agent_output" in node_output:
captured_agent_output = node_output["agent_output"]
print(f"[AGENT_SERVICE] Captured agent_output from node '{node_name}'", flush=True)
if "total_usage" in node_output:
usage_totals = node_output["total_usage"]
print(f"[AGENT_SERVICE] Captured total_usage from node '{node_name}': {usage_totals}", flush=True)
if "usage_by_model" in node_output:
usage_by_model = node_output["usage_by_model"]
print(f"[AGENT_SERVICE] Captured usage_by_model from node '{node_name}': {usage_by_model}", flush=True)
# Handle summarizer_export (existing logic for documents)
summarizer_export = node.get("summarizer_export")
if summarizer_export and isinstance(summarizer_export, dict):
node_messages = summarizer_export.get("messages", [])
last_message = node_messages[-1] if node_messages else None
if isinstance(last_message, AIMessage):
text = self._extract_text_content(last_message.content)
doc_meta = last_message.metadata.get("document") if last_message.metadata else None
if doc_meta is not None:
documents.append(doc_meta)
if isinstance(event, AIMessageChunk):
text = self._extract_text_content(event.content)
# Capture usage if present on chunks
try:
chunk_usage = getattr(event, "usage_metadata", None)
if isinstance(chunk_usage, dict):
norm = self._normalize_usage(chunk_usage)
model_id = self._extract_model_from_params(params) or model_name.value
bucket = usage_by_model.setdefault(model_id, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
bucket["input_tokens"] += norm["input_tokens"]
bucket["output_tokens"] += norm["output_tokens"]
bucket["total_tokens"] += norm["total_tokens"]
usage_totals["input_tokens"] = usage_totals.get("input_tokens", 0) + norm["input_tokens"]
usage_totals["output_tokens"] = usage_totals.get("output_tokens", 0) + norm["output_tokens"]
usage_totals["total_tokens"] = usage_totals.get("total_tokens", 0) + norm["total_tokens"]
except Exception:
pass
else:
# Not an assistant output we should stream (e.g., HumanMessage)
continue
if text:
yield {
"content": text,
"done": False,
"metadata": {
"model": model_name.value,
"agent_type": agent_type.value,
"usage": usage_totals
},
"documents": documents
}
# Compute latency and run post-processing pipeline for the final chunk
latency_s = time.time() - start_time
ctx = RunContext(
provider=model_name.provider.value,
model=model_name.value,
usage_totals=usage_totals,
usage_by_model=usage_by_model,
latency_s=latency_s,
)
build_orchestrator().run(ctx)
# Build final metadata
final_metadata = {
"model": model_name.value,
"agent_type": agent_type.value,
"usage": usage_totals,
"usage_by_model": usage_by_model,
"latency_s": latency_s,
**ctx.metadata_out
}
# Add token usage breakdown by model to metadata
if usage_by_model:
final_metadata["usage_by_model"] = usage_by_model
# Add individual token counts for easy access
final_metadata["input_tokens"] = usage_totals.get("input_tokens", 0)
final_metadata["output_tokens"] = usage_totals.get("output_tokens", 0)
final_metadata["total_tokens"] = usage_totals.get("total_tokens", 0)
print("\n" + "#"*70, flush=True)
print(f"# AGENT STREAM END - {agent_type.value}", flush=True)
print(f"# Latency: {latency_s:.2f}s", flush=True)
print(f"# Usage totals: {usage_totals}", flush=True)
if usage_by_model:
print(f"# Usage by model: {usage_by_model}", flush=True)
if captured_agent_output:
print(f"# Agent output keys: {list(captured_agent_output.keys())}", flush=True)
print("#"*70 + "\n", flush=True)
# Build final chunk
final_chunk: Dict[str, Any] = {
"content": "",
"done": True,
"metadata": final_metadata,
"documents": documents
}
# Add structured response if agent provided agent_output
if captured_agent_output is not None:
final_chunk["response"] = captured_agent_output
# Send final chunk
yield final_chunk
def _get_coherence_llm(
self,
payload_dict: Dict[str, Any],
temperature: float,
max_tokens: Optional[int],
streaming: bool = False
) -> Optional[BaseChatModel]:
"""
Get a specific LLM for the CoherenceAgent if specified in payload.
Args:
payload_dict: Payload dictionary
temperature: Sampling temperature
max_tokens: Max tokens to generate
streaming: Whether to enable streaming
Returns:
LLM instance if coherence_model is specified, None otherwise
"""
coherence_model_str = payload_dict.get("coherence_model")
if not coherence_model_str:
return None
try:
coherence_model_name = ModelName(coherence_model_str)
return llm_service.get_llm(
model_name=coherence_model_name,
temperature=temperature,
streaming=streaming,
max_tokens=max_tokens
)
except ValueError:
# Invalid model name, return None to use default
return None
def _get_analysis_llm(
self,
payload_dict: Dict[str, Any],
temperature: float,
max_tokens: Optional[int],
streaming: bool = False
) -> Optional[BaseChatModel]:
"""
Get a specific LLM for the AnalysisAgent if specified in payload.
Args:
payload_dict: Payload dictionary
temperature: Sampling temperature
max_tokens: Max tokens to generate
streaming: Whether to enable streaming
Returns:
LLM instance if analysis_model is specified, None otherwise
"""
analysis_model_str = payload_dict.get("analysis_model")
if not analysis_model_str:
return None
try:
analysis_model_name = ModelName(analysis_model_str)
return llm_service.get_llm(
model_name=analysis_model_name,
temperature=temperature,
streaming=streaming,
max_tokens=max_tokens
)
except ValueError:
# Invalid model name, return None to use default
return None
def _get_synthesis_llm(
self,
payload_dict: Dict[str, Any],
temperature: float,
max_tokens: Optional[int],
streaming: bool = False
) -> Optional[BaseChatModel]:
"""
Get a specific LLM for the SynthesisAgent if specified in payload.
Args:
payload_dict: Payload dictionary
temperature: Sampling temperature
max_tokens: Max tokens to generate
streaming: Whether to enable streaming
Returns:
LLM instance if synthesis_model is specified, None otherwise
"""
synthesis_model_str = payload_dict.get("synthesis_model")
if not synthesis_model_str:
return None
try:
synthesis_model_name = ModelName(synthesis_model_str)
return llm_service.get_llm(
model_name=synthesis_model_name,
temperature=temperature,
streaming=streaming,
max_tokens=max_tokens
)
except ValueError:
# Invalid model name, return None to use default
return None
def _prepare_messages_from_payload(
self,
payload: Dict[str, Any]
) -> List[BaseMessage]:
"""
Prepare messages list from payload data.
Extracts 'query' and 'conversation_history' from the payload
if they exist. Agents can define their own structure.
Args:
payload: Validated payload dictionary
Returns:
List of LangChain messages
"""
messages = []
# Extract conversation history if present
conversation_history = payload.get("conversation_history")
if conversation_history:
for msg in conversation_history:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "user":
messages.append(HumanMessage(content=content))
elif role == "assistant":
messages.append(AIMessage(content=content))
elif role == "system":
messages.append(SystemMessage(content=content))
# Extract query if present and add as current message
query = payload.get("query")
if query:
messages.append(HumanMessage(content=query))
return messages
def _extract_text_content(self, content: object) -> Optional[str]:
"""
Normalize LangChain message content into a plain text string.
Handles both string content and list-structured content with text parts.
"""
if content is None:
return None
if isinstance(content, str):
return content
if isinstance(content, list):
# LangChain can represent content as a list of parts like
# [{"type": "text", "text": "..."}, ...]
text_parts: List[str] = []
for part in content:
try:
# Dict-like parts with type/text
if isinstance(part, dict):
if part.get("type") == "text" and isinstance(part.get("text"), str):
text_parts.append(part["text"])
# Object-like parts with attributes
elif hasattr(part, "type") and getattr(part, "type") == "text" and hasattr(part, "text"):
value = getattr(part, "text")
if isinstance(value, str):
text_parts.append(value)
except Exception:
# Skip any malformed parts
continue
return "".join(text_parts) if text_parts else None
# Fallback: unknown content structure
return None
def _normalize_usage(self, usage: Dict[str, Any]) -> Dict[str, int]:
"""Normalize usage keys to input/output/total integers.
Supports variants like prompt_tokens/completion_tokens.
"""
try:
input_val = usage.get("input_tokens")
if not isinstance(input_val, (int, float)):
input_val = usage.get("prompt_tokens", 0)
output_val = usage.get("output_tokens")
if not isinstance(output_val, (int, float)):
output_val = usage.get("completion_tokens", 0)
total_val = usage.get("total_tokens")
if not isinstance(total_val, (int, float)):
total_val = (int(input_val or 0)) + (int(output_val or 0))
return {
"input_tokens": int(input_val or 0),
"output_tokens": int(output_val or 0),
"total_tokens": int(total_val or 0),
}
except Exception:
return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
def _extract_model_from_params(self, params: Optional[Dict[str, Any]]) -> Optional[str]:
"""Best-effort extraction of model identifier from LangGraph params."""
if not isinstance(params, dict):
return None
keys = ("ls_model_name", "model", "model_name", "model_id", "name", "llm_model", "openai_model")
for key in keys:
val = params.get(key)
if isinstance(val, str) and val:
return val
# Search likely nested containers
for container_key in ("configuration", "config", "kwargs", "meta", "metadata"):
sub = params.get(container_key)
if isinstance(sub, dict):
for key in keys:
val = sub.get(key)
if isinstance(val, str) and val:
return val
return None
# Singleton instance
agent_service = AgentService()