File size: 6,737 Bytes
b27eb78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150cd80
b27eb78
4af9b90
b27eb78
 
 
 
 
 
 
65229ef
 
8c6064d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65229ef
 
 
8c6064d
 
65229ef
8c6064d
 
 
 
 
 
 
 
 
 
 
65229ef
 
8c6064d
65229ef
 
 
8c6064d
65229ef
 
 
b27eb78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from pathlib import Path
from typing import Any

import opik
from loguru import logger
from opik import opik_context
from smolagents import LiteLLMModel, MessageRole, MultiStepAgent, ToolCallingAgent

from second_brain_online.config import settings

from .tools import (
    HuggingFaceEndpointSummarizerTool,
    MongoDBRetrieverTool,
    OpenAISummarizerTool,
    what_can_i_do,
)


def get_agent(retriever_config_path: Path) -> "AgentWrapper":
    agent = AgentWrapper.build_from_smolagents(
        retriever_config_path=retriever_config_path
    )

    return agent


class AgentWrapper:
    def __init__(self, agent: MultiStepAgent) -> None:
        self.__agent = agent

    @property
    def input_messages(self) -> list[dict]:
        return self.__agent.input_messages

    @property
    def agent_name(self) -> str:
        return self.__agent.agent_name

    @property
    def max_steps(self) -> str:
        return self.__agent.max_steps

    @classmethod
    def build_from_smolagents(cls, retriever_config_path: Path) -> "AgentWrapper":
        retriever_tool = MongoDBRetrieverTool(config_path=retriever_config_path)
        if settings.USE_HUGGINGFACE_DEDICATED_ENDPOINT:
            logger.warning(
                f"Using Hugging Face dedicated endpoint as the summarizer with URL: {settings.HUGGINGFACE_DEDICATED_ENDPOINT}"
            )
            summarizer_tool = HuggingFaceEndpointSummarizerTool()
        else:
            logger.warning(
                f"Using OpenAI as the summarizer with model: {settings.OPENAI_MODEL_ID}"
            )
            summarizer_tool = OpenAISummarizerTool(stream=False)

        model = LiteLLMModel(
            model_id=settings.OPENAI_MODEL_ID,
            api_base="https://api.openai.com/v1",
            api_key=settings.OPENAI_API_KEY,
        )

        agent = ToolCallingAgent(
            tools=[what_can_i_do, retriever_tool, summarizer_tool],
            model=model,
            max_steps=4,  # Allow more steps for complex queries
            verbosity_level=2,
        )

        return cls(agent)

    @opik.track(name="Agent.run")
    def run(self, task: str, **kwargs) -> Any:
        result = self.__agent.run(task, return_full_result=True, **kwargs)
        
        # Debug: Print step structure to understand the data
        logger.info(f"Result type: {type(result)}")
        if hasattr(result, 'steps'):
            logger.info(f"Number of steps: {len(result.steps)}")
            for i, step in enumerate(result.steps):
                logger.info(f"Step {i}: type={type(step)}, keys={step.keys() if isinstance(step, dict) else 'not a dict'}")
                if isinstance(step, dict) and 'tool_calls' in step:
                    logger.info(f"  Tool calls: {step['tool_calls']}")
                    if step['tool_calls']:
                        for tc in step['tool_calls']:
                            tc_type = type(tc)
                            if isinstance(tc, dict):
                                logger.info(f"    Tool call dict: {tc}")
                            else:
                                logger.info(f"    Tool call object: {tc}, type: {tc_type}")
                                if hasattr(tc, 'function'):
                                    logger.info(f"      Function: {tc.function}")
                                if hasattr(tc, 'name'):
                                    logger.info(f"      Name: {tc.name}")
        
        # Extract the raw output from answer_with_sources (Step 2) instead of using final_answer
        if hasattr(result, 'steps') and len(result.steps) >= 2:
            # Find the step where answer_with_sources was called
            for step_idx, step in enumerate(result.steps):
                if isinstance(step, dict) and 'tool_calls' in step and step['tool_calls']:
                    for tool_call in step['tool_calls']:
                        # Handle both dict and object formats
                        tool_name = None
                        if isinstance(tool_call, dict):
                            tool_name = tool_call.get('function', {}).get('name')
                        elif hasattr(tool_call, 'function'):
                            if hasattr(tool_call.function, 'name'):
                                tool_name = tool_call.function.name
                        elif hasattr(tool_call, 'name'):
                            tool_name = tool_call.name
                        
                        if tool_name == 'answer_with_sources':
                            # Found the answer_with_sources step - return its observations
                            if 'observations' in step and step['observations']:
                                logger.info(f"✅ Found answer_with_sources at step {step_idx}, returning its observations")
                                return step['observations']
        
        # Fallback to regular result.output
        logger.warning("⚠️ answer_with_sources output not found, falling back to result.output")
        if hasattr(result, 'output'):
            return result.output
        
        return result


def extract_tool_responses(agent: ToolCallingAgent) -> str:
    """
    Extracts and concatenates all tool response contents with numbered observation delimiters.

    Args:
        input_messages (List[Dict]): List of message dictionaries containing 'role' and 'content' keys

    Returns:
        str: Tool response contents separated by numbered observation delimiters

    Example:
        >>> messages = [
        ...     {"role": MessageRole.TOOL_RESPONSE, "content": "First response"},
        ...     {"role": MessageRole.USER, "content": "Question"},
        ...     {"role": MessageRole.TOOL_RESPONSE, "content": "Second response"}
        ... ]
        >>> extract_tool_responses(messages)
        "-------- OBSERVATION 1 --------\nFirst response\n-------- OBSERVATION 2 --------\nSecond response"
    """

    tool_responses = [
        msg["content"]
        for msg in agent.input_messages
        if msg["role"] == MessageRole.TOOL_RESPONSE
    ]

    return "\n".join(
        f"-------- OBSERVATION {i + 1} --------\n{response}"
        for i, response in enumerate(tool_responses)
    )


class OpikAgentMonitorCallback:
    def __init__(self) -> None:
        self.output_state: dict = {}

    def __call__(self, step_log) -> None:
        input_state = {
            "agent_memory": step_log.agent_memory,
            "tool_calls": step_log.tool_calls,
        }
        self.output_state = {"observations": step_log.observations}

        self.trace(input_state)

    @opik.track(name="Callback.agent_step")
    def trace(self, step_log) -> dict:
        return self.output_state