Spaces:
Sleeping
Sleeping
Cyril Dupland
Add SynthesisAgent functionality to A3 workflow, including new prompts and payload fields. Introduce synthesis model in A3Payload, update AgentState to include synthesis output, and create synthesis_node for generating executive summaries. Enhance agent service to support specific LLM configurations for synthesis, improving overall data analysis capabilities.
043f287 | """Agent service for executing LangGraph agents.""" | |
| from typing import Optional, AsyncIterator, List, Dict, Any, Union | |
| import time | |
| from langchain_core.messages import AIMessageChunk, HumanMessage, AIMessage, BaseMessage, SystemMessage | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from domain.enums import ModelName, AgentType | |
| from domain.payloads.base import BaseAgentPayload | |
| from .llm_service import llm_service | |
| from .agent_registry import agent_registry | |
| from services.postprocessing.registry import build_orchestrator | |
| from services.postprocessing.context import RunContext | |
| class AgentService: | |
| """ | |
| Service for executing agent graphs with different LLMs. | |
| This service is the bridge between the API layer and the LangGraph agents. | |
| It handles: | |
| - Creating the right LLM based on model selection | |
| - Getting the right agent graph from the registry | |
| - Executing the graph with or without streaming | |
| - Supporting agent-specific LLM configurations (e.g., coherence_model for A2) | |
| """ | |
| def __init__(self): | |
| """Initialize the agent service.""" | |
| pass | |
| async def invoke( | |
| self, | |
| payload: BaseAgentPayload, | |
| model_name: ModelName, | |
| agent_type: AgentType = AgentType.A2, | |
| temperature: float = 0.7, | |
| max_tokens: Optional[int] = None, | |
| ) -> dict: | |
| """ | |
| Invoke agent for a single response (non-streaming). | |
| Args: | |
| payload: Validated agent payload containing all agent-specific data | |
| model_name: LLM model to use | |
| agent_type: Type of agent graph | |
| temperature: Sampling temperature | |
| max_tokens: Max tokens to generate | |
| Returns: | |
| Response dictionary with content and metadata | |
| """ | |
| print("\n" + "#"*70, flush=True) | |
| print(f"# AGENT EXECUTION START - {agent_type.value}", flush=True) | |
| print(f"# Model: {model_name.value} | Temperature: {temperature}", flush=True) | |
| print("#"*70, flush=True) | |
| # Create default LLM instance | |
| llm = llm_service.get_llm( | |
| model_name=model_name, | |
| temperature=temperature, | |
| streaming=False, | |
| max_tokens=max_tokens | |
| ) | |
| # Convert payload to dict for state | |
| payload_dict = payload.model_dump() | |
| print(f"[AGENT_SERVICE] Payload keys: {list(payload_dict.keys())}", flush=True) | |
| # Handle agent-specific LLM configurations | |
| coherence_llm = self._get_coherence_llm(payload_dict, temperature, max_tokens) | |
| analysis_llm = self._get_analysis_llm(payload_dict, temperature, max_tokens) | |
| synthesis_llm = self._get_synthesis_llm(payload_dict, temperature, max_tokens) | |
| if coherence_llm: | |
| print(f"[AGENT_SERVICE] Using specific coherence_model: {payload_dict.get('coherence_model')}", flush=True) | |
| if analysis_llm: | |
| print(f"[AGENT_SERVICE] Using specific analysis_model: {payload_dict.get('analysis_model')}", flush=True) | |
| if synthesis_llm: | |
| print(f"[AGENT_SERVICE] Using specific synthesis_model: {payload_dict.get('synthesis_model')}", flush=True) | |
| # Get agent builder and create graph with appropriate LLMs | |
| builder = agent_registry.get_builder(agent_type) | |
| # Pass specific LLMs if the agent supports them | |
| if agent_type == AgentType.A2 and coherence_llm: | |
| graph = builder(llm, coherence_llm=coherence_llm) | |
| elif agent_type == AgentType.A3: | |
| # A3 supports both analysis_llm and synthesis_llm | |
| kwargs = {} | |
| if analysis_llm: | |
| kwargs["analysis_llm"] = analysis_llm | |
| if synthesis_llm: | |
| kwargs["synthesis_llm"] = synthesis_llm | |
| graph = builder(llm, **kwargs) if kwargs else builder(llm) | |
| else: | |
| graph = builder(llm) | |
| # Prepare messages from payload (for agents that need them) | |
| messages = self._prepare_messages_from_payload(payload_dict) | |
| print(f"[AGENT_SERVICE] Initial messages prepared: {len(messages)}", flush=True) | |
| # Execute graph with latency | |
| print(f"[AGENT_SERVICE] Starting graph execution...", flush=True) | |
| start_time = time.time() | |
| result = await graph.ainvoke({ | |
| "messages": messages, | |
| "payload": payload_dict | |
| }) | |
| latency_s = time.time() - start_time | |
| print(f"[AGENT_SERVICE] Graph execution completed in {latency_s:.2f}s", flush=True) | |
| # Extract response - use agent_output if available, otherwise fallback to last message | |
| response_message = result["messages"][-1] | |
| agent_output = result.get("agent_output") | |
| print(f"[AGENT_SERVICE] Last message type: {type(response_message).__name__}", flush=True) | |
| print(f"[AGENT_SERVICE] Last message content preview: {str(response_message.content)[:100]}...", flush=True) | |
| if agent_output is not None: | |
| # Agent provides structured output | |
| response_content: Any = agent_output | |
| print(f"[AGENT_SERVICE] Using agent_output: {list(agent_output.keys())}", flush=True) | |
| else: | |
| # Fallback to message content for simple agents | |
| response_content = response_message.content | |
| print(f"[AGENT_SERVICE] Using message content (no agent_output)", flush=True) | |
| # Get usage from total_usage in state (accumulated across all nodes) | |
| total_usage = result.get("total_usage", {}) | |
| usage_by_model = result.get("usage_by_model", {}) | |
| if total_usage: | |
| print(f"[AGENT_SERVICE] Using total_usage from state: {total_usage}", flush=True) | |
| usage_totals = total_usage | |
| else: | |
| # Fallback to last message usage (for backward compatibility) | |
| usage = getattr(response_message, "usage_metadata", None) or {} | |
| print(f"[AGENT_SERVICE] Fallback to message usage_metadata: {usage}", flush=True) | |
| usage_totals = self._normalize_usage(usage) | |
| # If no usage_by_model from state, create it from default model | |
| if not usage_by_model: | |
| usage_by_model = {model_name.value: usage_totals} | |
| print(f"[AGENT_SERVICE] Final usage_totals: {usage_totals}", flush=True) | |
| print(f"[AGENT_SERVICE] Usage by model: {usage_by_model}", flush=True) | |
| ctx = RunContext( | |
| provider=model_name.provider.value, | |
| model=model_name.value, | |
| usage_totals=usage_totals, | |
| usage_by_model=usage_by_model, | |
| latency_s=latency_s, | |
| ) | |
| build_orchestrator().run(ctx) | |
| base_metadata: Dict[str, Any] = { | |
| "message_count": len(result["messages"]), | |
| } | |
| base_metadata.update(ctx.metadata_out) | |
| # Add token usage breakdown by model to metadata | |
| if usage_by_model: | |
| base_metadata["usage_by_model"] = usage_by_model | |
| # Add individual token counts for easy access | |
| base_metadata["input_tokens"] = usage_totals.get("input_tokens", 0) | |
| base_metadata["output_tokens"] = usage_totals.get("output_tokens", 0) | |
| base_metadata["total_tokens"] = usage_totals.get("total_tokens", 0) | |
| print("\n" + "#"*70, flush=True) | |
| print(f"# AGENT EXECUTION END - {agent_type.value}", flush=True) | |
| print(f"# Latency: {latency_s:.2f}s | Messages: {len(result['messages'])}", flush=True) | |
| print(f"# Usage totals: {usage_totals}", flush=True) | |
| if agent_output: | |
| print(f"# Agent output keys: {list(agent_output.keys())}", flush=True) | |
| print("#"*70 + "\n", flush=True) | |
| return { | |
| "response": response_content, | |
| "model": model_name.value, | |
| "agent_type": agent_type.value, | |
| "metadata": base_metadata, | |
| } | |
| async def stream( | |
| self, | |
| payload: BaseAgentPayload, | |
| model_name: ModelName, | |
| agent_type: AgentType = AgentType.A2, | |
| temperature: float = 0.7, | |
| max_tokens: Optional[int] = None, | |
| ) -> AsyncIterator[dict]: | |
| """ | |
| Stream agent response token by token. | |
| Args: | |
| payload: Validated agent payload containing all agent-specific data | |
| model_name: LLM model to use | |
| agent_type: Type of agent graph | |
| temperature: Sampling temperature | |
| max_tokens: Max tokens to generate | |
| Yields: | |
| Dictionary chunks with content and metadata | |
| """ | |
| print("\n" + "#"*70) | |
| print(f"# AGENT STREAM START - {agent_type.value}") | |
| print(f"# Model: {model_name.value} | Temperature: {temperature}") | |
| print("#"*70) | |
| # Create default LLM instance with streaming enabled | |
| llm = llm_service.get_llm( | |
| model_name=model_name, | |
| temperature=temperature, | |
| streaming=True, | |
| max_tokens=max_tokens | |
| ) | |
| # Convert payload to dict for state | |
| payload_dict = payload.model_dump() | |
| print(f"[AGENT_SERVICE] Payload keys: {list(payload_dict.keys())}") | |
| # Handle agent-specific LLM configurations | |
| coherence_llm = self._get_coherence_llm(payload_dict, temperature, max_tokens, streaming=True) | |
| analysis_llm = self._get_analysis_llm(payload_dict, temperature, max_tokens, streaming=True) | |
| synthesis_llm = self._get_synthesis_llm(payload_dict, temperature, max_tokens, streaming=True) | |
| if coherence_llm: | |
| print(f"[AGENT_SERVICE] Using specific coherence_model: {payload_dict.get('coherence_model')}") | |
| if analysis_llm: | |
| print(f"[AGENT_SERVICE] Using specific analysis_model: {payload_dict.get('analysis_model')}") | |
| if synthesis_llm: | |
| print(f"[AGENT_SERVICE] Using specific synthesis_model: {payload_dict.get('synthesis_model')}") | |
| # Get agent builder and create graph with appropriate LLMs | |
| builder = agent_registry.get_builder(agent_type) | |
| # Pass specific LLMs if the agent supports them | |
| if agent_type == AgentType.A2 and coherence_llm: | |
| graph = builder(llm, coherence_llm=coherence_llm) | |
| elif agent_type == AgentType.A3: | |
| # A3 supports both analysis_llm and synthesis_llm | |
| kwargs = {} | |
| if analysis_llm: | |
| kwargs["analysis_llm"] = analysis_llm | |
| if synthesis_llm: | |
| kwargs["synthesis_llm"] = synthesis_llm | |
| graph = builder(llm, **kwargs) if kwargs else builder(llm) | |
| else: | |
| graph = builder(llm) | |
| # Prepare messages from payload | |
| messages = self._prepare_messages_from_payload(payload_dict) | |
| print(f"[AGENT_SERVICE] Initial messages prepared: {len(messages)}") | |
| print(f"[AGENT_SERVICE] Starting streaming graph execution...") | |
| # Track usage and latency for final emissions calculation | |
| usage_totals: Dict[str, int] = {} | |
| usage_by_model: Dict[str, Dict[str, int]] = {} | |
| start_time = time.time() | |
| documents = [] | |
| # Track agent_output - will be captured from any node that produces it | |
| captured_agent_output: Optional[Dict[str, Any]] = None | |
| # Stream graph execution | |
| async for msg in graph.astream({ | |
| "messages": messages, | |
| "payload": payload_dict | |
| }, stream_mode=["messages","updates"]): | |
| # LangGraph may yield (node_name, message) tuples in messages mode | |
| event = None | |
| params = None | |
| # Only emit assistant outputs; ignore user/history echoes | |
| text: Optional[str] = None | |
| if msg[0] == "messages": | |
| chunk = msg[1] | |
| if isinstance(chunk, tuple) and len(chunk) == 2: | |
| # Prefer the BaseMessage element if present | |
| from langchain_core.messages import BaseMessage as _LCBaseMessage | |
| if isinstance(chunk[1], _LCBaseMessage): | |
| event = chunk[1] | |
| params = chunk[0] | |
| elif isinstance(chunk[0], _LCBaseMessage): | |
| event = chunk[0] | |
| params = chunk[1] | |
| else: | |
| # Fallback to second element by convention | |
| event = chunk[1] | |
| else: | |
| event = chunk | |
| if msg[0] == "updates": | |
| node = msg[1] | |
| # Capture agent_output, total_usage, and usage_by_model from any node | |
| if isinstance(node, dict): | |
| for node_name, node_output in node.items(): | |
| if isinstance(node_output, dict): | |
| if "agent_output" in node_output: | |
| captured_agent_output = node_output["agent_output"] | |
| print(f"[AGENT_SERVICE] Captured agent_output from node '{node_name}'", flush=True) | |
| if "total_usage" in node_output: | |
| usage_totals = node_output["total_usage"] | |
| print(f"[AGENT_SERVICE] Captured total_usage from node '{node_name}': {usage_totals}", flush=True) | |
| if "usage_by_model" in node_output: | |
| usage_by_model = node_output["usage_by_model"] | |
| print(f"[AGENT_SERVICE] Captured usage_by_model from node '{node_name}': {usage_by_model}", flush=True) | |
| # Handle summarizer_export (existing logic for documents) | |
| summarizer_export = node.get("summarizer_export") | |
| if summarizer_export and isinstance(summarizer_export, dict): | |
| node_messages = summarizer_export.get("messages", []) | |
| last_message = node_messages[-1] if node_messages else None | |
| if isinstance(last_message, AIMessage): | |
| text = self._extract_text_content(last_message.content) | |
| doc_meta = last_message.metadata.get("document") if last_message.metadata else None | |
| if doc_meta is not None: | |
| documents.append(doc_meta) | |
| if isinstance(event, AIMessageChunk): | |
| text = self._extract_text_content(event.content) | |
| # Capture usage if present on chunks | |
| try: | |
| chunk_usage = getattr(event, "usage_metadata", None) | |
| if isinstance(chunk_usage, dict): | |
| norm = self._normalize_usage(chunk_usage) | |
| model_id = self._extract_model_from_params(params) or model_name.value | |
| bucket = usage_by_model.setdefault(model_id, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}) | |
| bucket["input_tokens"] += norm["input_tokens"] | |
| bucket["output_tokens"] += norm["output_tokens"] | |
| bucket["total_tokens"] += norm["total_tokens"] | |
| usage_totals["input_tokens"] = usage_totals.get("input_tokens", 0) + norm["input_tokens"] | |
| usage_totals["output_tokens"] = usage_totals.get("output_tokens", 0) + norm["output_tokens"] | |
| usage_totals["total_tokens"] = usage_totals.get("total_tokens", 0) + norm["total_tokens"] | |
| except Exception: | |
| pass | |
| else: | |
| # Not an assistant output we should stream (e.g., HumanMessage) | |
| continue | |
| if text: | |
| yield { | |
| "content": text, | |
| "done": False, | |
| "metadata": { | |
| "model": model_name.value, | |
| "agent_type": agent_type.value, | |
| "usage": usage_totals | |
| }, | |
| "documents": documents | |
| } | |
| # Compute latency and run post-processing pipeline for the final chunk | |
| latency_s = time.time() - start_time | |
| ctx = RunContext( | |
| provider=model_name.provider.value, | |
| model=model_name.value, | |
| usage_totals=usage_totals, | |
| usage_by_model=usage_by_model, | |
| latency_s=latency_s, | |
| ) | |
| build_orchestrator().run(ctx) | |
| # Build final metadata | |
| final_metadata = { | |
| "model": model_name.value, | |
| "agent_type": agent_type.value, | |
| "usage": usage_totals, | |
| "usage_by_model": usage_by_model, | |
| "latency_s": latency_s, | |
| **ctx.metadata_out | |
| } | |
| # Add token usage breakdown by model to metadata | |
| if usage_by_model: | |
| final_metadata["usage_by_model"] = usage_by_model | |
| # Add individual token counts for easy access | |
| final_metadata["input_tokens"] = usage_totals.get("input_tokens", 0) | |
| final_metadata["output_tokens"] = usage_totals.get("output_tokens", 0) | |
| final_metadata["total_tokens"] = usage_totals.get("total_tokens", 0) | |
| print("\n" + "#"*70, flush=True) | |
| print(f"# AGENT STREAM END - {agent_type.value}", flush=True) | |
| print(f"# Latency: {latency_s:.2f}s", flush=True) | |
| print(f"# Usage totals: {usage_totals}", flush=True) | |
| if usage_by_model: | |
| print(f"# Usage by model: {usage_by_model}", flush=True) | |
| if captured_agent_output: | |
| print(f"# Agent output keys: {list(captured_agent_output.keys())}", flush=True) | |
| print("#"*70 + "\n", flush=True) | |
| # Build final chunk | |
| final_chunk: Dict[str, Any] = { | |
| "content": "", | |
| "done": True, | |
| "metadata": final_metadata, | |
| "documents": documents | |
| } | |
| # Add structured response if agent provided agent_output | |
| if captured_agent_output is not None: | |
| final_chunk["response"] = captured_agent_output | |
| # Send final chunk | |
| yield final_chunk | |
| def _get_coherence_llm( | |
| self, | |
| payload_dict: Dict[str, Any], | |
| temperature: float, | |
| max_tokens: Optional[int], | |
| streaming: bool = False | |
| ) -> Optional[BaseChatModel]: | |
| """ | |
| Get a specific LLM for the CoherenceAgent if specified in payload. | |
| Args: | |
| payload_dict: Payload dictionary | |
| temperature: Sampling temperature | |
| max_tokens: Max tokens to generate | |
| streaming: Whether to enable streaming | |
| Returns: | |
| LLM instance if coherence_model is specified, None otherwise | |
| """ | |
| coherence_model_str = payload_dict.get("coherence_model") | |
| if not coherence_model_str: | |
| return None | |
| try: | |
| coherence_model_name = ModelName(coherence_model_str) | |
| return llm_service.get_llm( | |
| model_name=coherence_model_name, | |
| temperature=temperature, | |
| streaming=streaming, | |
| max_tokens=max_tokens | |
| ) | |
| except ValueError: | |
| # Invalid model name, return None to use default | |
| return None | |
| def _get_analysis_llm( | |
| self, | |
| payload_dict: Dict[str, Any], | |
| temperature: float, | |
| max_tokens: Optional[int], | |
| streaming: bool = False | |
| ) -> Optional[BaseChatModel]: | |
| """ | |
| Get a specific LLM for the AnalysisAgent if specified in payload. | |
| Args: | |
| payload_dict: Payload dictionary | |
| temperature: Sampling temperature | |
| max_tokens: Max tokens to generate | |
| streaming: Whether to enable streaming | |
| Returns: | |
| LLM instance if analysis_model is specified, None otherwise | |
| """ | |
| analysis_model_str = payload_dict.get("analysis_model") | |
| if not analysis_model_str: | |
| return None | |
| try: | |
| analysis_model_name = ModelName(analysis_model_str) | |
| return llm_service.get_llm( | |
| model_name=analysis_model_name, | |
| temperature=temperature, | |
| streaming=streaming, | |
| max_tokens=max_tokens | |
| ) | |
| except ValueError: | |
| # Invalid model name, return None to use default | |
| return None | |
| def _get_synthesis_llm( | |
| self, | |
| payload_dict: Dict[str, Any], | |
| temperature: float, | |
| max_tokens: Optional[int], | |
| streaming: bool = False | |
| ) -> Optional[BaseChatModel]: | |
| """ | |
| Get a specific LLM for the SynthesisAgent if specified in payload. | |
| Args: | |
| payload_dict: Payload dictionary | |
| temperature: Sampling temperature | |
| max_tokens: Max tokens to generate | |
| streaming: Whether to enable streaming | |
| Returns: | |
| LLM instance if synthesis_model is specified, None otherwise | |
| """ | |
| synthesis_model_str = payload_dict.get("synthesis_model") | |
| if not synthesis_model_str: | |
| return None | |
| try: | |
| synthesis_model_name = ModelName(synthesis_model_str) | |
| return llm_service.get_llm( | |
| model_name=synthesis_model_name, | |
| temperature=temperature, | |
| streaming=streaming, | |
| max_tokens=max_tokens | |
| ) | |
| except ValueError: | |
| # Invalid model name, return None to use default | |
| return None | |
| def _prepare_messages_from_payload( | |
| self, | |
| payload: Dict[str, Any] | |
| ) -> List[BaseMessage]: | |
| """ | |
| Prepare messages list from payload data. | |
| Extracts 'query' and 'conversation_history' from the payload | |
| if they exist. Agents can define their own structure. | |
| Args: | |
| payload: Validated payload dictionary | |
| Returns: | |
| List of LangChain messages | |
| """ | |
| messages = [] | |
| # Extract conversation history if present | |
| conversation_history = payload.get("conversation_history") | |
| if conversation_history: | |
| for msg in conversation_history: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| if role == "user": | |
| messages.append(HumanMessage(content=content)) | |
| elif role == "assistant": | |
| messages.append(AIMessage(content=content)) | |
| elif role == "system": | |
| messages.append(SystemMessage(content=content)) | |
| # Extract query if present and add as current message | |
| query = payload.get("query") | |
| if query: | |
| messages.append(HumanMessage(content=query)) | |
| return messages | |
| def _extract_text_content(self, content: object) -> Optional[str]: | |
| """ | |
| Normalize LangChain message content into a plain text string. | |
| Handles both string content and list-structured content with text parts. | |
| """ | |
| if content is None: | |
| return None | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| # LangChain can represent content as a list of parts like | |
| # [{"type": "text", "text": "..."}, ...] | |
| text_parts: List[str] = [] | |
| for part in content: | |
| try: | |
| # Dict-like parts with type/text | |
| if isinstance(part, dict): | |
| if part.get("type") == "text" and isinstance(part.get("text"), str): | |
| text_parts.append(part["text"]) | |
| # Object-like parts with attributes | |
| elif hasattr(part, "type") and getattr(part, "type") == "text" and hasattr(part, "text"): | |
| value = getattr(part, "text") | |
| if isinstance(value, str): | |
| text_parts.append(value) | |
| except Exception: | |
| # Skip any malformed parts | |
| continue | |
| return "".join(text_parts) if text_parts else None | |
| # Fallback: unknown content structure | |
| return None | |
| def _normalize_usage(self, usage: Dict[str, Any]) -> Dict[str, int]: | |
| """Normalize usage keys to input/output/total integers. | |
| Supports variants like prompt_tokens/completion_tokens. | |
| """ | |
| try: | |
| input_val = usage.get("input_tokens") | |
| if not isinstance(input_val, (int, float)): | |
| input_val = usage.get("prompt_tokens", 0) | |
| output_val = usage.get("output_tokens") | |
| if not isinstance(output_val, (int, float)): | |
| output_val = usage.get("completion_tokens", 0) | |
| total_val = usage.get("total_tokens") | |
| if not isinstance(total_val, (int, float)): | |
| total_val = (int(input_val or 0)) + (int(output_val or 0)) | |
| return { | |
| "input_tokens": int(input_val or 0), | |
| "output_tokens": int(output_val or 0), | |
| "total_tokens": int(total_val or 0), | |
| } | |
| except Exception: | |
| return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} | |
| def _extract_model_from_params(self, params: Optional[Dict[str, Any]]) -> Optional[str]: | |
| """Best-effort extraction of model identifier from LangGraph params.""" | |
| if not isinstance(params, dict): | |
| return None | |
| keys = ("ls_model_name", "model", "model_name", "model_id", "name", "llm_model", "openai_model") | |
| for key in keys: | |
| val = params.get(key) | |
| if isinstance(val, str) and val: | |
| return val | |
| # Search likely nested containers | |
| for container_key in ("configuration", "config", "kwargs", "meta", "metadata"): | |
| sub = params.get(container_key) | |
| if isinstance(sub, dict): | |
| for key in keys: | |
| val = sub.get(key) | |
| if isinstance(val, str) and val: | |
| return val | |
| return None | |
| # Singleton instance | |
| agent_service = AgentService() | |