Spaces:
Running
Running
| from typing import List, Dict, Any, Optional | |
| import logging | |
| import json | |
| import time | |
| from observability import logger as obs_logger | |
| from observability import components as obs_components | |
| from ..base import BaseAgent | |
| from domain.training.charts import ChartSpec, ChartType | |
| from llm.base import LLMClient | |
| logger = logging.getLogger(__name__) | |
| class VisualizationAgent(BaseAgent): | |
| """ | |
| Agent responsible for specifying visualizations from run data using LLM. | |
| Returns ChartSpec objects instead of executing code. | |
| """ | |
| def __init__(self, llm_client: LLMClient): | |
| self.llm_client = llm_client | |
| self.requested_specs: List[ChartSpec] = [] | |
| self.instruction = """You are a data visualization expert. Your goal is to specify charts to visualize run data. | |
| You have access to a tool `request_chart(chart_type, title, params_json)`. | |
| Chart Types available: | |
| - 'pace': Pace vs Date | |
| - 'heart_rate': Heart Rate vs Date | |
| The data is available in the variable `features`, which is a list of dictionaries. | |
| Each dictionary represents a run and has keys like: | |
| - 'start_time': datetime object | |
| - 'avg_pace_s_per_km': float (seconds per km) | |
| - 'avg_hr_bpm': float (beats per minute) | |
| - 'total_distance_m': float | |
| - 'total_duration_s': float | |
| When asked to generate charts: | |
| 1. Determine which chart types are appropriate. | |
| 2. Call `request_chart` for each chart needed. | |
| 3. `params_json` should be a JSON string of optional parameters (e.g., {"color": "blue"}). | |
| Always request 'pace' and 'heart_rate' charts by default. | |
| """ | |
| def request_chart(self, chart_type: str, title: str, params_json: str = "{}") -> str: | |
| """ | |
| Records a request for a specific chart type with title and parameters. | |
| """ | |
| try: | |
| params = json.loads(params_json) | |
| spec = ChartSpec(chart_type=ChartType(chart_type.lower()), title=title, params=params) | |
| self.requested_specs.append(spec) | |
| return f"Successfully recorded request for {chart_type} chart: {title}" | |
| except Exception as e: | |
| logger.error(f"Error recording chart request: {e}") | |
| return f"Error recording request: {str(e)}" | |
| async def run(self, features: List[Dict[str, Any]], query: str = None) -> List[ChartSpec]: | |
| with obs_logger.start_span("visualization_agent.run", obs_components.AGENT): | |
| """ | |
| Generates chart specifications using the LLM. | |
| """ | |
| self.requested_specs = [] | |
| if query: | |
| prompt = f""" | |
| The user has requested a specific visualization: "{query}" | |
| There are {len(features)} runs available. | |
| Please request the most appropriate chart(s) using `request_chart`. | |
| """ | |
| else: | |
| prompt = f""" | |
| Generate default visualizations for the provided {len(features)} runs. | |
| Please request: | |
| 1. A pace chart ('pace') | |
| 2. A heart rate chart ('heart_rate') | |
| """ | |
| try: | |
| await self.llm_client.generate( | |
| prompt, | |
| instruction=self.instruction, | |
| tools=[self.request_chart], | |
| name="visualization_agent", | |
| ) | |
| except Exception as e: | |
| logger.error(f"LLM specification failed: {e}") | |
| # The context manager will already log the error upon exit | |
| # Fallback if no specs were generated | |
| if not self.requested_specs and not query: | |
| self.requested_specs = [ | |
| ChartSpec(chart_type=ChartType.PACE, title="Pace Chart"), | |
| ChartSpec(chart_type=ChartType.HEART_RATE, title="Heart Rate Chart"), | |
| ] | |
| return self.requested_specs | |