from langchain_core.prompts import ( ChatPromptTemplate, MessagesPlaceholder, PromptTemplate, ) from langchain_classic.chains import create_history_aware_retriever, create_retrieval_chain from langchain_classic.chains.combine_documents import create_stuff_documents_chain from langchain_core.vectorstores import VectorStoreRetriever from langchain_core.chat_history import BaseChatMessageHistory from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory from caching.lfu import LFUCache import os def create_llm(provider: str, model_name: str, api_key: str | None = None): """ Factory that creates a LangChain chat model for the given provider. """ from configs.config import PROVIDER_ENV_KEYS env_key = PROVIDER_ENV_KEYS.get(provider) resolved_key = api_key or (os.environ.get(env_key) if env_key else None) if not resolved_key: raise ValueError( f"No API key for {provider}. Set {env_key} or provide one in the UI." ) if provider == "Google Gemini": from langchain_google_genai import ChatGoogleGenerativeAI return ChatGoogleGenerativeAI(model=model_name, google_api_key=resolved_key) elif provider == "OpenAI": from langchain_openai import ChatOpenAI return ChatOpenAI(model=model_name, api_key=resolved_key) elif provider == "Anthropic": from langchain_anthropic import ChatAnthropic return ChatAnthropic(model=model_name, api_key=resolved_key) elif provider == "HuggingFace": from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint llm = HuggingFaceEndpoint( repo_id=model_name, huggingfacehub_api_token=resolved_key, ) return ChatHuggingFace(llm=llm) else: raise ValueError(f"Unknown provider: {provider}") class LLMService: def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever, cache_capacity: int = 50, provider: str = "Google Gemini", llm_model_name: str = "gemini-2.5-flash-lite"): self._conversational_rag_chain = None self._logger = logger self.system_prompt = system_prompt self._web_retriever = web_retriever self.current_provider = provider self.current_model_name = llm_model_name self.llm = create_llm(provider, llm_model_name) self._initialize_conversational_rag_chain() ### Statefully manage chat history ### self.store = LFUCache(capacity=cache_capacity) def _initialize_conversational_rag_chain(self): """ Initializes the conversational RAG chain. """ ### Contextualize question ### contextualize_q_system_prompt = """Given the full chat history and the latest user message, \ rewrite the message as a fully self-contained question that can be understood without any prior context. \ Preserve intent and key entities, expand pronouns and references, and include necessary constraints, dates, or assumptions. \ Do not add new information, do not omit essential details, and do not answer the question. Return only the rewritten question.""" contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) history_aware_retriever = create_history_aware_retriever( self.llm, self._web_retriever, contextualize_q_prompt) qa_prompt = ChatPromptTemplate.from_messages( [ ("system", self.system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) document_prompt = PromptTemplate.from_template( "{page_content}\n[Source: {source}]" ) question_answer_chain = create_stuff_documents_chain( self.llm, qa_prompt, document_prompt=document_prompt ) rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) self._conversational_rag_chain = RunnableWithMessageHistory( rag_chain, self._get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) def _get_session_history(self, session_id: str) -> BaseChatMessageHistory: history = self.store.get(session_id) if history is None: history = ChatMessageHistory() self.store.put(session_id, history) return history def update_llm(self, provider: str, model_name: str, api_key: str | None = None): """ Swap the LLM at runtime. Rebuilds the chain but preserves the retriever and chat history store. """ new_llm = create_llm(provider, model_name, api_key or None) self.llm = new_llm self.current_provider = provider self.current_model_name = model_name self._initialize_conversational_rag_chain() self._logger.info(f"LLM switched to {provider} / {model_name}") def conversational_rag_chain(self): """ Returns the initialized conversational RAG chain. Returns: The conversational RAG chain instance. """ return self._conversational_rag_chain def get_llm(self): """ Returns the LLM instance. """ if self.llm is None: raise Exception("llm is not initialized") return self.llm