Spaces:
Sleeping
Sleeping
Claude Code - Backend Implementation Specialist
Deploy AI-powered Todo application to Hugging Face Spaces
f2b5c2a | """ | |
| Cohere client service for AI chatbot. | |
| This module provides a wrapper around the Cohere API with: | |
| - API key management | |
| - Retry logic for transient failures | |
| - Timeout handling | |
| - Structured logging | |
| - Token usage tracking | |
| """ | |
| import os | |
| import logging | |
| import time | |
| from typing import List, Dict, Any, Optional | |
| from tenacity import ( | |
| retry, | |
| stop_after_attempt, | |
| wait_exponential, | |
| retry_if_exception_type | |
| ) | |
| import cohere | |
| from cohere.errors import TooManyRequestsError, ServiceUnavailableError | |
| logger = logging.getLogger(__name__) | |
| class CohereClient: | |
| """ | |
| Cohere API client with retry logic and structured logging. | |
| This client is specifically configured for the AI chatbot use case | |
| with deterministic temperature and tool-calling support. | |
| """ | |
| def __init__(self): | |
| """Initialize Cohere client with environment configuration.""" | |
| self.api_key = os.getenv("COHERE_API_KEY") | |
| if not self.api_key: | |
| raise ValueError("COHERE_API_KEY not found in environment variables") | |
| self.model = os.getenv("COHERE_MODEL", "command-r-plus") | |
| self.temperature = float(os.getenv("COHERE_TEMPERATURE", "0.3")) | |
| self.max_tokens = int(os.getenv("COHERE_MAX_TOKENS", "2000")) | |
| self.timeout = int(os.getenv("COHERE_TIMEOUT", "30")) | |
| # Initialize Cohere client | |
| self.client = cohere.ClientV2(self.api_key) | |
| logger.info(f"Cohere client initialized with model: {self.model}") | |
| async def chat( | |
| self, | |
| messages: List[Dict[str, str]], | |
| tools: Optional[List[Dict[str, Any]]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Send chat request to Cohere API with retry logic. | |
| Args: | |
| messages: List of message dictionaries with 'role' and 'content' | |
| tools: Optional list of tool definitions for tool-calling | |
| Returns: | |
| Dictionary containing response and tool calls (if any) | |
| Raises: | |
| Exception: If API call fails after retries | |
| """ | |
| start_time = time.time() | |
| try: | |
| logger.info(f"Sending chat request to Cohere (model: {self.model})") | |
| logger.debug(f"Messages: {len(messages)}, Tools: {len(tools) if tools else 0}") | |
| response = self.client.chat( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens, | |
| tools=tools if tools else None | |
| ) | |
| latency = time.time() - start_time | |
| # Debug: Print full response structure | |
| logger.info(f"Cohere response received: {response}") | |
| logger.info(f"Response dict: {response.__dict__ if hasattr(response, '__dict__') else 'No dict'}") | |
| # Extract response content | |
| response_text = "" | |
| if hasattr(response, 'message') and hasattr(response.message, 'content') and response.message.content: | |
| for item in response.message.content: | |
| if hasattr(item, 'text'): | |
| response_text = item.text | |
| break | |
| # Extract tool calls if present | |
| tool_calls = [] | |
| if hasattr(response.message, 'tool_calls') and response.message.tool_calls: | |
| import json | |
| for tool_call in response.message.tool_calls: | |
| try: | |
| # Parse JSON string arguments into dictionary | |
| arguments = json.loads(tool_call.function.arguments) if isinstance(tool_call.function.arguments, str) else tool_call.function.arguments | |
| tool_calls.append({ | |
| "name": tool_call.function.name, | |
| "parameters": arguments | |
| }) | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Failed to parse tool call arguments: {e}") | |
| continue | |
| # Log metrics | |
| logger.info(f"Cohere API call successful (latency: {latency:.2f}s)") | |
| if hasattr(response, 'usage'): | |
| logger.info(f"Token usage - Input: {response.usage.tokens.input_tokens}, " | |
| f"Output: {response.usage.tokens.output_tokens}") | |
| return { | |
| "response": response_text, | |
| "tool_calls": tool_calls, | |
| "latency": latency | |
| } | |
| except TooManyRequestsError as e: | |
| logger.warning(f"Rate limit hit: {str(e)}") | |
| raise | |
| except ServiceUnavailableError as e: | |
| logger.error(f"Cohere service unavailable: {str(e)}") | |
| raise | |
| except Exception as e: | |
| import traceback | |
| logger.error(f"Cohere API call failed: {str(e)}") | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| raise | |
| def validate_tool_call(self, tool_call: Dict[str, Any]) -> bool: | |
| """ | |
| Validate that a tool call has the required structure. | |
| Args: | |
| tool_call: Tool call dictionary to validate | |
| Returns: | |
| True if valid, False otherwise | |
| """ | |
| if not isinstance(tool_call, dict): | |
| return False | |
| if "name" not in tool_call or "parameters" not in tool_call: | |
| return False | |
| if not isinstance(tool_call["name"], str): | |
| return False | |
| if not isinstance(tool_call["parameters"], dict): | |
| return False | |
| return True | |
| # Global Cohere client instance | |
| cohere_client = CohereClient() | |