scratch_chat / chat_agent /services /groq_client.py
WebashalarForML's picture
Upload 178 files
330b6e4 verified
"""
Groq LangChain Integration Service
This module provides integration with Groq's LangChain API for generating
chat responses with programming language context and chat history support.
"""
import os
import logging
import time
from typing import List, Dict, Any, Optional, Generator, AsyncGenerator
from dataclasses import dataclass
from enum import Enum
from langchain_groq import ChatGroq
from langchain.schema import HumanMessage, AIMessage, SystemMessage
from groq import Groq
from groq.types.chat import ChatCompletion
from ..utils.error_handler import ChatAgentError, ErrorCategory, ErrorSeverity, get_error_handler
from ..utils.circuit_breaker import circuit_breaker, CircuitBreakerConfig, get_circuit_breaker_manager
from ..utils.logging_config import get_logger, get_performance_logger
logger = get_logger('groq_client')
performance_logger = get_performance_logger('groq_client')
# Legacy error classes for backward compatibility
class GroqError(ChatAgentError):
"""Base exception for Groq API errors"""
def __init__(self, message: str, **kwargs):
super().__init__(message, category=ErrorCategory.API_ERROR, **kwargs)
class GroqRateLimitError(ChatAgentError):
"""Exception raised when API rate limits are exceeded"""
def __init__(self, message: str, **kwargs):
super().__init__(message, category=ErrorCategory.RATE_LIMIT_ERROR, **kwargs)
class GroqAuthenticationError(ChatAgentError):
"""Exception raised when API authentication fails"""
def __init__(self, message: str, **kwargs):
super().__init__(message, category=ErrorCategory.AUTHENTICATION_ERROR, severity=ErrorSeverity.HIGH, **kwargs)
class GroqNetworkError(ChatAgentError):
"""Exception raised when network errors occur"""
def __init__(self, message: str, **kwargs):
super().__init__(message, category=ErrorCategory.NETWORK_ERROR, **kwargs)
@dataclass
class ChatMessage:
"""Represents a chat message with role and content"""
role: str # 'user', 'assistant', 'system'
content: str
language: Optional[str] = None
timestamp: Optional[str] = None
@dataclass
class LanguageContext:
"""Represents programming language context for chat"""
language: str
prompt_template: str
syntax_highlighting: str
class GroqClient:
"""
Groq LangChain integration client for chat-based programming assistance.
Provides methods for generating responses with chat history context,
language-specific prompts, and streaming capabilities.
"""
def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
"""
Initialize Groq client with API authentication and configuration.
Args:
api_key: Groq API key (defaults to GROQ_API_KEY env var)
model: Model name (defaults to GROQ_MODEL env var or mixtral-8x7b-32768)
"""
self.api_key = api_key or os.getenv('GROQ_API_KEY')
self.model = model or os.getenv('GROQ_MODEL', 'llama-3.1-8b-instant')
if not self.api_key:
raise GroqAuthenticationError("Groq API key not provided")
# Configuration from environment
self.max_tokens = int(os.getenv('MAX_TOKENS', '2048'))
self.temperature = float(os.getenv('TEMPERATURE', '0.7'))
self.stream_responses = os.getenv('STREAM_RESPONSES', 'True').lower() == 'true'
# Initialize error handler
self.error_handler = get_error_handler()
# Initialize circuit breaker for API calls
circuit_config = CircuitBreakerConfig(
failure_threshold=5,
recovery_timeout=60,
success_threshold=3,
timeout=30.0,
expected_exception=(Exception,)
)
circuit_manager = get_circuit_breaker_manager()
self.circuit_breaker = circuit_manager.create_breaker(
name="groq_api",
config=circuit_config,
fallback_function=self._fallback_response
)
# Initialize clients
self._initialize_clients()
# Rate limiting and retry configuration
self.max_retries = 3
self.base_delay = 1.0
self.max_delay = 60.0
logger.info(f"GroqClient initialized with model: {self.model}", extra={
'model': self.model,
'max_tokens': self.max_tokens,
'temperature': self.temperature,
'circuit_breaker': 'enabled'
})
def _initialize_clients(self):
"""Initialize Groq and LangChain clients with error handling"""
try:
# Initialize direct Groq client for streaming
self.groq_client = Groq(api_key=self.api_key)
# Initialize LangChain Groq client
self.langchain_client = ChatGroq(
groq_api_key=self.api_key,
model_name=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens
)
logger.info("Groq clients initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Groq clients: {str(e)}")
raise GroqAuthenticationError(f"Client initialization failed: {str(e)}")
def generate_response(
self,
prompt: str,
chat_history: List[ChatMessage],
language_context: LanguageContext,
stream: bool = False
) -> str:
"""
Generate response using Groq LangChain API with chat history and language context.
Args:
prompt: User's input message
chat_history: List of previous chat messages for context
language_context: Programming language context and templates
stream: Whether to return streaming response
Returns:
Generated response string
Raises:
ChatAgentError: For various API-related errors
"""
start_time = time.time()
try:
# Build messages with context
messages = self._build_messages(prompt, chat_history, language_context)
# Use circuit breaker for API call
if stream:
result = self.circuit_breaker.call(self._generate_streaming_response, messages)
else:
result = self.circuit_breaker.call(self._generate_standard_response, messages)
# Log performance
duration = time.time() - start_time
performance_logger.log_operation(
operation="generate_response",
duration=duration,
context={
'model': self.model,
'language': language_context.language,
'stream': stream,
'message_count': len(messages),
'prompt_length': len(prompt)
}
)
return result
except ChatAgentError:
# Re-raise ChatAgentError as-is
raise
except Exception as e:
# Handle and convert other exceptions
duration = time.time() - start_time
context = {
'model': self.model,
'language': language_context.language,
'stream': stream,
'duration': duration,
'prompt_length': len(prompt)
}
chat_error = self.error_handler.handle_error(e, context)
# Return fallback response instead of raising for user-facing calls
return self.error_handler.get_fallback_response(chat_error)
def stream_response(
self,
prompt: str,
chat_history: List[ChatMessage],
language_context: LanguageContext
) -> Generator[str, None, None]:
"""
Generate streaming response for real-time chat experience.
Args:
prompt: User's input message
chat_history: List of previous chat messages for context
language_context: Programming language context and templates
Yields:
Response chunks as they are generated
Raises:
ChatAgentError: For various API-related errors
"""
start_time = time.time()
chunk_count = 0
try:
messages = self._build_messages(prompt, chat_history, language_context)
# Use circuit breaker for streaming API call
def _stream_call():
return self.groq_client.chat.completions.create(
model=self.model,
messages=[{"role": msg.role, "content": msg.content} for msg in messages],
temperature=self.temperature,
max_tokens=self.max_tokens,
stream=True
)
response = self.circuit_breaker.call(_stream_call)
for chunk in response:
if chunk.choices[0].delta.content:
chunk_count += 1
yield chunk.choices[0].delta.content
# Log performance
duration = time.time() - start_time
performance_logger.log_operation(
operation="stream_response",
duration=duration,
context={
'model': self.model,
'language': language_context.language,
'message_count': len(messages),
'chunk_count': chunk_count,
'prompt_length': len(prompt)
}
)
except ChatAgentError as e:
# Yield fallback response for chat errors
fallback = self.error_handler.get_fallback_response(e)
for word in fallback.split():
yield word + " "
time.sleep(0.05) # Simulate streaming
except Exception as e:
# Handle and convert other exceptions
duration = time.time() - start_time
context = {
'model': self.model,
'language': language_context.language,
'duration': duration,
'chunk_count': chunk_count,
'prompt_length': len(prompt)
}
chat_error = self.error_handler.handle_error(e, context)
fallback = self.error_handler.get_fallback_response(chat_error)
# Yield fallback response as streaming chunks
for word in fallback.split():
yield word + " "
time.sleep(0.05) # Simulate streaming
def _build_messages(
self,
prompt: str,
chat_history: List[ChatMessage],
language_context: LanguageContext
) -> List[ChatMessage]:
"""
Build message list with system prompt, chat history, and current prompt.
Args:
prompt: Current user message
chat_history: Previous conversation messages
language_context: Programming language context
Returns:
List of formatted chat messages
"""
messages = []
# Add system message with language context
system_prompt = language_context.prompt_template.format(
language=language_context.language
)
messages.append(ChatMessage(role="system", content=system_prompt))
# Add chat history (limit to recent messages to stay within context window)
context_window = int(os.getenv('CONTEXT_WINDOW_SIZE', '10'))
recent_history = chat_history[-context_window:] if chat_history else []
for msg in recent_history:
messages.append(msg)
# Add current user message
messages.append(ChatMessage(role="user", content=prompt, language=language_context.language))
return messages
def _generate_standard_response(self, messages: List[ChatMessage]) -> str:
"""Generate standard (non-streaming) response"""
langchain_messages = []
for msg in messages:
if msg.role == "system":
langchain_messages.append(SystemMessage(content=msg.content))
elif msg.role == "user":
langchain_messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
langchain_messages.append(AIMessage(content=msg.content))
response = self.langchain_client.invoke(langchain_messages)
return response.content
def _generate_streaming_response(self, messages: List[ChatMessage]) -> str:
"""Generate response using streaming and return complete response"""
response_chunks = []
for chunk in self.stream_response("", [], LanguageContext("python", "", "")):
response_chunks.append(chunk)
return "".join(response_chunks)
def _fallback_response(self, *args, **kwargs) -> str:
"""
Provide fallback response when circuit breaker is open.
Returns:
Fallback response string
"""
return ("I'm currently experiencing high demand and need a moment to catch up. "
"While you wait, here are some general programming tips:\n\n"
"• Break down complex problems into smaller steps\n"
"• Use descriptive variable names\n"
"• Add comments to explain your logic\n"
"• Test your code frequently\n\n"
"Please try your question again in a moment!")
def _handle_api_error(self, error: Exception) -> str:
"""
Handle various API errors with appropriate fallback responses.
Args:
error: Exception that occurred during API call
Returns:
Fallback error message for user
Raises:
ChatAgentError: Re-raises as appropriate error type
"""
error_str = str(error).lower()
if "rate limit" in error_str or "429" in error_str:
logger.warning(f"Rate limit exceeded: {error}")
raise GroqRateLimitError("API rate limit exceeded", context={'original_error': str(error)})
elif "authentication" in error_str or "401" in error_str:
logger.error(f"Authentication error: {error}")
raise GroqAuthenticationError("API authentication failed", context={'original_error': str(error)})
elif "network" in error_str or "connection" in error_str:
logger.error(f"Network error: {error}")
raise GroqNetworkError("Network connection failed", context={'original_error': str(error)})
elif "quota" in error_str or "billing" in error_str:
logger.error(f"Quota/billing error: {error}")
raise GroqError("API quota exceeded", context={'original_error': str(error)})
else:
logger.error(f"Unexpected API error: {error}")
raise GroqError("Unexpected API error", context={'original_error': str(error)})
def test_connection(self) -> bool:
"""
Test connection to Groq API.
Returns:
True if connection successful, False otherwise
"""
try:
# Simple test message
test_messages = [
ChatMessage(role="system", content="You are a helpful assistant."),
ChatMessage(role="user", content="Hello")
]
response = self._generate_standard_response(test_messages)
logger.info("Groq API connection test successful")
return True
except Exception as e:
logger.error(f"Groq API connection test failed: {e}")
return False
def get_model_info(self) -> Dict[str, Any]:
"""
Get information about the current model configuration.
Returns:
Dictionary with model configuration details
"""
return {
"model": self.model,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"stream_responses": self.stream_responses,
"api_key_configured": bool(self.api_key)
}
# Default language prompt templates
DEFAULT_LANGUAGE_TEMPLATES = {
"python": """You are a helpful programming assistant specializing in Python.
You help students learn Python programming by providing clear explanations,
debugging assistance, and code examples. Always use Python syntax and best practices.
Keep explanations beginner-friendly and provide practical examples.""",
"javascript": """You are a helpful programming assistant specializing in JavaScript.
You help students learn JavaScript programming by providing clear explanations,
debugging assistance, and code examples. Always use modern JavaScript (ES6+) syntax.
Keep explanations beginner-friendly and provide practical examples.""",
"java": """You are a helpful programming assistant specializing in Java.
You help students learn Java programming by providing clear explanations,
debugging assistance, and code examples. Always use modern Java syntax and best practices.
Keep explanations beginner-friendly and provide practical examples.""",
"cpp": """You are a helpful programming assistant specializing in C++.
You help students learn C++ programming by providing clear explanations,
debugging assistance, and code examples. Always use modern C++ (C++11 or later) syntax.
Keep explanations beginner-friendly and provide practical examples."""
}
def create_language_context(language: str) -> LanguageContext:
"""
Create language context with appropriate prompt template.
Args:
language: Programming language name
Returns:
LanguageContext object with template and settings
"""
template = DEFAULT_LANGUAGE_TEMPLATES.get(
language.lower(),
DEFAULT_LANGUAGE_TEMPLATES["python"]
)
return LanguageContext(
language=language,
prompt_template=template,
syntax_highlighting=language.lower()
)