from typing import Any from pydantic import Field from langchain_core.language_models import LLM from langchain_core.runnables import RunnableBranch, RunnableLambda, RunnableParallel from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from stores.llm.LLMProviderFactory import LLMProviderFactory from config import get_settings class ProviderLLMWrapper(LLM): provider: Any = Field(..., description="The wrapped LLM provider") def _call(self, prompt: str, stop=None) -> str: # Calls the underlying model and ensures a string is returned result = self.provider.generate_text(prompt) if result is None: raise ValueError("LLM provider returned None (likely due to timeout or error)") if isinstance(result, dict): response = result.get("response") if response is None: raise ValueError(f"LLM provider returned dict without 'response' key: {result.keys()}") return response if isinstance(result, str): return result raise ValueError(f"Unexpected LLM response type: {type(result).__name__}") @property def _llm_type(self): return "custom-provider" def get_num_tokens(self, text: str) -> int: return len(text.split()) class AssistantRagGen: def __init__(self): config = get_settings() self.factory = LLMProviderFactory(config) self.generator = self.factory.create(config.GENERATION_BACKEND) self.generator.set_generation_model(config.GENERATION_MODEL_ID) self.llm = ProviderLLMWrapper(provider=self.generator) self.valid_routes = {"user_info", "site_query", "pdf_query"} def build_router_prompt(self, user_prompt: str) -> str: return f"""You are a query routing classifier. Your sole job is to categorize a user's question into exactly one routing category. ## Categories | Category | Routes questions about... | |--------------|-------------------------------------------------------------------------------------------| | `user_info` | Personal profile, enrolled courses, username, role, learning progress, achievements | | `site_query` | Platform features, website navigation, rules, policies, FAQs, general platform knowledge | | `pdf_query` | Document content, uploaded files, PDF search, lesson materials, reading resources | ## Examples user_info → "What courses am I enrolled in?" user_info → "What is my current progress in the Python course?" site_query → "How do I reset my password?" site_query → "What are the platform's refund policies?" pdf_query → "What does the document say about recursion?" pdf_query → "Find me the section on neural networks in the materials" ## Decision Rules 1. If the question involves the **current user's personal data** → `user_info` 2. If the question is about **how the platform works** → `site_query` 3. If the question requires **reading or searching a document** → `pdf_query` 4. When ambiguous, prefer `pdf_query` over `site_query`, and `user_info` over both. ## Output Format Respond with a single lowercase word. No punctuation. No explanation. No whitespace. Valid outputs: user_info | site_query | pdf_query Question: {user_prompt} """ def build_unified_prompt(self, context: str, question: str, conversation_history: str = "", User_Info: str = "") -> str: return f""" You are a helpful university assistant. Rules: - Use the provided context FIRST. - Use conversation history to understand follow-up questions. - If the question is about the user, use the User_Info and enrolled_courses. - If the answer is not in the context, say: "Not found in the provided materials." Then add: "From my own information:" and answer briefly. - Be concise and clear. Conversation History: {conversation_history if conversation_history else "None"} User Info: {User_Info if User_Info else "None"} Context: {context} Current Question: {question} Answer: """ def build_user_info_prompt(self, question: str, conversation_history: str = "", User_Info: str = "") -> str: return f""" You are a university assistant handling a user account inquiry. Use the provided User Info and Enrolled Courses to answer the question accurately. Conversation History: {conversation_history if conversation_history else "None"} User Info: {User_Info if User_Info else "None"} Current Question: {question} Answer: """ def build_site_query_prompt(self, question: str,context:str="", conversation_history: str = "") -> str: return f""" You are a university assistant handling a platform or site-related question. Provide clear instructions, rules, or general information about how the university platform works. Conversation History: {conversation_history if conversation_history else "None"} Current Question: {question} Site Context: {context if context else "None"} Answer: """ def robust_router(self, input_data: dict) -> str: question = input_data["question"] attempts = 0 while attempts < 3: prompt = self.build_router_prompt(question) route = self.llm.invoke(prompt).strip().lower() if route in self.valid_routes: return route attempts += 1 return "pdf_query" def get_chain(self): router_node = RunnableLambda(self.robust_router) user_info_chain = RunnableLambda(lambda x: self.llm.invoke( self.build_user_info_prompt( question=x["question"], conversation_history=x.get("conversation_history", ""), User_Info=x.get("User_Info", ""), ) )) site_query_chain = RunnableLambda(lambda x: self.llm.invoke( self.build_site_query_prompt( question=x["question"], context=x.get("context", ""), conversation_history=x.get("conversation_history", "") ) )) pdf_query_chain = RunnableLambda(lambda x: self.llm.invoke( self.build_unified_prompt( context=x.get("context", "No context provided."), question=x["question"], conversation_history=x.get("conversation_history", ""), User_Info=x.get("User_Info", ""), ) )) branching_logic = RunnableBranch( (lambda x: x["topic"] == "user_info", user_info_chain), (lambda x: x["topic"] == "site_query", site_query_chain), pdf_query_chain ) full_chain = ( RunnableParallel({ "topic": router_node, # Pass all incoming variables straight through to the branches "question": lambda x: x["question"], "context": lambda x: x.get("context", ""), "conversation_history": lambda x: x.get("conversation_history", ""), "User_Info": lambda x: x.get("User_Info", ""), "enrolled_courses": lambda x: x.get("enrolled_courses", "") }) | branching_logic | StrOutputParser() ) return full_chain