Spaces:
Sleeping
Sleeping
| from typing import Dict, Optional, Tuple, List, Any | |
| import re | |
| import xml.etree.ElementTree as ET | |
| from datetime import datetime | |
| import json | |
| class ToolType: | |
| DUCKDUCKGO = "duckduckgo_search" | |
| REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase" | |
| PUBMED = "pubmed_search" | |
| CENSUS = "get_census_data" | |
| HEATMAP = "heatmap_code" | |
| MERMAID = "mermaid_diagram" | |
| WISQARS = "wisqars" | |
| WONDER = "wonder" | |
| NCHS = "nchs" | |
| ONESTEP = "onestep" | |
| DQS = "dqs_nhis_adult_summary_health_statistics" | |
| class ResponseFormatter: | |
| def format_thought( | |
| thought: str, | |
| observation: Optional[str] = None, | |
| citations: List[Dict] = None, | |
| metadata: Dict = None | |
| ) -> Tuple[str, str]: | |
| """Format agent thought and observation for both terminal and XML output""" | |
| # Terminal format | |
| terminal_output = thought.strip() | |
| if observation: | |
| cleaned_obs = ResponseFormatter._clean_markdown(observation) | |
| if cleaned_obs: | |
| terminal_output += f"\n\nObservation:\n{cleaned_obs}" | |
| # XML format | |
| root = ET.Element("agent_response") | |
| thought_elem = ET.SubElement(root, "thought") | |
| thought_elem.text = thought.strip() | |
| if observation: | |
| obs_elem = ET.SubElement(root, "observation") | |
| # Extract and format tool outputs | |
| tool_outputs = ResponseFormatter._extract_tool_outputs(observation) | |
| if tool_outputs: | |
| tools_elem = ET.SubElement(obs_elem, "tools") | |
| for tool_name, tool_data in tool_outputs.items(): | |
| tool_elem = ResponseFormatter._create_tool_element(tools_elem, tool_name, tool_data) | |
| # Add citations if available | |
| if citations: | |
| citations_elem = ET.SubElement(root, "citations") | |
| for citation in citations: | |
| cite_elem = ET.SubElement(citations_elem, "citation") | |
| for key, value in citation.items(): | |
| cite_detail = ET.SubElement(cite_elem, key) | |
| cite_detail.text = str(value) | |
| # Add metadata if available | |
| if metadata: | |
| metadata_elem = ET.SubElement(root, "metadata") | |
| for key, value in metadata.items(): | |
| meta_detail = ET.SubElement(metadata_elem, key) | |
| meta_detail.text = str(value) | |
| xml_output = ET.tostring(root, encoding='unicode') | |
| return terminal_output, xml_output | |
| def _create_tool_element(parent: ET.Element, tool_name: str, tool_data: Dict) -> ET.Element: | |
| """Create XML element for specific tool type with appropriate structure""" | |
| tool_elem = ET.SubElement(parent, "tool") | |
| tool_elem.set("name", tool_name) | |
| # Handle different tool types | |
| if tool_name == ToolType.CENSUS: | |
| ResponseFormatter._format_census_data(tool_elem, tool_data) | |
| elif tool_name == ToolType.MERMAID: | |
| ResponseFormatter._format_mermaid_data(tool_elem, tool_data) | |
| elif tool_name in [ToolType.WISQARS, ToolType.WONDER, ToolType.NCHS]: | |
| ResponseFormatter._format_health_data(tool_elem, tool_data) | |
| else: | |
| # Generic tool output format | |
| content_elem = ET.SubElement(tool_elem, "content") | |
| content_elem.text = ResponseFormatter._clean_markdown(str(tool_data)) | |
| return tool_elem | |
| def _format_census_data(tool_elem: ET.Element, data: Dict) -> None: | |
| """Format census data with specific structure""" | |
| try: | |
| # Extract census tract data | |
| tracts_elem = ET.SubElement(tool_elem, "census_tracts") | |
| # Parse the llm_result to extract structured data | |
| if "llm_result" in data: | |
| result = json.loads(data["llm_result"]) | |
| for tract_data in result.get("tracts", []): | |
| tract_elem = ET.SubElement(tracts_elem, "tract") | |
| tract_elem.set("id", str(tract_data.get("tract", ""))) | |
| # Add tract details | |
| for key, value in tract_data.items(): | |
| if key != "tract": | |
| detail_elem = ET.SubElement(tract_elem, key.replace("_", "")) | |
| detail_elem.text = str(value) | |
| except: | |
| # Fallback to simple format if parsing fails | |
| content_elem = ET.SubElement(tool_elem, "content") | |
| content_elem.text = ResponseFormatter._clean_markdown(str(data)) | |
| def _format_mermaid_data(tool_elem: ET.Element, data: Dict) -> None: | |
| """Format mermaid diagram data""" | |
| try: | |
| diagram_elem = ET.SubElement(tool_elem, "diagram") | |
| if "mermaid_diagram" in data: | |
| # Clean the mermaid code | |
| mermaid_code = re.sub(r'```mermaid\s*|\s*```', '', data["mermaid_diagram"]) | |
| diagram_elem.text = mermaid_code.strip() | |
| except: | |
| content_elem = ET.SubElement(tool_elem, "content") | |
| content_elem.text = ResponseFormatter._clean_markdown(str(data)) | |
| def _format_health_data(tool_elem: ET.Element, data: Dict) -> None: | |
| """Format health-related data from WISQARS, WONDER, etc.""" | |
| try: | |
| if isinstance(data, dict): | |
| for key, value in data.items(): | |
| category_elem = ET.SubElement(tool_elem, key.replace("_", "")) | |
| if isinstance(value, dict): | |
| for sub_key, sub_value in value.items(): | |
| sub_elem = ET.SubElement(category_elem, sub_key.replace("_", "")) | |
| sub_elem.text = str(sub_value) | |
| else: | |
| category_elem.text = str(value) | |
| except: | |
| content_elem = ET.SubElement(tool_elem, "content") | |
| content_elem.text = ResponseFormatter._clean_markdown(str(data)) | |
| def _extract_tool_outputs(observation: str) -> Dict[str, Any]: | |
| """Extract and clean tool outputs from observation""" | |
| tool_outputs = {} | |
| try: | |
| if isinstance(observation, str): | |
| data = json.loads(observation) | |
| for key, value in data.items(): | |
| if isinstance(value, str) and "llm_result" in value: | |
| try: | |
| tool_result = json.loads(value) | |
| tool_outputs[key] = tool_result | |
| except: | |
| tool_outputs[key] = value | |
| except: | |
| pass | |
| return tool_outputs | |
| def format_message(message: str) -> Tuple[str, str]: | |
| """Format agent message for both terminal and XML output""" | |
| # Terminal format | |
| terminal_output = message.strip() | |
| # XML format | |
| root = ET.Element("agent_response") | |
| msg_elem = ET.SubElement(root, "message") | |
| msg_elem.text = message.strip() | |
| xml_output = ET.tostring(root, encoding='unicode') | |
| return terminal_output, xml_output | |
| def format_error(error: str) -> Tuple[str, str]: | |
| """Format error message for both terminal and XML output""" | |
| # Terminal format | |
| terminal_output = f"Error: {error}" | |
| # XML format | |
| root = ET.Element("agent_response") | |
| error_elem = ET.SubElement(root, "error") | |
| error_elem.text = error | |
| xml_output = ET.tostring(root, encoding='unicode') | |
| return terminal_output, xml_output | |
| def _clean_markdown(text: str) -> str: | |
| """Clean markdown formatting from text""" | |
| text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) | |
| text = re.sub(r'[*_`#]', '', text) | |
| return re.sub(r'\n{3,}', '\n\n', text.strip()) |