Spaces:
Sleeping
Sleeping
| from typing import Dict, Optional, Tuple, List, Any, Set, Union | |
| import re | |
| import xml.etree.ElementTree as ET | |
| from datetime import datetime | |
| import json | |
| import logging | |
| from enum import Enum | |
| # Setup logger | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| # Create console handler if needed | |
| if not logger.handlers: | |
| ch = logging.StreamHandler() | |
| ch.setLevel(logging.INFO) | |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| ch.setFormatter(formatter) | |
| logger.addHandler(ch) | |
| class StreamingFormatter: | |
| def __init__(self): | |
| self.processed_events = set() | |
| self.current_tool_outputs = [] | |
| self.current_citations = [] | |
| self.current_metadata = {} | |
| self.current_message_id = None | |
| self.current_message_buffer = "" | |
| def reset(self): | |
| """Reset the formatter state""" | |
| self.processed_events.clear() | |
| self.current_tool_outputs.clear() | |
| self.current_citations.clear() | |
| self.current_metadata.clear() | |
| self.current_message_id = None | |
| self.current_message_buffer = "" | |
| def append_to_buffer(self, text: str): | |
| """Append text to the current message buffer""" | |
| self.current_message_buffer += text | |
| def get_and_clear_buffer(self) -> str: | |
| """Get the current buffer content and clear it""" | |
| content = self.current_message_buffer | |
| self.current_message_buffer = "" | |
| return content | |
| class ToolType(Enum): | |
| """Enum for supported tool types""" | |
| DUCKDUCKGO = "ddgo_search" | |
| REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase" | |
| PUBMED = "pubmed_search" | |
| CENSUS = "get_census_data" | |
| HEATMAP = "heatmap_code" | |
| MERMAID = "mermaid_output" | |
| WISQARS = "wisqars" | |
| WONDER = "wonder" | |
| NCHS = "nchs" | |
| ONESTEP = "onestep" | |
| DQS = "dqs_nhis_adult_summary_health_statistics" | |
| def get_tool_type(cls, tool_name: str) -> Optional['ToolType']: | |
| """Get enum member from tool name string""" | |
| try: | |
| return cls[tool_name.upper()] | |
| except KeyError: | |
| return None | |
| class ResponseFormatter: | |
| _instance = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(ResponseFormatter, cls).__new__(cls) | |
| cls._instance.streaming_state = StreamingFormatter() | |
| cls._instance.logger = logger | |
| return cls._instance | |
| def format_thought( | |
| self, | |
| thought: str, | |
| observation: str, | |
| citations: List[Dict] = None, | |
| metadata: Dict = None, | |
| tool_outputs: List[Dict] = None, | |
| event_id: str = None, | |
| message_id: str = None | |
| ) -> Optional[Tuple[str, str]]: | |
| """Format agent thought for both terminal and XML output""" | |
| # Skip if already processed in streaming mode | |
| if event_id and event_id in self.streaming_state.processed_events: | |
| return None | |
| # Handle message state | |
| if message_id != self.streaming_state.current_message_id: | |
| self.streaming_state.reset() | |
| self.streaming_state.current_message_id = message_id | |
| # Skip empty thoughts | |
| if not thought and not observation and not tool_outputs: | |
| return None | |
| # Terminal format | |
| terminal_output = { | |
| "type": "agent_thought", | |
| "content": thought, | |
| "metadata": metadata or {} | |
| } | |
| if tool_outputs: | |
| # Deduplicate tool outputs | |
| seen_outputs = set() | |
| unique_outputs = [] | |
| for output in tool_outputs: | |
| output_key = f"{output.get('type')}:{output.get('content')}" | |
| if output_key not in seen_outputs: | |
| seen_outputs.add(output_key) | |
| unique_outputs.append(output) | |
| terminal_output["tool_outputs"] = unique_outputs | |
| # XML format | |
| root = ET.Element("agent_response") | |
| if thought: | |
| thought_elem = ET.SubElement(root, "thought") | |
| thought_elem.text = thought | |
| if observation: | |
| obs_elem = ET.SubElement(root, "observation") | |
| obs_elem.text = observation | |
| if tool_outputs: | |
| tools_elem = ET.SubElement(root, "tool_outputs") | |
| for tool_output in unique_outputs: | |
| tool_elem = ET.SubElement(tools_elem, "tool_output") | |
| tool_elem.attrib["type"] = tool_output.get("type", "") | |
| tool_elem.text = tool_output.get("content", "") | |
| if citations: | |
| cites_elem = ET.SubElement(root, "citations") | |
| for citation in citations: | |
| cite_elem = ET.SubElement(cites_elem, "citation") | |
| for key, value in citation.items(): | |
| cite_elem.attrib[key] = str(value) | |
| xml_output = ET.tostring(root, encoding='unicode') | |
| # Track processed event | |
| if event_id: | |
| self.streaming_state.processed_events.add(event_id) | |
| return json.dumps(terminal_output), xml_output | |
| def format_message( | |
| self, | |
| message: str, | |
| event_id: str = None, | |
| message_id: str = None | |
| ) -> Optional[Tuple[str, str]]: | |
| """Format agent message for both terminal and XML output""" | |
| # Skip if already processed | |
| if event_id and event_id in self.streaming_state.processed_events: | |
| return None | |
| # Handle message state | |
| if message_id != self.streaming_state.current_message_id: | |
| self.streaming_state.reset() | |
| self.streaming_state.current_message_id = message_id | |
| # Accumulate message content | |
| self.streaming_state.append_to_buffer(message) | |
| # Only output if we have meaningful content | |
| if not self.streaming_state.current_message_buffer.strip(): | |
| return None | |
| # Terminal format | |
| terminal_output = self.streaming_state.current_message_buffer.strip() | |
| # XML format | |
| root = ET.Element("agent_response") | |
| msg_elem = ET.SubElement(root, "message") | |
| msg_elem.text = terminal_output | |
| xml_output = ET.tostring(root, encoding='unicode') | |
| # Track processed event | |
| if event_id: | |
| self.streaming_state.processed_events.add(event_id) | |
| return terminal_output, xml_output | |
| def format_error( | |
| self, | |
| error: str, | |
| event_id: str = None, | |
| message_id: str = None | |
| ) -> Optional[Tuple[str, str]]: | |
| """Format error message for both terminal and XML output""" | |
| # Skip if already processed | |
| if event_id and event_id in self.streaming_state.processed_events: | |
| return None | |
| # Handle message state | |
| if message_id != self.streaming_state.current_message_id: | |
| self.streaming_state.reset() | |
| self.streaming_state.current_message_id = message_id | |
| # Skip empty errors | |
| if not error: | |
| return None | |
| # Terminal format | |
| terminal_output = f"Error: {error}" | |
| # XML format | |
| root = ET.Element("agent_response") | |
| error_elem = ET.SubElement(root, "error") | |
| error_elem.text = error | |
| xml_output = ET.tostring(root, encoding='unicode') | |
| # Track processed event | |
| if event_id: | |
| self.streaming_state.processed_events.add(event_id) | |
| return terminal_output, xml_output | |
| def format_tool_output( | |
| self, | |
| tool_type: str, | |
| content: Union[str, Dict], | |
| metadata: Optional[Dict] = None | |
| ) -> Dict: | |
| """Format tool output into standardized structure""" | |
| try: | |
| # Get enum tool type | |
| tool = ToolType.get_tool_type(tool_type) | |
| if not tool: | |
| self.logger.warning(f"Unknown tool type: {tool_type}") | |
| return { | |
| "type": tool_type, | |
| "content": content, | |
| "metadata": metadata or {} | |
| } | |
| # Format based on tool type | |
| if tool == ToolType.MERMAID: | |
| return { | |
| "type": "mermaid", | |
| "content": self._clean_mermaid_content(content), | |
| "metadata": metadata or {} | |
| } | |
| elif tool == ToolType.HEATMAP: | |
| return { | |
| "type": "heatmap", | |
| "content": self._format_heatmap_data(content), | |
| "metadata": metadata or {} | |
| } | |
| else: | |
| # Default formatting for other tools | |
| return { | |
| "type": tool.value, | |
| "content": content, | |
| "metadata": metadata or {} | |
| } | |
| except Exception as e: | |
| self.logger.error(f"Error formatting tool output: {str(e)}") | |
| return { | |
| "type": "error", | |
| "content": str(e), | |
| "metadata": metadata or {} | |
| } | |
| def _clean_mermaid_content(self, content: Union[str, Dict]) -> str: | |
| """Clean and standardize mermaid diagram content""" | |
| try: | |
| if isinstance(content, dict): | |
| content = content.get("mermaid_diagram", "") | |
| # Remove markdown formatting | |
| content = re.sub(r'```mermaid\s*|\s*```', '', content) | |
| # Clean up whitespace | |
| content = content.strip() | |
| return content | |
| except Exception as e: | |
| self.logger.error(f"Error cleaning mermaid content: {str(e)}") | |
| return str(content) | |
| def _format_heatmap_data(self, content: Union[str, Dict]) -> Dict: | |
| """Format heatmap data into standardized structure""" | |
| try: | |
| if isinstance(content, str): | |
| content = json.loads(content) | |
| return { | |
| "data": content.get("data", []), | |
| "options": content.get("options", {}), | |
| "metadata": content.get("metadata", {}) | |
| } | |
| except Exception as e: | |
| self.logger.error(f"Error formatting heatmap data: {str(e)}") | |
| return {"error": str(e)} | |
| def _clean_markdown(text: str) -> str: | |
| """Clean markdown formatting from text""" | |
| text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) | |
| text = re.sub(r'[*_`#]', '', text) | |
| return re.sub(r'\n{3,}', '\n\n', text.strip()) |