File size: 7,323 Bytes
676582c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""
Cohere Provider Implementation

Cohere API provider with function calling support.
Optional provider (trial only, not recommended for production).
"""

import logging
from typing import List, Dict, Any
import cohere

from .base import LLMProvider, LLMResponse

logger = logging.getLogger(__name__)


class CohereProvider(LLMProvider):
    """
    Cohere API provider implementation.

    Features:
    - Native function calling support
    - Trial tier only (not recommended for production)
    - Model: command-r-plus (best for function calling)

    Note: Cohere requires a paid plan after trial expires.
    Use Gemini or OpenRouter for free-tier operation.
    """

    def __init__(
        self,
        api_key: str,
        model: str = "command-r-plus",
        temperature: float = 0.7,
        max_tokens: int = 8192
    ):
        super().__init__(api_key, model, temperature, max_tokens)
        self.client = cohere.Client(api_key)
        logger.info(f"Initialized CohereProvider with model: {model}")

    def _convert_tools_to_cohere_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Convert MCP tool definitions to Cohere tool format.

        Args:
            tools: MCP tool definitions

        Returns:
            List of Cohere-formatted tool definitions
        """
        return [
            {
                "name": tool["name"],
                "description": tool["description"],
                "parameter_definitions": tool["parameters"].get("properties", {})
            }
            for tool in tools
        ]

    async def generate_response_with_tools(
        self,
        messages: List[Dict[str, str]],
        system_prompt: str,
        tools: List[Dict[str, Any]]
    ) -> LLMResponse:
        """
        Generate a response with function calling support.

        Args:
            messages: Conversation history
            system_prompt: System instructions
            tools: Tool definitions

        Returns:
            LLMResponse with content and/or tool_calls
        """
        try:
            # Convert tools to Cohere format
            cohere_tools = self._convert_tools_to_cohere_format(tools)

            # Format chat history for Cohere
            chat_history = []
            for msg in messages[:-1]:  # All except last message
                chat_history.append({
                    "role": "USER" if msg["role"] == "user" else "CHATBOT",
                    "message": msg["content"]
                })

            # Last message is the current user message
            current_message = messages[-1]["content"] if messages else ""

            # Generate response with function calling
            response = self.client.chat(
                message=current_message,
                chat_history=chat_history,
                preamble=system_prompt,
                model=self.model,
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                tools=cohere_tools
            )

            # Check for tool calls
            if response.tool_calls:
                tool_calls = [
                    {
                        "name": tc.name,
                        "arguments": tc.parameters
                    }
                    for tc in response.tool_calls
                ]
                logger.info(f"Cohere requested function calls: {[tc['name'] for tc in tool_calls]}")
                return LLMResponse(
                    content=None,
                    tool_calls=tool_calls,
                    finish_reason="tool_calls"
                )

            # Regular text response
            content = response.text
            logger.info("Cohere generated text response")
            return LLMResponse(
                content=content,
                finish_reason="COMPLETE"
            )

        except Exception as e:
            logger.error(f"Cohere API error: {str(e)}")
            raise

    async def generate_response_with_tool_results(
        self,
        messages: List[Dict[str, str]],
        tool_calls: List[Dict[str, Any]],
        tool_results: List[Dict[str, Any]]
    ) -> LLMResponse:
        """
        Generate a final response after tool execution.

        Args:
            messages: Original conversation history
            tool_calls: Tool calls that were made
            tool_results: Results from tool execution

        Returns:
            LLMResponse with final content
        """
        try:
            # Format chat history
            chat_history = []
            for msg in messages:
                chat_history.append({
                    "role": "USER" if msg["role"] == "user" else "CHATBOT",
                    "message": msg["content"]
                })

            # Format tool results for Cohere
            tool_results_formatted = [
                {
                    "call": {"name": call["name"], "parameters": call["arguments"]},
                    "outputs": [{"result": str(result)}]
                }
                for call, result in zip(tool_calls, tool_results)
            ]

            # Generate final response
            response = self.client.chat(
                message="Based on the tool results, provide a natural language response.",
                chat_history=chat_history,
                model=self.model,
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                tool_results=tool_results_formatted
            )

            content = response.text
            logger.info("Cohere generated final response after tool execution")
            return LLMResponse(
                content=content,
                finish_reason="COMPLETE"
            )

        except Exception as e:
            logger.error(f"Cohere API error in tool results: {str(e)}")
            raise

    async def generate_simple_response(
        self,
        messages: List[Dict[str, str]],
        system_prompt: str
    ) -> LLMResponse:
        """
        Generate a simple response without function calling.

        Args:
            messages: Conversation history
            system_prompt: System instructions

        Returns:
            LLMResponse with content
        """
        try:
            # Format chat history
            chat_history = []
            for msg in messages[:-1]:
                chat_history.append({
                    "role": "USER" if msg["role"] == "user" else "CHATBOT",
                    "message": msg["content"]
                })

            current_message = messages[-1]["content"] if messages else ""

            # Generate response
            response = self.client.chat(
                message=current_message,
                chat_history=chat_history,
                preamble=system_prompt,
                model=self.model,
                temperature=self.temperature,
                max_tokens=self.max_tokens
            )

            content = response.text
            logger.info("Cohere generated simple response")
            return LLMResponse(
                content=content,
                finish_reason="COMPLETE"
            )

        except Exception as e:
            logger.error(f"Cohere API error: {str(e)}")
            raise