File size: 7,582 Bytes
1bc3f18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from typing import Any
from pydantic import Field
from langchain_core.language_models import LLM
from langchain_core.runnables import RunnableBranch, RunnableLambda, RunnableParallel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from stores.llm.LLMProviderFactory import LLMProviderFactory
from config import get_settings

class ProviderLLMWrapper(LLM):
    provider: Any = Field(..., description="The wrapped LLM provider")
    def _call(self, prompt: str, stop=None) -> str:
        # Calls the underlying model and ensures a string is returned
        result = self.provider.generate_text(prompt)
        if result is None:
            raise ValueError("LLM provider returned None (likely due to timeout or error)")
        if isinstance(result, dict):
            response = result.get("response")
            if response is None:
                raise ValueError(f"LLM provider returned dict without 'response' key: {result.keys()}")
            return response
        if isinstance(result, str):
            return result
        raise ValueError(f"Unexpected LLM response type: {type(result).__name__}")
    @property
    def _llm_type(self):
        return "custom-provider"
    
    def get_num_tokens(self, text: str) -> int:
        return len(text.split())

class AssistantRagGen:
    def __init__(self):
        config = get_settings()
        self.factory = LLMProviderFactory(config)
        self.generator = self.factory.create(config.GENERATION_BACKEND)
        self.generator.set_generation_model(config.GENERATION_MODEL_ID)
        self.llm = ProviderLLMWrapper(provider=self.generator)
        self.valid_routes = {"user_info", "site_query", "pdf_query"}

    def build_router_prompt(self, user_prompt: str) -> str:
        return f"""You are a query routing classifier. Your sole job is to categorize a user's question into exactly one routing category.

    ## Categories

    | Category     | Routes questions about...                                                                 |
    |--------------|-------------------------------------------------------------------------------------------|
    | `user_info`  | Personal profile, enrolled courses, username, role, learning progress, achievements       |
    | `site_query` | Platform features, website navigation, rules, policies, FAQs, general platform knowledge  |
    | `pdf_query`  | Document content, uploaded files, PDF search, lesson materials, reading resources         |

    ## Examples

    user_info  β†’ "What courses am I enrolled in?"
    user_info  β†’ "What is my current progress in the Python course?"
    site_query β†’ "How do I reset my password?"
    site_query β†’ "What are the platform's refund policies?"
    pdf_query  β†’ "What does the document say about recursion?"
    pdf_query  β†’ "Find me the section on neural networks in the materials"

    ## Decision Rules

    1. If the question involves the **current user's personal data** β†’ `user_info`
    2. If the question is about **how the platform works** β†’ `site_query`
    3. If the question requires **reading or searching a document** β†’ `pdf_query`
    4. When ambiguous, prefer `pdf_query` over `site_query`, and `user_info` over both.

    ## Output Format

    Respond with a single lowercase word. No punctuation. No explanation. No whitespace.

    Valid outputs: user_info | site_query | pdf_query

    Question: {user_prompt}
    """

    def build_unified_prompt(self, context: str, question: str, conversation_history: str = "", User_Info: str = "") -> str:
        return f"""
    You are a helpful university assistant.

    Rules:
    - Use the provided context FIRST.
    - Use conversation history to understand follow-up questions.
    - If the question is about the user, use the User_Info and enrolled_courses.
    - If the answer is not in the context, say:
    "Not found in the provided materials."
    Then add:
    "From my own information:" and answer briefly.
    - Be concise and clear.

    Conversation History:
    {conversation_history if conversation_history else "None"}

    User Info:
    {User_Info if User_Info else "None"}

    Context:
    {context}

    Current Question:
    {question}

    Answer:
    """

    def build_user_info_prompt(self, question: str, conversation_history: str = "", User_Info: str = "") -> str:
        return f"""
    You are a university assistant handling a user account inquiry. 
    Use the provided User Info and Enrolled Courses to answer the question accurately.

    Conversation History:
    {conversation_history if conversation_history else "None"}

    User Info:
    {User_Info if User_Info else "None"}

    Current Question:
    {question}

    Answer:
    """

    def build_site_query_prompt(self, question: str,context:str="", conversation_history: str = "") -> str:
        return f"""
    You are a university assistant handling a platform or site-related question.
    Provide clear instructions, rules, or general information about how the university platform works.

    Conversation History:
    {conversation_history if conversation_history else "None"}

    Current Question:
    {question}

    Site Context:
    {context if context else "None"}

    Answer:
    """

    def robust_router(self, input_data: dict) -> str:
        question = input_data["question"]
        attempts = 0
        while attempts < 3:
            prompt = self.build_router_prompt(question)
            route = self.llm.invoke(prompt).strip().lower()
            
            if route in self.valid_routes:
                return route
            attempts += 1
        return "pdf_query"

    def get_chain(self):
        router_node = RunnableLambda(self.robust_router)

        user_info_chain = RunnableLambda(lambda x: self.llm.invoke(
            self.build_user_info_prompt(
                question=x["question"],
                conversation_history=x.get("conversation_history", ""),
                User_Info=x.get("User_Info", ""),
            )
        ))

        site_query_chain = RunnableLambda(lambda x: self.llm.invoke(
            self.build_site_query_prompt(
                question=x["question"],
                context=x.get("context", ""),
                conversation_history=x.get("conversation_history", "")
            )
        ))

        pdf_query_chain = RunnableLambda(lambda x: self.llm.invoke(
            self.build_unified_prompt(
                context=x.get("context", "No context provided."),
                question=x["question"],
                conversation_history=x.get("conversation_history", ""),
                User_Info=x.get("User_Info", ""),
            )
        ))

        branching_logic = RunnableBranch(
            (lambda x: x["topic"] == "user_info", user_info_chain),
            (lambda x: x["topic"] == "site_query", site_query_chain),
            pdf_query_chain
        )

        full_chain = (
            RunnableParallel({
                "topic": router_node,
                # Pass all incoming variables straight through to the branches
                "question": lambda x: x["question"],
                "context": lambda x: x.get("context", ""),
                "conversation_history": lambda x: x.get("conversation_history", ""),
                "User_Info": lambda x: x.get("User_Info", ""),
                "enrolled_courses": lambda x: x.get("enrolled_courses", "")
            })
            | branching_logic
            | StrOutputParser()
        )
        
        return full_chain