Spaces:
Running
Running
| """Router node for classifying questions and directing to appropriate handlers.""" | |
| import string | |
| from typing import Literal | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from src.data_processing.formatting import format_choices | |
| from src.state import GraphState | |
| from src.utils.llm import get_small_model | |
| from src.utils.logging import print_log | |
| from src.utils.prompts import load_prompt | |
| def _find_refusal_option(state: GraphState) -> str | None: | |
| """Find refusal option in choices and return corresponding letter.""" | |
| all_choices = state["all_choices"] | |
| option_labels = list(string.ascii_uppercase[:len(all_choices)]) | |
| refusal_patterns = [ | |
| "tôi không thể", "không thể trả lời", "không thể cung cấp", "không thể chia sẻ", | |
| "từ chối trả lời", "từ chối cung cấp", | |
| "nằm ngoài phạm vi", "không thuộc phạm vi", "tôi là mô hình ngôn ngữ", | |
| "hành vi vi phạm", "trái pháp luật", "không hỗ trợ", | |
| ] | |
| for i, choice in enumerate(all_choices): | |
| txt = choice.lower().strip() | |
| if any(p in txt for p in refusal_patterns): | |
| return option_labels[i] | |
| return None | |
| def _classify_with_llm(state: GraphState) -> str: | |
| """Classify question using LLM.""" | |
| choices_text = format_choices(state["all_choices"]) | |
| llm = get_small_model() | |
| system_prompt = load_prompt("router.j2", "system") | |
| user_prompt = load_prompt("router.j2", "user", question=state["question"], choices=choices_text) | |
| # Escape curly braces to prevent LangChain from parsing them as variables | |
| system_prompt = system_prompt.replace("{", "{{").replace("}", "}}") | |
| user_prompt = user_prompt.replace("{", "{{").replace("}", "}}") | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system_prompt), | |
| ("human", user_prompt), | |
| ]) | |
| chain = prompt | llm | |
| response = chain.invoke({}) | |
| return response.content.strip().lower() | |
| def router_node(state: GraphState) -> dict: | |
| """Analyze question and determine routing path. Returns answer immediately for toxic content.""" | |
| question = state["question"].lower() | |
| # Fast-track: Direct answer for reading comprehension | |
| direct_keywords = ["đoạn thông tin", "đoạn văn", "bài đọc", "căn cứ vào đoạn", "theo đoạn"] | |
| if any(k in question for k in direct_keywords) and len(question.split()) > 50: | |
| print_log(" [Router] Fast-track: Direct Answer (Found Context block)") | |
| return {"route": "direct"} | |
| # Fast-track: Math/Logic for LaTeX or math keywords | |
| math_signals = [ | |
| "$", "\\frac", "^", "=", "tính giá trị", "biểu thức", "phương trình", | |
| "hàm số", "đạo hàm", "xác suất", "lãi suất", "vận tốc", "gia tốc", | |
| "điện trở", "gam", "mol", "nguyên tử khối", "gdp", "lạm phát", "công suất" | |
| ] | |
| if any(s in question for s in math_signals): | |
| print_log(" [Router] Fast-track: Math (Keywords/LaTeX detected)") | |
| return {"route": "math"} | |
| print_log(" [Router] Slow-track: Using LLM to classify...") | |
| try: | |
| route = _classify_with_llm(state) | |
| print_log(f" [Router] LLM Decision: {route}") | |
| if "direct" in route: | |
| route_type = "direct" | |
| elif "math" in route or "logic" in route: | |
| route_type = "math" | |
| elif "toxic" in route: | |
| refusal_answer = _find_refusal_option(state) | |
| if refusal_answer: | |
| print_log(f" [Router] Toxic detected, found refusal option: {refusal_answer}") | |
| return {"route": "toxic", "answer": refusal_answer} | |
| print_log(" [Router] Toxic detected, no refusal option found, defaulting to A") | |
| return {"route": "toxic", "answer": "A"} | |
| else: | |
| route_type = "rag" | |
| return {"route": route_type} | |
| except Exception as e: | |
| print_log(f" [Router] Error: {e}. Fallback to RAG.") | |
| return {"route": "rag"} | |
| def route_question(state: GraphState) -> Literal["knowledge_rag", "logic_solver", "direct_answer", "__end__"]: | |
| """Conditional edge function to route to appropriate node based on state route.""" | |
| route = state.get("route", "rag") | |
| answer = state.get("answer") | |
| if route == "toxic": | |
| return "__end__" | |
| if route == "direct": | |
| return "direct_answer" | |
| if route == "math": | |
| return "logic_solver" | |
| # Fallback to direct_answer for RAG questions (no vector DB in production) | |
| # Direct agent can answer general knowledge questions using LLM knowledge | |
| return "direct_answer" | |