Spaces:
Paused
Paused
| from typing import Any | |
| from pydantic import Field | |
| from langchain_core.language_models import LLM | |
| from langchain_core.runnables import RunnableBranch, RunnableLambda, RunnableParallel | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import PromptTemplate | |
| from stores.llm.LLMProviderFactory import LLMProviderFactory | |
| from config import get_settings | |
| class ProviderLLMWrapper(LLM): | |
| provider: Any = Field(..., description="The wrapped LLM provider") | |
| def _call(self, prompt: str, stop=None) -> str: | |
| # Calls the underlying model and ensures a string is returned | |
| result = self.provider.generate_text(prompt) | |
| if result is None: | |
| raise ValueError("LLM provider returned None (likely due to timeout or error)") | |
| if isinstance(result, dict): | |
| response = result.get("response") | |
| if response is None: | |
| raise ValueError(f"LLM provider returned dict without 'response' key: {result.keys()}") | |
| return response | |
| if isinstance(result, str): | |
| return result | |
| raise ValueError(f"Unexpected LLM response type: {type(result).__name__}") | |
| def _llm_type(self): | |
| return "custom-provider" | |
| def get_num_tokens(self, text: str) -> int: | |
| return len(text.split()) | |
| class AssistantRagGen: | |
| def __init__(self): | |
| config = get_settings() | |
| self.factory = LLMProviderFactory(config) | |
| self.generator = self.factory.create(config.GENERATION_BACKEND) | |
| self.generator.set_generation_model(config.GENERATION_MODEL_ID) | |
| self.llm = ProviderLLMWrapper(provider=self.generator) | |
| self.valid_routes = {"user_info", "site_query", "pdf_query"} | |
| def build_router_prompt(self, user_prompt: str) -> str: | |
| return f"""You are a query routing classifier. Your sole job is to categorize a user's question into exactly one routing category. | |
| ## Categories | |
| | Category | Routes questions about... | | |
| |--------------|-------------------------------------------------------------------------------------------| | |
| | `user_info` | Personal profile, enrolled courses, username, role, learning progress, achievements | | |
| | `site_query` | Platform features, website navigation, rules, policies, FAQs, general platform knowledge | | |
| | `pdf_query` | Document content, uploaded files, PDF search, lesson materials, reading resources | | |
| ## Examples | |
| user_info → "What courses am I enrolled in?" | |
| user_info → "What is my current progress in the Python course?" | |
| site_query → "How do I reset my password?" | |
| site_query → "What are the platform's refund policies?" | |
| pdf_query → "What does the document say about recursion?" | |
| pdf_query → "Find me the section on neural networks in the materials" | |
| ## Decision Rules | |
| 1. If the question involves the **current user's personal data** → `user_info` | |
| 2. If the question is about **how the platform works** → `site_query` | |
| 3. If the question requires **reading or searching a document** → `pdf_query` | |
| 4. When ambiguous, prefer `pdf_query` over `site_query`, and `user_info` over both. | |
| ## Output Format | |
| Respond with a single lowercase word. No punctuation. No explanation. No whitespace. | |
| Valid outputs: user_info | site_query | pdf_query | |
| Question: {user_prompt} | |
| """ | |
| def build_unified_prompt(self, context: str, question: str, conversation_history: str = "", User_Info: str = "") -> str: | |
| return f""" | |
| You are a helpful university assistant. | |
| Rules: | |
| - Use the provided context FIRST. | |
| - Use conversation history to understand follow-up questions. | |
| - If the question is about the user, use the User_Info and enrolled_courses. | |
| - If the answer is not in the context, say: | |
| "Not found in the provided materials." | |
| Then add: | |
| "From my own information:" and answer briefly. | |
| - Be concise and clear. | |
| Conversation History: | |
| {conversation_history if conversation_history else "None"} | |
| User Info: | |
| {User_Info if User_Info else "None"} | |
| Context: | |
| {context} | |
| Current Question: | |
| {question} | |
| Answer: | |
| """ | |
| def build_user_info_prompt(self, question: str, conversation_history: str = "", User_Info: str = "") -> str: | |
| return f""" | |
| You are a university assistant handling a user account inquiry. | |
| Use the provided User Info and Enrolled Courses to answer the question accurately. | |
| Conversation History: | |
| {conversation_history if conversation_history else "None"} | |
| User Info: | |
| {User_Info if User_Info else "None"} | |
| Current Question: | |
| {question} | |
| Answer: | |
| """ | |
| def build_site_query_prompt(self, question: str,context:str="", conversation_history: str = "") -> str: | |
| return f""" | |
| You are a university assistant handling a platform or site-related question. | |
| Provide clear instructions, rules, or general information about how the university platform works. | |
| Conversation History: | |
| {conversation_history if conversation_history else "None"} | |
| Current Question: | |
| {question} | |
| Site Context: | |
| {context if context else "None"} | |
| Answer: | |
| """ | |
| def robust_router(self, input_data: dict) -> str: | |
| question = input_data["question"] | |
| attempts = 0 | |
| while attempts < 3: | |
| prompt = self.build_router_prompt(question) | |
| route = self.llm.invoke(prompt).strip().lower() | |
| if route in self.valid_routes: | |
| return route | |
| attempts += 1 | |
| return "pdf_query" | |
| def get_chain(self): | |
| router_node = RunnableLambda(self.robust_router) | |
| user_info_chain = RunnableLambda(lambda x: self.llm.invoke( | |
| self.build_user_info_prompt( | |
| question=x["question"], | |
| conversation_history=x.get("conversation_history", ""), | |
| User_Info=x.get("User_Info", ""), | |
| ) | |
| )) | |
| site_query_chain = RunnableLambda(lambda x: self.llm.invoke( | |
| self.build_site_query_prompt( | |
| question=x["question"], | |
| context=x.get("context", ""), | |
| conversation_history=x.get("conversation_history", "") | |
| ) | |
| )) | |
| pdf_query_chain = RunnableLambda(lambda x: self.llm.invoke( | |
| self.build_unified_prompt( | |
| context=x.get("context", "No context provided."), | |
| question=x["question"], | |
| conversation_history=x.get("conversation_history", ""), | |
| User_Info=x.get("User_Info", ""), | |
| ) | |
| )) | |
| branching_logic = RunnableBranch( | |
| (lambda x: x["topic"] == "user_info", user_info_chain), | |
| (lambda x: x["topic"] == "site_query", site_query_chain), | |
| pdf_query_chain | |
| ) | |
| full_chain = ( | |
| RunnableParallel({ | |
| "topic": router_node, | |
| # Pass all incoming variables straight through to the branches | |
| "question": lambda x: x["question"], | |
| "context": lambda x: x.get("context", ""), | |
| "conversation_history": lambda x: x.get("conversation_history", ""), | |
| "User_Info": lambda x: x.get("User_Info", ""), | |
| "enrolled_courses": lambda x: x.get("enrolled_courses", "") | |
| }) | |
| | branching_logic | |
| | StrOutputParser() | |
| ) | |
| return full_chain | |