File size: 4,117 Bytes
557ee65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import List, Dict, Any, Optional
import logging
import json
import time
from observability import logger as obs_logger
from observability import components as obs_components
from ..base import BaseAgent
from domain.training.charts import ChartSpec, ChartType
from llm.base import LLMClient

logger = logging.getLogger(__name__)


class VisualizationAgent(BaseAgent):
    """
    Agent responsible for specifying visualizations from run data using LLM.
    Returns ChartSpec objects instead of executing code.
    """

    def __init__(self, llm_client: LLMClient):
        self.llm_client = llm_client
        self.requested_specs: List[ChartSpec] = []
        self.instruction = """You are a data visualization expert. Your goal is to specify charts to visualize run data.
            
            You have access to a tool `request_chart(chart_type, title, params_json)`.
            
            Chart Types available:
            - 'pace': Pace vs Date
            - 'heart_rate': Heart Rate vs Date
            
            The data is available in the variable `features`, which is a list of dictionaries.
            Each dictionary represents a run and has keys like:
            - 'start_time': datetime object
            - 'avg_pace_s_per_km': float (seconds per km)
            - 'avg_hr_bpm': float (beats per minute)
            - 'total_distance_m': float
            - 'total_duration_s': float
            
            When asked to generate charts:
            1. Determine which chart types are appropriate.
            2. Call `request_chart` for each chart needed.
            3. `params_json` should be a JSON string of optional parameters (e.g., {"color": "blue"}).
            
            Always request 'pace' and 'heart_rate' charts by default.
            """

    def request_chart(self, chart_type: str, title: str, params_json: str = "{}") -> str:
        """
        Records a request for a specific chart type with title and parameters.
        """
        try:
            params = json.loads(params_json)
            spec = ChartSpec(chart_type=ChartType(chart_type.lower()), title=title, params=params)
            self.requested_specs.append(spec)
            return f"Successfully recorded request for {chart_type} chart: {title}"
        except Exception as e:
            logger.error(f"Error recording chart request: {e}")
            return f"Error recording request: {str(e)}"

    async def run(self, features: List[Dict[str, Any]], query: str = None) -> List[ChartSpec]:
        with obs_logger.start_span("visualization_agent.run", obs_components.AGENT):
            """
            Generates chart specifications using the LLM.
            """
            self.requested_specs = []

            if query:
                prompt = f"""
                The user has requested a specific visualization: "{query}"
                There are {len(features)} runs available.
                Please request the most appropriate chart(s) using `request_chart`.
                """
            else:
                prompt = f"""
                Generate default visualizations for the provided {len(features)} runs.
                Please request:
                1. A pace chart ('pace')
                2. A heart rate chart ('heart_rate')
                """

            try:
                await self.llm_client.generate(
                    prompt,
                    instruction=self.instruction,
                    tools=[self.request_chart],
                    name="visualization_agent",
                )
            except Exception as e:
                logger.error(f"LLM specification failed: {e}")
                # The context manager will already log the error upon exit

            # Fallback if no specs were generated
            if not self.requested_specs and not query:
                self.requested_specs = [
                    ChartSpec(chart_type=ChartType.PACE, title="Pace Chart"),
                    ChartSpec(chart_type=ChartType.HEART_RATE, title="Heart Rate Chart"),
                ]

            return self.requested_specs