| import os |
| from typing import Any, Optional |
|
|
| import weave |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_groq import ChatGroq |
| from pydantic import BaseModel |
|
|
| from rag_pipelines.prompts import STRUCTURED_RAG_PROMPT, RAGResponseModel |
|
|
|
|
| class ChatGroqGenerator: |
| """Interact with the ChatGroq model to generate responses based on user queries and documents. |
| |
| This class provides an interface for generating responses using the ChatGroq model. |
| It handles prompt formatting, LLM invocation, document integration, and result generation. |
| """ |
|
|
| model: str |
| api_key: str |
| llm_params: dict[str, Any] |
| llm: Optional[ChatGroq] = None |
| structured_output_model: BaseModel |
| system_prompt: str |
|
|
| def __init__( |
| self, |
| model: str, |
| api_key: Optional[str] = None, |
| llm_params: Optional[dict[str, Any]] = None, |
| structured_output_model: BaseModel = RAGResponseModel, |
| system_prompt: str = STRUCTURED_RAG_PROMPT, |
| ): |
| """Initialize the ChatGroqGenerator with configuration parameters. |
| |
| Args: |
| model (str): The name of the ChatGroq model to use. |
| api_key (Optional[str]): API key for the ChatGroq service. If not provided, |
| the "GROQ_API_KEY" environment variable will be used. |
| llm_params (Optional[dict]): Additional parameters for configuring the ChatGroq model. |
| structured_output_model (BaseModel): The output model for structured responses. |
| system_prompt (str): The system prompt for the ChatGroq model. |
| |
| Raises: |
| ValueError: If the API key is not provided and the "GROQ_API_KEY" environment variable is not set. |
| """ |
| if llm_params is None: |
| llm_params = {} |
|
|
| api_key = api_key or os.environ.get("GROQ_API_KEY") |
| if api_key is None: |
| msg = "GROQ_API_KEY is not set. Please provide an API key or set it as an environment variable." |
| raise ValueError(msg) |
|
|
| self.model = model |
| self.api_key = api_key |
| self.llm_params = llm_params |
|
|
| self.structured_output_model = structured_output_model |
| self.system_prompt = system_prompt |
|
|
| self.llm = ChatGroq(model=self.model, api_key=self.api_key, **llm_params) |
|
|
| @weave.op() |
| def __call__(self, state: dict[str, Any]) -> dict[str, Any]: |
| """Generate a response using the current state of user prompts and graded documents. |
| |
| Args: |
| state (dict[str, Any]): The current state, containing: |
| - 'question': The user question. |
| - 'context': A list of filtered document texts. |
| - 'documents': A list of retrieved documents. |
| |
| Returns: |
| dict[str, Any]: A dictionary containing: |
| - 'question': The user question. |
| - 'context': A list of filtered document texts. |
| - 'documents': A list of retrieved documents. |
| - 'answer': The generated response. |
| """ |
| question = state["question"] |
| context = state["context"] |
| documents = state["documents"] |
|
|
| formatted_context = "\n".join(context) |
|
|
| prompt = ChatPromptTemplate.from_messages( |
| [ |
| ("system", self.system_prompt), |
| ] |
| ) |
|
|
| response_chain = prompt | self.llm.with_structured_output(self.structured_output_model) |
|
|
| response = response_chain.invoke({"question": question, "context": formatted_context}) |
|
|
| answer = response.final_answer |
|
|
| return {"question": question, "context": context, "documents": documents, "answer": answer} |
|
|