| """ |
| Cohere Provider Implementation |
| |
| Cohere API provider with function calling support. |
| Optional provider (trial only, not recommended for production). |
| """ |
|
|
| import logging |
| from typing import List, Dict, Any |
| import cohere |
|
|
| from .base import LLMProvider, LLMResponse |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class CohereProvider(LLMProvider): |
| """ |
| Cohere API provider implementation. |
| |
| Features: |
| - Native function calling support |
| - Trial tier only (not recommended for production) |
| - Model: command-r-plus (best for function calling) |
| |
| Note: Cohere requires a paid plan after trial expires. |
| Use Gemini or OpenRouter for free-tier operation. |
| """ |
|
|
| def __init__( |
| self, |
| api_key: str, |
| model: str = "command-r-plus", |
| temperature: float = 0.7, |
| max_tokens: int = 8192 |
| ): |
| super().__init__(api_key, model, temperature, max_tokens) |
| self.client = cohere.Client(api_key) |
| logger.info(f"Initialized CohereProvider with model: {model}") |
|
|
| def _convert_tools_to_cohere_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """ |
| Convert MCP tool definitions to Cohere tool format. |
| |
| Args: |
| tools: MCP tool definitions |
| |
| Returns: |
| List of Cohere-formatted tool definitions |
| """ |
| return [ |
| { |
| "name": tool["name"], |
| "description": tool["description"], |
| "parameter_definitions": tool["parameters"].get("properties", {}) |
| } |
| for tool in tools |
| ] |
|
|
| async def generate_response_with_tools( |
| self, |
| messages: List[Dict[str, str]], |
| system_prompt: str, |
| tools: List[Dict[str, Any]] |
| ) -> LLMResponse: |
| """ |
| Generate a response with function calling support. |
| |
| Args: |
| messages: Conversation history |
| system_prompt: System instructions |
| tools: Tool definitions |
| |
| Returns: |
| LLMResponse with content and/or tool_calls |
| """ |
| try: |
| |
| cohere_tools = self._convert_tools_to_cohere_format(tools) |
|
|
| |
| chat_history = [] |
| for msg in messages[:-1]: |
| chat_history.append({ |
| "role": "USER" if msg["role"] == "user" else "CHATBOT", |
| "message": msg["content"] |
| }) |
|
|
| |
| current_message = messages[-1]["content"] if messages else "" |
|
|
| |
| response = self.client.chat( |
| message=current_message, |
| chat_history=chat_history, |
| preamble=system_prompt, |
| model=self.model, |
| temperature=self.temperature, |
| max_tokens=self.max_tokens, |
| tools=cohere_tools |
| ) |
|
|
| |
| if response.tool_calls: |
| tool_calls = [ |
| { |
| "name": tc.name, |
| "arguments": tc.parameters |
| } |
| for tc in response.tool_calls |
| ] |
| logger.info(f"Cohere requested function calls: {[tc['name'] for tc in tool_calls]}") |
| return LLMResponse( |
| content=None, |
| tool_calls=tool_calls, |
| finish_reason="tool_calls" |
| ) |
|
|
| |
| content = response.text |
| logger.info("Cohere generated text response") |
| return LLMResponse( |
| content=content, |
| finish_reason="COMPLETE" |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Cohere API error: {str(e)}") |
| raise |
|
|
| async def generate_response_with_tool_results( |
| self, |
| messages: List[Dict[str, str]], |
| tool_calls: List[Dict[str, Any]], |
| tool_results: List[Dict[str, Any]] |
| ) -> LLMResponse: |
| """ |
| Generate a final response after tool execution. |
| |
| Args: |
| messages: Original conversation history |
| tool_calls: Tool calls that were made |
| tool_results: Results from tool execution |
| |
| Returns: |
| LLMResponse with final content |
| """ |
| try: |
| |
| chat_history = [] |
| for msg in messages: |
| chat_history.append({ |
| "role": "USER" if msg["role"] == "user" else "CHATBOT", |
| "message": msg["content"] |
| }) |
|
|
| |
| tool_results_formatted = [ |
| { |
| "call": {"name": call["name"], "parameters": call["arguments"]}, |
| "outputs": [{"result": str(result)}] |
| } |
| for call, result in zip(tool_calls, tool_results) |
| ] |
|
|
| |
| response = self.client.chat( |
| message="Based on the tool results, provide a natural language response.", |
| chat_history=chat_history, |
| model=self.model, |
| temperature=self.temperature, |
| max_tokens=self.max_tokens, |
| tool_results=tool_results_formatted |
| ) |
|
|
| content = response.text |
| logger.info("Cohere generated final response after tool execution") |
| return LLMResponse( |
| content=content, |
| finish_reason="COMPLETE" |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Cohere API error in tool results: {str(e)}") |
| raise |
|
|
| async def generate_simple_response( |
| self, |
| messages: List[Dict[str, str]], |
| system_prompt: str |
| ) -> LLMResponse: |
| """ |
| Generate a simple response without function calling. |
| |
| Args: |
| messages: Conversation history |
| system_prompt: System instructions |
| |
| Returns: |
| LLMResponse with content |
| """ |
| try: |
| |
| chat_history = [] |
| for msg in messages[:-1]: |
| chat_history.append({ |
| "role": "USER" if msg["role"] == "user" else "CHATBOT", |
| "message": msg["content"] |
| }) |
|
|
| current_message = messages[-1]["content"] if messages else "" |
|
|
| |
| response = self.client.chat( |
| message=current_message, |
| chat_history=chat_history, |
| preamble=system_prompt, |
| model=self.model, |
| temperature=self.temperature, |
| max_tokens=self.max_tokens |
| ) |
|
|
| content = response.text |
| logger.info("Cohere generated simple response") |
| return LLMResponse( |
| content=content, |
| finish_reason="COMPLETE" |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Cohere API error: {str(e)}") |
| raise |
|
|