File size: 5,832 Bytes
f2b5c2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
Cohere client service for AI chatbot.

This module provides a wrapper around the Cohere API with:
- API key management
- Retry logic for transient failures
- Timeout handling
- Structured logging
- Token usage tracking
"""

import os
import logging
import time
from typing import List, Dict, Any, Optional
from tenacity import (
    retry,
    stop_after_attempt,
    wait_exponential,
    retry_if_exception_type
)
import cohere
from cohere.errors import TooManyRequestsError, ServiceUnavailableError

logger = logging.getLogger(__name__)


class CohereClient:
    """
    Cohere API client with retry logic and structured logging.

    This client is specifically configured for the AI chatbot use case
    with deterministic temperature and tool-calling support.
    """

    def __init__(self):
        """Initialize Cohere client with environment configuration."""
        self.api_key = os.getenv("COHERE_API_KEY")
        if not self.api_key:
            raise ValueError("COHERE_API_KEY not found in environment variables")

        self.model = os.getenv("COHERE_MODEL", "command-r-plus")
        self.temperature = float(os.getenv("COHERE_TEMPERATURE", "0.3"))
        self.max_tokens = int(os.getenv("COHERE_MAX_TOKENS", "2000"))
        self.timeout = int(os.getenv("COHERE_TIMEOUT", "30"))

        # Initialize Cohere client
        self.client = cohere.ClientV2(self.api_key)
        logger.info(f"Cohere client initialized with model: {self.model}")

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        retry=retry_if_exception_type((TooManyRequestsError, ServiceUnavailableError))
    )
    async def chat(
        self,
        messages: List[Dict[str, str]],
        tools: Optional[List[Dict[str, Any]]] = None
    ) -> Dict[str, Any]:
        """
        Send chat request to Cohere API with retry logic.

        Args:
            messages: List of message dictionaries with 'role' and 'content'
            tools: Optional list of tool definitions for tool-calling

        Returns:
            Dictionary containing response and tool calls (if any)

        Raises:
            Exception: If API call fails after retries
        """
        start_time = time.time()

        try:
            logger.info(f"Sending chat request to Cohere (model: {self.model})")
            logger.debug(f"Messages: {len(messages)}, Tools: {len(tools) if tools else 0}")

            response = self.client.chat(
                model=self.model,
                messages=messages,
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                tools=tools if tools else None
            )

            latency = time.time() - start_time

            # Debug: Print full response structure
            logger.info(f"Cohere response received: {response}")
            logger.info(f"Response dict: {response.__dict__ if hasattr(response, '__dict__') else 'No dict'}")

            # Extract response content
            response_text = ""
            if hasattr(response, 'message') and hasattr(response.message, 'content') and response.message.content:
                for item in response.message.content:
                    if hasattr(item, 'text'):
                        response_text = item.text
                        break

            # Extract tool calls if present
            tool_calls = []
            if hasattr(response.message, 'tool_calls') and response.message.tool_calls:
                import json
                for tool_call in response.message.tool_calls:
                    try:
                        # Parse JSON string arguments into dictionary
                        arguments = json.loads(tool_call.function.arguments) if isinstance(tool_call.function.arguments, str) else tool_call.function.arguments
                        tool_calls.append({
                            "name": tool_call.function.name,
                            "parameters": arguments
                        })
                    except json.JSONDecodeError as e:
                        logger.error(f"Failed to parse tool call arguments: {e}")
                        continue

            # Log metrics
            logger.info(f"Cohere API call successful (latency: {latency:.2f}s)")
            if hasattr(response, 'usage'):
                logger.info(f"Token usage - Input: {response.usage.tokens.input_tokens}, "
                          f"Output: {response.usage.tokens.output_tokens}")

            return {
                "response": response_text,
                "tool_calls": tool_calls,
                "latency": latency
            }

        except TooManyRequestsError as e:
            logger.warning(f"Rate limit hit: {str(e)}")
            raise
        except ServiceUnavailableError as e:
            logger.error(f"Cohere service unavailable: {str(e)}")
            raise
        except Exception as e:
            import traceback
            logger.error(f"Cohere API call failed: {str(e)}")
            logger.error(f"Traceback: {traceback.format_exc()}")
            raise

    def validate_tool_call(self, tool_call: Dict[str, Any]) -> bool:
        """
        Validate that a tool call has the required structure.

        Args:
            tool_call: Tool call dictionary to validate

        Returns:
            True if valid, False otherwise
        """
        if not isinstance(tool_call, dict):
            return False

        if "name" not in tool_call or "parameters" not in tool_call:
            return False

        if not isinstance(tool_call["name"], str):
            return False

        if not isinstance(tool_call["parameters"], dict):
            return False

        return True


# Global Cohere client instance
cohere_client = CohereClient()