Spaces:
Running
Running
File size: 4,117 Bytes
557ee65 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | from typing import List, Dict, Any, Optional
import logging
import json
import time
from observability import logger as obs_logger
from observability import components as obs_components
from ..base import BaseAgent
from domain.training.charts import ChartSpec, ChartType
from llm.base import LLMClient
logger = logging.getLogger(__name__)
class VisualizationAgent(BaseAgent):
"""
Agent responsible for specifying visualizations from run data using LLM.
Returns ChartSpec objects instead of executing code.
"""
def __init__(self, llm_client: LLMClient):
self.llm_client = llm_client
self.requested_specs: List[ChartSpec] = []
self.instruction = """You are a data visualization expert. Your goal is to specify charts to visualize run data.
You have access to a tool `request_chart(chart_type, title, params_json)`.
Chart Types available:
- 'pace': Pace vs Date
- 'heart_rate': Heart Rate vs Date
The data is available in the variable `features`, which is a list of dictionaries.
Each dictionary represents a run and has keys like:
- 'start_time': datetime object
- 'avg_pace_s_per_km': float (seconds per km)
- 'avg_hr_bpm': float (beats per minute)
- 'total_distance_m': float
- 'total_duration_s': float
When asked to generate charts:
1. Determine which chart types are appropriate.
2. Call `request_chart` for each chart needed.
3. `params_json` should be a JSON string of optional parameters (e.g., {"color": "blue"}).
Always request 'pace' and 'heart_rate' charts by default.
"""
def request_chart(self, chart_type: str, title: str, params_json: str = "{}") -> str:
"""
Records a request for a specific chart type with title and parameters.
"""
try:
params = json.loads(params_json)
spec = ChartSpec(chart_type=ChartType(chart_type.lower()), title=title, params=params)
self.requested_specs.append(spec)
return f"Successfully recorded request for {chart_type} chart: {title}"
except Exception as e:
logger.error(f"Error recording chart request: {e}")
return f"Error recording request: {str(e)}"
async def run(self, features: List[Dict[str, Any]], query: str = None) -> List[ChartSpec]:
with obs_logger.start_span("visualization_agent.run", obs_components.AGENT):
"""
Generates chart specifications using the LLM.
"""
self.requested_specs = []
if query:
prompt = f"""
The user has requested a specific visualization: "{query}"
There are {len(features)} runs available.
Please request the most appropriate chart(s) using `request_chart`.
"""
else:
prompt = f"""
Generate default visualizations for the provided {len(features)} runs.
Please request:
1. A pace chart ('pace')
2. A heart rate chart ('heart_rate')
"""
try:
await self.llm_client.generate(
prompt,
instruction=self.instruction,
tools=[self.request_chart],
name="visualization_agent",
)
except Exception as e:
logger.error(f"LLM specification failed: {e}")
# The context manager will already log the error upon exit
# Fallback if no specs were generated
if not self.requested_specs and not query:
self.requested_specs = [
ChartSpec(chart_type=ChartType.PACE, title="Pace Chart"),
ChartSpec(chart_type=ChartType.HEART_RATE, title="Heart Rate Chart"),
]
return self.requested_specs
|