AISVIZ-BOT / llm_setup /llm_setup.py
vaishnav
make mistral default model
a856301
from langchain_core.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
PromptTemplate,
)
from langchain_classic.chains import create_history_aware_retriever, create_retrieval_chain
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from caching.lfu import LFUCache
import os
def create_llm(provider: str, model_name: str, api_key: str | None = None):
"""
Factory that creates a LangChain chat model for the given provider.
"""
from configs.config import PROVIDER_ENV_KEYS
env_key = PROVIDER_ENV_KEYS.get(provider)
resolved_key = api_key or (os.environ.get(env_key) if env_key else None)
if not resolved_key:
raise ValueError(
f"No API key for {provider}. Set {env_key} or provide one in the UI."
)
if provider == "Google Gemini":
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=model_name, google_api_key=resolved_key)
elif provider == "OpenAI":
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model_name, api_key=resolved_key)
elif provider == "Anthropic":
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model_name, api_key=resolved_key)
elif provider == "HuggingFace":
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
llm = HuggingFaceEndpoint(
repo_id=model_name,
huggingfacehub_api_token=resolved_key,
)
return ChatHuggingFace(llm=llm)
else:
raise ValueError(f"Unknown provider: {provider}")
class LLMService:
def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever,
cache_capacity: int = 50,
provider: str = "Google Gemini",
llm_model_name: str = "gemini-2.5-flash-lite"):
self._conversational_rag_chain = None
self._logger = logger
self.system_prompt = system_prompt
self._web_retriever = web_retriever
self.current_provider = provider
self.current_model_name = llm_model_name
self.llm = create_llm(provider, llm_model_name)
self._initialize_conversational_rag_chain()
### Statefully manage chat history ###
self.store = LFUCache(capacity=cache_capacity)
def _initialize_conversational_rag_chain(self):
"""
Initializes the conversational RAG chain.
"""
### Contextualize question ###
contextualize_q_system_prompt = """Given the full chat history and the latest user message, \
rewrite the message as a fully self-contained question that can be understood without any prior context. \
Preserve intent and key entities, expand pronouns and references, and include necessary constraints, dates, or assumptions. \
Do not add new information, do not omit essential details, and do not answer the question. Return only the rewritten question."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
self.llm, self._web_retriever, contextualize_q_prompt)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
document_prompt = PromptTemplate.from_template(
"{page_content}\n[Source: {source}]"
)
question_answer_chain = create_stuff_documents_chain(
self.llm, qa_prompt, document_prompt=document_prompt
)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
self._conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
self._get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
def _get_session_history(self, session_id: str) -> BaseChatMessageHistory:
history = self.store.get(session_id)
if history is None:
history = ChatMessageHistory()
self.store.put(session_id, history)
return history
def update_llm(self, provider: str, model_name: str, api_key: str | None = None):
"""
Swap the LLM at runtime. Rebuilds the chain but preserves the retriever
and chat history store.
"""
new_llm = create_llm(provider, model_name, api_key or None)
self.llm = new_llm
self.current_provider = provider
self.current_model_name = model_name
self._initialize_conversational_rag_chain()
self._logger.info(f"LLM switched to {provider} / {model_name}")
def conversational_rag_chain(self):
"""
Returns the initialized conversational RAG chain.
Returns:
The conversational RAG chain instance.
"""
return self._conversational_rag_chain
def get_llm(self):
"""
Returns the LLM instance.
"""
if self.llm is None:
raise Exception("llm is not initialized")
return self.llm