Spaces:
Running
Running
| from langchain_core.prompts import ( | |
| ChatPromptTemplate, | |
| MessagesPlaceholder, | |
| PromptTemplate, | |
| ) | |
| from langchain_classic.chains import create_history_aware_retriever, create_retrieval_chain | |
| from langchain_classic.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.vectorstores import VectorStoreRetriever | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from caching.lfu import LFUCache | |
| import os | |
| def create_llm(provider: str, model_name: str, api_key: str | None = None): | |
| """ | |
| Factory that creates a LangChain chat model for the given provider. | |
| """ | |
| from configs.config import PROVIDER_ENV_KEYS | |
| env_key = PROVIDER_ENV_KEYS.get(provider) | |
| resolved_key = api_key or (os.environ.get(env_key) if env_key else None) | |
| if not resolved_key: | |
| raise ValueError( | |
| f"No API key for {provider}. Set {env_key} or provide one in the UI." | |
| ) | |
| if provider == "Google Gemini": | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| return ChatGoogleGenerativeAI(model=model_name, google_api_key=resolved_key) | |
| elif provider == "OpenAI": | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI(model=model_name, api_key=resolved_key) | |
| elif provider == "Anthropic": | |
| from langchain_anthropic import ChatAnthropic | |
| return ChatAnthropic(model=model_name, api_key=resolved_key) | |
| elif provider == "HuggingFace": | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| llm = HuggingFaceEndpoint( | |
| repo_id=model_name, | |
| huggingfacehub_api_token=resolved_key, | |
| ) | |
| return ChatHuggingFace(llm=llm) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| class LLMService: | |
| def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever, | |
| cache_capacity: int = 50, | |
| provider: str = "Google Gemini", | |
| llm_model_name: str = "gemini-2.5-flash-lite"): | |
| self._conversational_rag_chain = None | |
| self._logger = logger | |
| self.system_prompt = system_prompt | |
| self._web_retriever = web_retriever | |
| self.current_provider = provider | |
| self.current_model_name = llm_model_name | |
| self.llm = create_llm(provider, llm_model_name) | |
| self._initialize_conversational_rag_chain() | |
| ### Statefully manage chat history ### | |
| self.store = LFUCache(capacity=cache_capacity) | |
| def _initialize_conversational_rag_chain(self): | |
| """ | |
| Initializes the conversational RAG chain. | |
| """ | |
| ### Contextualize question ### | |
| contextualize_q_system_prompt = """Given the full chat history and the latest user message, \ | |
| rewrite the message as a fully self-contained question that can be understood without any prior context. \ | |
| Preserve intent and key entities, expand pronouns and references, and include necessary constraints, dates, or assumptions. \ | |
| Do not add new information, do not omit essential details, and do not answer the question. Return only the rewritten question.""" | |
| contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", contextualize_q_system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| history_aware_retriever = create_history_aware_retriever( | |
| self.llm, self._web_retriever, contextualize_q_prompt) | |
| qa_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", self.system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| document_prompt = PromptTemplate.from_template( | |
| "{page_content}\n[Source: {source}]" | |
| ) | |
| question_answer_chain = create_stuff_documents_chain( | |
| self.llm, qa_prompt, document_prompt=document_prompt | |
| ) | |
| rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
| self._conversational_rag_chain = RunnableWithMessageHistory( | |
| rag_chain, | |
| self._get_session_history, | |
| input_messages_key="input", | |
| history_messages_key="chat_history", | |
| output_messages_key="answer", | |
| ) | |
| def _get_session_history(self, session_id: str) -> BaseChatMessageHistory: | |
| history = self.store.get(session_id) | |
| if history is None: | |
| history = ChatMessageHistory() | |
| self.store.put(session_id, history) | |
| return history | |
| def update_llm(self, provider: str, model_name: str, api_key: str | None = None): | |
| """ | |
| Swap the LLM at runtime. Rebuilds the chain but preserves the retriever | |
| and chat history store. | |
| """ | |
| new_llm = create_llm(provider, model_name, api_key or None) | |
| self.llm = new_llm | |
| self.current_provider = provider | |
| self.current_model_name = model_name | |
| self._initialize_conversational_rag_chain() | |
| self._logger.info(f"LLM switched to {provider} / {model_name}") | |
| def conversational_rag_chain(self): | |
| """ | |
| Returns the initialized conversational RAG chain. | |
| Returns: | |
| The conversational RAG chain instance. | |
| """ | |
| return self._conversational_rag_chain | |
| def get_llm(self): | |
| """ | |
| Returns the LLM instance. | |
| """ | |
| if self.llm is None: | |
| raise Exception("llm is not initialized") | |
| return self.llm | |