Spaces:
Paused
Paused
File size: 7,582 Bytes
1bc3f18 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | from typing import Any
from pydantic import Field
from langchain_core.language_models import LLM
from langchain_core.runnables import RunnableBranch, RunnableLambda, RunnableParallel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from stores.llm.LLMProviderFactory import LLMProviderFactory
from config import get_settings
class ProviderLLMWrapper(LLM):
provider: Any = Field(..., description="The wrapped LLM provider")
def _call(self, prompt: str, stop=None) -> str:
# Calls the underlying model and ensures a string is returned
result = self.provider.generate_text(prompt)
if result is None:
raise ValueError("LLM provider returned None (likely due to timeout or error)")
if isinstance(result, dict):
response = result.get("response")
if response is None:
raise ValueError(f"LLM provider returned dict without 'response' key: {result.keys()}")
return response
if isinstance(result, str):
return result
raise ValueError(f"Unexpected LLM response type: {type(result).__name__}")
@property
def _llm_type(self):
return "custom-provider"
def get_num_tokens(self, text: str) -> int:
return len(text.split())
class AssistantRagGen:
def __init__(self):
config = get_settings()
self.factory = LLMProviderFactory(config)
self.generator = self.factory.create(config.GENERATION_BACKEND)
self.generator.set_generation_model(config.GENERATION_MODEL_ID)
self.llm = ProviderLLMWrapper(provider=self.generator)
self.valid_routes = {"user_info", "site_query", "pdf_query"}
def build_router_prompt(self, user_prompt: str) -> str:
return f"""You are a query routing classifier. Your sole job is to categorize a user's question into exactly one routing category.
## Categories
| Category | Routes questions about... |
|--------------|-------------------------------------------------------------------------------------------|
| `user_info` | Personal profile, enrolled courses, username, role, learning progress, achievements |
| `site_query` | Platform features, website navigation, rules, policies, FAQs, general platform knowledge |
| `pdf_query` | Document content, uploaded files, PDF search, lesson materials, reading resources |
## Examples
user_info β "What courses am I enrolled in?"
user_info β "What is my current progress in the Python course?"
site_query β "How do I reset my password?"
site_query β "What are the platform's refund policies?"
pdf_query β "What does the document say about recursion?"
pdf_query β "Find me the section on neural networks in the materials"
## Decision Rules
1. If the question involves the **current user's personal data** β `user_info`
2. If the question is about **how the platform works** β `site_query`
3. If the question requires **reading or searching a document** β `pdf_query`
4. When ambiguous, prefer `pdf_query` over `site_query`, and `user_info` over both.
## Output Format
Respond with a single lowercase word. No punctuation. No explanation. No whitespace.
Valid outputs: user_info | site_query | pdf_query
Question: {user_prompt}
"""
def build_unified_prompt(self, context: str, question: str, conversation_history: str = "", User_Info: str = "") -> str:
return f"""
You are a helpful university assistant.
Rules:
- Use the provided context FIRST.
- Use conversation history to understand follow-up questions.
- If the question is about the user, use the User_Info and enrolled_courses.
- If the answer is not in the context, say:
"Not found in the provided materials."
Then add:
"From my own information:" and answer briefly.
- Be concise and clear.
Conversation History:
{conversation_history if conversation_history else "None"}
User Info:
{User_Info if User_Info else "None"}
Context:
{context}
Current Question:
{question}
Answer:
"""
def build_user_info_prompt(self, question: str, conversation_history: str = "", User_Info: str = "") -> str:
return f"""
You are a university assistant handling a user account inquiry.
Use the provided User Info and Enrolled Courses to answer the question accurately.
Conversation History:
{conversation_history if conversation_history else "None"}
User Info:
{User_Info if User_Info else "None"}
Current Question:
{question}
Answer:
"""
def build_site_query_prompt(self, question: str,context:str="", conversation_history: str = "") -> str:
return f"""
You are a university assistant handling a platform or site-related question.
Provide clear instructions, rules, or general information about how the university platform works.
Conversation History:
{conversation_history if conversation_history else "None"}
Current Question:
{question}
Site Context:
{context if context else "None"}
Answer:
"""
def robust_router(self, input_data: dict) -> str:
question = input_data["question"]
attempts = 0
while attempts < 3:
prompt = self.build_router_prompt(question)
route = self.llm.invoke(prompt).strip().lower()
if route in self.valid_routes:
return route
attempts += 1
return "pdf_query"
def get_chain(self):
router_node = RunnableLambda(self.robust_router)
user_info_chain = RunnableLambda(lambda x: self.llm.invoke(
self.build_user_info_prompt(
question=x["question"],
conversation_history=x.get("conversation_history", ""),
User_Info=x.get("User_Info", ""),
)
))
site_query_chain = RunnableLambda(lambda x: self.llm.invoke(
self.build_site_query_prompt(
question=x["question"],
context=x.get("context", ""),
conversation_history=x.get("conversation_history", "")
)
))
pdf_query_chain = RunnableLambda(lambda x: self.llm.invoke(
self.build_unified_prompt(
context=x.get("context", "No context provided."),
question=x["question"],
conversation_history=x.get("conversation_history", ""),
User_Info=x.get("User_Info", ""),
)
))
branching_logic = RunnableBranch(
(lambda x: x["topic"] == "user_info", user_info_chain),
(lambda x: x["topic"] == "site_query", site_query_chain),
pdf_query_chain
)
full_chain = (
RunnableParallel({
"topic": router_node,
# Pass all incoming variables straight through to the branches
"question": lambda x: x["question"],
"context": lambda x: x.get("context", ""),
"conversation_history": lambda x: x.get("conversation_history", ""),
"User_Info": lambda x: x.get("User_Info", ""),
"enrolled_courses": lambda x: x.get("enrolled_courses", "")
})
| branching_logic
| StrOutputParser()
)
return full_chain
|