EXAM_RAG_API / generation /AssistantRagGenerator.py
MinaNasser's picture
1st
1bc3f18
from typing import Any
from pydantic import Field
from langchain_core.language_models import LLM
from langchain_core.runnables import RunnableBranch, RunnableLambda, RunnableParallel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from stores.llm.LLMProviderFactory import LLMProviderFactory
from config import get_settings
class ProviderLLMWrapper(LLM):
provider: Any = Field(..., description="The wrapped LLM provider")
def _call(self, prompt: str, stop=None) -> str:
# Calls the underlying model and ensures a string is returned
result = self.provider.generate_text(prompt)
if result is None:
raise ValueError("LLM provider returned None (likely due to timeout or error)")
if isinstance(result, dict):
response = result.get("response")
if response is None:
raise ValueError(f"LLM provider returned dict without 'response' key: {result.keys()}")
return response
if isinstance(result, str):
return result
raise ValueError(f"Unexpected LLM response type: {type(result).__name__}")
@property
def _llm_type(self):
return "custom-provider"
def get_num_tokens(self, text: str) -> int:
return len(text.split())
class AssistantRagGen:
def __init__(self):
config = get_settings()
self.factory = LLMProviderFactory(config)
self.generator = self.factory.create(config.GENERATION_BACKEND)
self.generator.set_generation_model(config.GENERATION_MODEL_ID)
self.llm = ProviderLLMWrapper(provider=self.generator)
self.valid_routes = {"user_info", "site_query", "pdf_query"}
def build_router_prompt(self, user_prompt: str) -> str:
return f"""You are a query routing classifier. Your sole job is to categorize a user's question into exactly one routing category.
## Categories
| Category | Routes questions about... |
|--------------|-------------------------------------------------------------------------------------------|
| `user_info` | Personal profile, enrolled courses, username, role, learning progress, achievements |
| `site_query` | Platform features, website navigation, rules, policies, FAQs, general platform knowledge |
| `pdf_query` | Document content, uploaded files, PDF search, lesson materials, reading resources |
## Examples
user_info → "What courses am I enrolled in?"
user_info → "What is my current progress in the Python course?"
site_query → "How do I reset my password?"
site_query → "What are the platform's refund policies?"
pdf_query → "What does the document say about recursion?"
pdf_query → "Find me the section on neural networks in the materials"
## Decision Rules
1. If the question involves the **current user's personal data** → `user_info`
2. If the question is about **how the platform works** → `site_query`
3. If the question requires **reading or searching a document** → `pdf_query`
4. When ambiguous, prefer `pdf_query` over `site_query`, and `user_info` over both.
## Output Format
Respond with a single lowercase word. No punctuation. No explanation. No whitespace.
Valid outputs: user_info | site_query | pdf_query
Question: {user_prompt}
"""
def build_unified_prompt(self, context: str, question: str, conversation_history: str = "", User_Info: str = "") -> str:
return f"""
You are a helpful university assistant.
Rules:
- Use the provided context FIRST.
- Use conversation history to understand follow-up questions.
- If the question is about the user, use the User_Info and enrolled_courses.
- If the answer is not in the context, say:
"Not found in the provided materials."
Then add:
"From my own information:" and answer briefly.
- Be concise and clear.
Conversation History:
{conversation_history if conversation_history else "None"}
User Info:
{User_Info if User_Info else "None"}
Context:
{context}
Current Question:
{question}
Answer:
"""
def build_user_info_prompt(self, question: str, conversation_history: str = "", User_Info: str = "") -> str:
return f"""
You are a university assistant handling a user account inquiry.
Use the provided User Info and Enrolled Courses to answer the question accurately.
Conversation History:
{conversation_history if conversation_history else "None"}
User Info:
{User_Info if User_Info else "None"}
Current Question:
{question}
Answer:
"""
def build_site_query_prompt(self, question: str,context:str="", conversation_history: str = "") -> str:
return f"""
You are a university assistant handling a platform or site-related question.
Provide clear instructions, rules, or general information about how the university platform works.
Conversation History:
{conversation_history if conversation_history else "None"}
Current Question:
{question}
Site Context:
{context if context else "None"}
Answer:
"""
def robust_router(self, input_data: dict) -> str:
question = input_data["question"]
attempts = 0
while attempts < 3:
prompt = self.build_router_prompt(question)
route = self.llm.invoke(prompt).strip().lower()
if route in self.valid_routes:
return route
attempts += 1
return "pdf_query"
def get_chain(self):
router_node = RunnableLambda(self.robust_router)
user_info_chain = RunnableLambda(lambda x: self.llm.invoke(
self.build_user_info_prompt(
question=x["question"],
conversation_history=x.get("conversation_history", ""),
User_Info=x.get("User_Info", ""),
)
))
site_query_chain = RunnableLambda(lambda x: self.llm.invoke(
self.build_site_query_prompt(
question=x["question"],
context=x.get("context", ""),
conversation_history=x.get("conversation_history", "")
)
))
pdf_query_chain = RunnableLambda(lambda x: self.llm.invoke(
self.build_unified_prompt(
context=x.get("context", "No context provided."),
question=x["question"],
conversation_history=x.get("conversation_history", ""),
User_Info=x.get("User_Info", ""),
)
))
branching_logic = RunnableBranch(
(lambda x: x["topic"] == "user_info", user_info_chain),
(lambda x: x["topic"] == "site_query", site_query_chain),
pdf_query_chain
)
full_chain = (
RunnableParallel({
"topic": router_node,
# Pass all incoming variables straight through to the branches
"question": lambda x: x["question"],
"context": lambda x: x.get("context", ""),
"conversation_history": lambda x: x.get("conversation_history", ""),
"User_Info": lambda x: x.get("User_Info", ""),
"enrolled_courses": lambda x: x.get("enrolled_courses", "")
})
| branching_logic
| StrOutputParser()
)
return full_chain