awinml's picture
Upload 107 files
336f4a9 verified
import os
from typing import Any, Optional
import weave
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from pydantic import BaseModel
from rag_pipelines.prompts import STRUCTURED_RAG_PROMPT, RAGResponseModel
class ChatGroqGenerator:
"""Interact with the ChatGroq model to generate responses based on user queries and documents.
This class provides an interface for generating responses using the ChatGroq model.
It handles prompt formatting, LLM invocation, document integration, and result generation.
"""
model: str
api_key: str
llm_params: dict[str, Any]
llm: Optional[ChatGroq] = None
structured_output_model: BaseModel
system_prompt: str
def __init__(
self,
model: str,
api_key: Optional[str] = None,
llm_params: Optional[dict[str, Any]] = None,
structured_output_model: BaseModel = RAGResponseModel,
system_prompt: str = STRUCTURED_RAG_PROMPT,
):
"""Initialize the ChatGroqGenerator with configuration parameters.
Args:
model (str): The name of the ChatGroq model to use.
api_key (Optional[str]): API key for the ChatGroq service. If not provided,
the "GROQ_API_KEY" environment variable will be used.
llm_params (Optional[dict]): Additional parameters for configuring the ChatGroq model.
structured_output_model (BaseModel): The output model for structured responses.
system_prompt (str): The system prompt for the ChatGroq model.
Raises:
ValueError: If the API key is not provided and the "GROQ_API_KEY" environment variable is not set.
"""
if llm_params is None:
llm_params = {}
api_key = api_key or os.environ.get("GROQ_API_KEY")
if api_key is None:
msg = "GROQ_API_KEY is not set. Please provide an API key or set it as an environment variable."
raise ValueError(msg)
self.model = model
self.api_key = api_key
self.llm_params = llm_params
self.structured_output_model = structured_output_model
self.system_prompt = system_prompt
self.llm = ChatGroq(model=self.model, api_key=self.api_key, **llm_params)
@weave.op()
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
"""Generate a response using the current state of user prompts and graded documents.
Args:
state (dict[str, Any]): The current state, containing:
- 'question': The user question.
- 'context': A list of filtered document texts.
- 'documents': A list of retrieved documents.
Returns:
dict[str, Any]: A dictionary containing:
- 'question': The user question.
- 'context': A list of filtered document texts.
- 'documents': A list of retrieved documents.
- 'answer': The generated response.
"""
question = state["question"]
context = state["context"]
documents = state["documents"]
formatted_context = "\n".join(context)
prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_prompt),
]
)
response_chain = prompt | self.llm.with_structured_output(self.structured_output_model)
response = response_chain.invoke({"question": question, "context": formatted_context})
answer = response.final_answer
return {"question": question, "context": context, "documents": documents, "answer": answer}