Spaces:
Runtime error
Runtime error
| """ | |
| Groq LangChain Integration Service | |
| This module provides integration with Groq's LangChain API for generating | |
| chat responses with programming language context and chat history support. | |
| """ | |
| import os | |
| import logging | |
| import time | |
| from typing import List, Dict, Any, Optional, Generator, AsyncGenerator | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from langchain_groq import ChatGroq | |
| from langchain.schema import HumanMessage, AIMessage, SystemMessage | |
| from groq import Groq | |
| from groq.types.chat import ChatCompletion | |
| from ..utils.error_handler import ChatAgentError, ErrorCategory, ErrorSeverity, get_error_handler | |
| from ..utils.circuit_breaker import circuit_breaker, CircuitBreakerConfig, get_circuit_breaker_manager | |
| from ..utils.logging_config import get_logger, get_performance_logger | |
| logger = get_logger('groq_client') | |
| performance_logger = get_performance_logger('groq_client') | |
| # Legacy error classes for backward compatibility | |
| class GroqError(ChatAgentError): | |
| """Base exception for Groq API errors""" | |
| def __init__(self, message: str, **kwargs): | |
| super().__init__(message, category=ErrorCategory.API_ERROR, **kwargs) | |
| class GroqRateLimitError(ChatAgentError): | |
| """Exception raised when API rate limits are exceeded""" | |
| def __init__(self, message: str, **kwargs): | |
| super().__init__(message, category=ErrorCategory.RATE_LIMIT_ERROR, **kwargs) | |
| class GroqAuthenticationError(ChatAgentError): | |
| """Exception raised when API authentication fails""" | |
| def __init__(self, message: str, **kwargs): | |
| super().__init__(message, category=ErrorCategory.AUTHENTICATION_ERROR, severity=ErrorSeverity.HIGH, **kwargs) | |
| class GroqNetworkError(ChatAgentError): | |
| """Exception raised when network errors occur""" | |
| def __init__(self, message: str, **kwargs): | |
| super().__init__(message, category=ErrorCategory.NETWORK_ERROR, **kwargs) | |
| class ChatMessage: | |
| """Represents a chat message with role and content""" | |
| role: str # 'user', 'assistant', 'system' | |
| content: str | |
| language: Optional[str] = None | |
| timestamp: Optional[str] = None | |
| class LanguageContext: | |
| """Represents programming language context for chat""" | |
| language: str | |
| prompt_template: str | |
| syntax_highlighting: str | |
| class GroqClient: | |
| """ | |
| Groq LangChain integration client for chat-based programming assistance. | |
| Provides methods for generating responses with chat history context, | |
| language-specific prompts, and streaming capabilities. | |
| """ | |
| def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None): | |
| """ | |
| Initialize Groq client with API authentication and configuration. | |
| Args: | |
| api_key: Groq API key (defaults to GROQ_API_KEY env var) | |
| model: Model name (defaults to GROQ_MODEL env var or mixtral-8x7b-32768) | |
| """ | |
| self.api_key = api_key or os.getenv('GROQ_API_KEY') | |
| self.model = model or os.getenv('GROQ_MODEL', 'llama-3.1-8b-instant') | |
| if not self.api_key: | |
| raise GroqAuthenticationError("Groq API key not provided") | |
| # Configuration from environment | |
| self.max_tokens = int(os.getenv('MAX_TOKENS', '2048')) | |
| self.temperature = float(os.getenv('TEMPERATURE', '0.7')) | |
| self.stream_responses = os.getenv('STREAM_RESPONSES', 'True').lower() == 'true' | |
| # Initialize error handler | |
| self.error_handler = get_error_handler() | |
| # Initialize circuit breaker for API calls | |
| circuit_config = CircuitBreakerConfig( | |
| failure_threshold=5, | |
| recovery_timeout=60, | |
| success_threshold=3, | |
| timeout=30.0, | |
| expected_exception=(Exception,) | |
| ) | |
| circuit_manager = get_circuit_breaker_manager() | |
| self.circuit_breaker = circuit_manager.create_breaker( | |
| name="groq_api", | |
| config=circuit_config, | |
| fallback_function=self._fallback_response | |
| ) | |
| # Initialize clients | |
| self._initialize_clients() | |
| # Rate limiting and retry configuration | |
| self.max_retries = 3 | |
| self.base_delay = 1.0 | |
| self.max_delay = 60.0 | |
| logger.info(f"GroqClient initialized with model: {self.model}", extra={ | |
| 'model': self.model, | |
| 'max_tokens': self.max_tokens, | |
| 'temperature': self.temperature, | |
| 'circuit_breaker': 'enabled' | |
| }) | |
| def _initialize_clients(self): | |
| """Initialize Groq and LangChain clients with error handling""" | |
| try: | |
| # Initialize direct Groq client for streaming | |
| self.groq_client = Groq(api_key=self.api_key) | |
| # Initialize LangChain Groq client | |
| self.langchain_client = ChatGroq( | |
| groq_api_key=self.api_key, | |
| model_name=self.model, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens | |
| ) | |
| logger.info("Groq clients initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Groq clients: {str(e)}") | |
| raise GroqAuthenticationError(f"Client initialization failed: {str(e)}") | |
| def generate_response( | |
| self, | |
| prompt: str, | |
| chat_history: List[ChatMessage], | |
| language_context: LanguageContext, | |
| stream: bool = False | |
| ) -> str: | |
| """ | |
| Generate response using Groq LangChain API with chat history and language context. | |
| Args: | |
| prompt: User's input message | |
| chat_history: List of previous chat messages for context | |
| language_context: Programming language context and templates | |
| stream: Whether to return streaming response | |
| Returns: | |
| Generated response string | |
| Raises: | |
| ChatAgentError: For various API-related errors | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Build messages with context | |
| messages = self._build_messages(prompt, chat_history, language_context) | |
| # Use circuit breaker for API call | |
| if stream: | |
| result = self.circuit_breaker.call(self._generate_streaming_response, messages) | |
| else: | |
| result = self.circuit_breaker.call(self._generate_standard_response, messages) | |
| # Log performance | |
| duration = time.time() - start_time | |
| performance_logger.log_operation( | |
| operation="generate_response", | |
| duration=duration, | |
| context={ | |
| 'model': self.model, | |
| 'language': language_context.language, | |
| 'stream': stream, | |
| 'message_count': len(messages), | |
| 'prompt_length': len(prompt) | |
| } | |
| ) | |
| return result | |
| except ChatAgentError: | |
| # Re-raise ChatAgentError as-is | |
| raise | |
| except Exception as e: | |
| # Handle and convert other exceptions | |
| duration = time.time() - start_time | |
| context = { | |
| 'model': self.model, | |
| 'language': language_context.language, | |
| 'stream': stream, | |
| 'duration': duration, | |
| 'prompt_length': len(prompt) | |
| } | |
| chat_error = self.error_handler.handle_error(e, context) | |
| # Return fallback response instead of raising for user-facing calls | |
| return self.error_handler.get_fallback_response(chat_error) | |
| def stream_response( | |
| self, | |
| prompt: str, | |
| chat_history: List[ChatMessage], | |
| language_context: LanguageContext | |
| ) -> Generator[str, None, None]: | |
| """ | |
| Generate streaming response for real-time chat experience. | |
| Args: | |
| prompt: User's input message | |
| chat_history: List of previous chat messages for context | |
| language_context: Programming language context and templates | |
| Yields: | |
| Response chunks as they are generated | |
| Raises: | |
| ChatAgentError: For various API-related errors | |
| """ | |
| start_time = time.time() | |
| chunk_count = 0 | |
| try: | |
| messages = self._build_messages(prompt, chat_history, language_context) | |
| # Use circuit breaker for streaming API call | |
| def _stream_call(): | |
| return self.groq_client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": msg.role, "content": msg.content} for msg in messages], | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens, | |
| stream=True | |
| ) | |
| response = self.circuit_breaker.call(_stream_call) | |
| for chunk in response: | |
| if chunk.choices[0].delta.content: | |
| chunk_count += 1 | |
| yield chunk.choices[0].delta.content | |
| # Log performance | |
| duration = time.time() - start_time | |
| performance_logger.log_operation( | |
| operation="stream_response", | |
| duration=duration, | |
| context={ | |
| 'model': self.model, | |
| 'language': language_context.language, | |
| 'message_count': len(messages), | |
| 'chunk_count': chunk_count, | |
| 'prompt_length': len(prompt) | |
| } | |
| ) | |
| except ChatAgentError as e: | |
| # Yield fallback response for chat errors | |
| fallback = self.error_handler.get_fallback_response(e) | |
| for word in fallback.split(): | |
| yield word + " " | |
| time.sleep(0.05) # Simulate streaming | |
| except Exception as e: | |
| # Handle and convert other exceptions | |
| duration = time.time() - start_time | |
| context = { | |
| 'model': self.model, | |
| 'language': language_context.language, | |
| 'duration': duration, | |
| 'chunk_count': chunk_count, | |
| 'prompt_length': len(prompt) | |
| } | |
| chat_error = self.error_handler.handle_error(e, context) | |
| fallback = self.error_handler.get_fallback_response(chat_error) | |
| # Yield fallback response as streaming chunks | |
| for word in fallback.split(): | |
| yield word + " " | |
| time.sleep(0.05) # Simulate streaming | |
| def _build_messages( | |
| self, | |
| prompt: str, | |
| chat_history: List[ChatMessage], | |
| language_context: LanguageContext | |
| ) -> List[ChatMessage]: | |
| """ | |
| Build message list with system prompt, chat history, and current prompt. | |
| Args: | |
| prompt: Current user message | |
| chat_history: Previous conversation messages | |
| language_context: Programming language context | |
| Returns: | |
| List of formatted chat messages | |
| """ | |
| messages = [] | |
| # Add system message with language context | |
| system_prompt = language_context.prompt_template.format( | |
| language=language_context.language | |
| ) | |
| messages.append(ChatMessage(role="system", content=system_prompt)) | |
| # Add chat history (limit to recent messages to stay within context window) | |
| context_window = int(os.getenv('CONTEXT_WINDOW_SIZE', '10')) | |
| recent_history = chat_history[-context_window:] if chat_history else [] | |
| for msg in recent_history: | |
| messages.append(msg) | |
| # Add current user message | |
| messages.append(ChatMessage(role="user", content=prompt, language=language_context.language)) | |
| return messages | |
| def _generate_standard_response(self, messages: List[ChatMessage]) -> str: | |
| """Generate standard (non-streaming) response""" | |
| langchain_messages = [] | |
| for msg in messages: | |
| if msg.role == "system": | |
| langchain_messages.append(SystemMessage(content=msg.content)) | |
| elif msg.role == "user": | |
| langchain_messages.append(HumanMessage(content=msg.content)) | |
| elif msg.role == "assistant": | |
| langchain_messages.append(AIMessage(content=msg.content)) | |
| response = self.langchain_client.invoke(langchain_messages) | |
| return response.content | |
| def _generate_streaming_response(self, messages: List[ChatMessage]) -> str: | |
| """Generate response using streaming and return complete response""" | |
| response_chunks = [] | |
| for chunk in self.stream_response("", [], LanguageContext("python", "", "")): | |
| response_chunks.append(chunk) | |
| return "".join(response_chunks) | |
| def _fallback_response(self, *args, **kwargs) -> str: | |
| """ | |
| Provide fallback response when circuit breaker is open. | |
| Returns: | |
| Fallback response string | |
| """ | |
| return ("I'm currently experiencing high demand and need a moment to catch up. " | |
| "While you wait, here are some general programming tips:\n\n" | |
| "• Break down complex problems into smaller steps\n" | |
| "• Use descriptive variable names\n" | |
| "• Add comments to explain your logic\n" | |
| "• Test your code frequently\n\n" | |
| "Please try your question again in a moment!") | |
| def _handle_api_error(self, error: Exception) -> str: | |
| """ | |
| Handle various API errors with appropriate fallback responses. | |
| Args: | |
| error: Exception that occurred during API call | |
| Returns: | |
| Fallback error message for user | |
| Raises: | |
| ChatAgentError: Re-raises as appropriate error type | |
| """ | |
| error_str = str(error).lower() | |
| if "rate limit" in error_str or "429" in error_str: | |
| logger.warning(f"Rate limit exceeded: {error}") | |
| raise GroqRateLimitError("API rate limit exceeded", context={'original_error': str(error)}) | |
| elif "authentication" in error_str or "401" in error_str: | |
| logger.error(f"Authentication error: {error}") | |
| raise GroqAuthenticationError("API authentication failed", context={'original_error': str(error)}) | |
| elif "network" in error_str or "connection" in error_str: | |
| logger.error(f"Network error: {error}") | |
| raise GroqNetworkError("Network connection failed", context={'original_error': str(error)}) | |
| elif "quota" in error_str or "billing" in error_str: | |
| logger.error(f"Quota/billing error: {error}") | |
| raise GroqError("API quota exceeded", context={'original_error': str(error)}) | |
| else: | |
| logger.error(f"Unexpected API error: {error}") | |
| raise GroqError("Unexpected API error", context={'original_error': str(error)}) | |
| def test_connection(self) -> bool: | |
| """ | |
| Test connection to Groq API. | |
| Returns: | |
| True if connection successful, False otherwise | |
| """ | |
| try: | |
| # Simple test message | |
| test_messages = [ | |
| ChatMessage(role="system", content="You are a helpful assistant."), | |
| ChatMessage(role="user", content="Hello") | |
| ] | |
| response = self._generate_standard_response(test_messages) | |
| logger.info("Groq API connection test successful") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Groq API connection test failed: {e}") | |
| return False | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """ | |
| Get information about the current model configuration. | |
| Returns: | |
| Dictionary with model configuration details | |
| """ | |
| return { | |
| "model": self.model, | |
| "max_tokens": self.max_tokens, | |
| "temperature": self.temperature, | |
| "stream_responses": self.stream_responses, | |
| "api_key_configured": bool(self.api_key) | |
| } | |
| # Default language prompt templates | |
| DEFAULT_LANGUAGE_TEMPLATES = { | |
| "python": """You are a helpful programming assistant specializing in Python. | |
| You help students learn Python programming by providing clear explanations, | |
| debugging assistance, and code examples. Always use Python syntax and best practices. | |
| Keep explanations beginner-friendly and provide practical examples.""", | |
| "javascript": """You are a helpful programming assistant specializing in JavaScript. | |
| You help students learn JavaScript programming by providing clear explanations, | |
| debugging assistance, and code examples. Always use modern JavaScript (ES6+) syntax. | |
| Keep explanations beginner-friendly and provide practical examples.""", | |
| "java": """You are a helpful programming assistant specializing in Java. | |
| You help students learn Java programming by providing clear explanations, | |
| debugging assistance, and code examples. Always use modern Java syntax and best practices. | |
| Keep explanations beginner-friendly and provide practical examples.""", | |
| "cpp": """You are a helpful programming assistant specializing in C++. | |
| You help students learn C++ programming by providing clear explanations, | |
| debugging assistance, and code examples. Always use modern C++ (C++11 or later) syntax. | |
| Keep explanations beginner-friendly and provide practical examples.""" | |
| } | |
| def create_language_context(language: str) -> LanguageContext: | |
| """ | |
| Create language context with appropriate prompt template. | |
| Args: | |
| language: Programming language name | |
| Returns: | |
| LanguageContext object with template and settings | |
| """ | |
| template = DEFAULT_LANGUAGE_TEMPLATES.get( | |
| language.lower(), | |
| DEFAULT_LANGUAGE_TEMPLATES["python"] | |
| ) | |
| return LanguageContext( | |
| language=language, | |
| prompt_template=template, | |
| syntax_highlighting=language.lower() | |
| ) |