| from __future__ import annotations |
|
|
| import re |
| from typing import Any, Dict, List, Optional, Set |
|
|
| from context_parser import detect_intent, intent_to_help_mode |
| from formatting import format_reply |
| from generator_engine import GeneratorEngine |
| from models import RetrievedChunk, SolverResult |
| from quant_solver import is_quant_question |
| from solver_router import route_solver |
| from question_classifier import classify_question, normalize_category |
| from retrieval_engine import RetrievalEngine |
|
|
|
|
| RETRIEVAL_ALLOWED_INTENTS = { |
| "walkthrough", |
| "step_by_step", |
| "explain", |
| "method", |
| "hint", |
| "definition", |
| "concept", |
| "instruction", |
| } |
|
|
| DIRECT_SOLVE_PATTERNS = [ |
| r"\bsolve\b", |
| r"\bwhat is\b", |
| r"\bfind\b", |
| r"\bgive (?:me )?the answer\b", |
| r"\bjust the answer\b", |
| r"\banswer only\b", |
| r"\bcalculate\b", |
| ] |
|
|
| STRUCTURE_KEYWORDS = { |
| "algebra": ["equation", "solve", "isolate", "variable", "linear", "expression", "unknown", "algebra"], |
| "percent": ["percent", "%", "percentage", "increase", "decrease"], |
| "ratio": ["ratio", "proportion", "part", "share"], |
| "statistics": ["mean", "median", "mode", "range", "average"], |
| "probability": ["probability", "chance", "odds"], |
| "geometry": ["triangle", "circle", "angle", "area", "perimeter", "radius", "diameter"], |
| "number_theory": ["integer", "odd", "even", "prime", "divisible", "factor", "multiple", "remainder"], |
| "sequence": ["sequence", "geometric", "arithmetic", "term", "series"], |
| "quant": ["equation", "solve", "value", "integer", "ratio", "percent"], |
| "data": ["data", "mean", "median", "trend", "chart", "table", "correlation"], |
| "verbal": ["grammar", "meaning", "author", "argument", "sentence", "word"], |
| } |
|
|
| INTENT_KEYWORDS = { |
| "walkthrough": ["walkthrough", "work through", "step by step", "full working"], |
| "step_by_step": ["step", "first step", "next step", "step by step"], |
| "explain": ["explain", "why", "understand"], |
| "method": ["method", "approach", "how do i solve", "how to solve"], |
| "hint": ["hint", "nudge", "clue"], |
| "definition": ["define", "definition", "what does", "what is meant by"], |
| "concept": ["concept", "idea", "principle", "rule"], |
| "instruction": ["how do i", "how to", "what should i do first", "what step", "first step"], |
| } |
|
|
| MISMATCH_TERMS = { |
| "algebra": ["absolute value", "modulus", "square root", "quadratic", "inequality", "roots", "parabola"], |
| "percent": ["triangle", "circle", "prime", "absolute value"], |
| "ratio": ["absolute value", "quadratic", "circle"], |
| "statistics": ["absolute value", "prime", "triangle"], |
| "probability": ["absolute value", "circle area", "quadratic"], |
| "geometry": ["absolute value", "prime", "median salary"], |
| "number_theory": ["circle", "triangle", "median salary"], |
| } |
|
|
|
|
| def _normalize_classified_topic(topic: Optional[str], category: Optional[str], question_text: str) -> str: |
| t = (topic or "").strip().lower() |
| q = (question_text or "").lower() |
| c = normalize_category(category) |
|
|
| if t not in {"general_quant", "general", "unknown", ""}: |
| return t |
|
|
| if "%" in q or "percent" in q: |
| return "percent" |
| if "ratio" in q or ":" in q: |
| return "ratio" |
| if "probability" in q or "chosen at random" in q: |
| return "probability" |
| if "divisible" in q or "remainder" in q or "prime" in q or "factor" in q: |
| return "number_theory" |
| if any(k in q for k in ["circle", "triangle", "perimeter", "area", "circumference"]): |
| return "geometry" |
| if any(k in q for k in ["mean", "median", "average", "sales", "revenue"]): |
| return "statistics" if c == "Quantitative" else "data" |
| if "=" in q or "what is x" in q or "what is y" in q or "integer" in q: |
| return "algebra" |
|
|
| if c == "DataInsight": |
| return "data" |
| if c == "Verbal": |
| return "verbal" |
| if c == "Quantitative": |
| return "quant" |
|
|
| return "general" |
|
|
|
|
| def _teaching_lines(chunks: List[RetrievedChunk]) -> List[str]: |
| lines: List[str] = [] |
| for chunk in chunks: |
| text = (chunk.text or "").strip().replace("\n", " ") |
| if len(text) > 220: |
| text = text[:217].rstrip() + "…" |
| topic = chunk.topic or "general" |
| lines.append(f"- {topic}: {text}") |
| return lines |
|
|
|
|
| def _compose_reply( |
| result: SolverResult, |
| intent: str, |
| reveal_answer: bool, |
| verbosity: float, |
| category: Optional[str] = None, |
| ) -> str: |
| steps = result.steps or [] |
| internal = result.internal_answer or result.answer_value or "" |
|
|
| if intent == "hint": |
| return steps[0] if steps else "Start by identifying what the question is really asking." |
|
|
| if intent == "instruction": |
| if steps: |
| return f"First step: {steps[0]}" |
| return "First, identify the key relationship or comparison in the question." |
|
|
| if intent == "definition": |
| if steps: |
| return f"Here is the idea in context:\n- {steps[0]}" |
| return "This is asking for the meaning of the term or idea in the question." |
|
|
| if intent in {"walkthrough", "step_by_step", "explain", "method", "concept"}: |
| if not steps: |
| if reveal_answer and internal: |
| return f"The result is {internal}." |
| return "I can explain the method, but I do not have enough structured steps yet." |
|
|
| shown_steps = steps if verbosity >= 0.66 else steps[: min(3, len(steps))] |
| body = "\n".join(f"- {s}" for s in shown_steps) |
|
|
| if reveal_answer and internal: |
| return f"{body}\n\nThat gives {internal}." |
| return body |
|
|
| if reveal_answer and internal: |
| if result.answer_value: |
| return f"The answer is {result.answer_value}." |
| return f"The result is {internal}." |
|
|
| if steps: |
| return steps[0] |
|
|
| if normalize_category(category) == "Verbal": |
| return "I can help analyse the wording or logic, but I do not have a full verbal solver yet." |
|
|
| if normalize_category(category) == "DataInsight": |
| return "I can help reason through the data, but I cannot confidently solve this from the current parse alone yet." |
|
|
| return "I can help with this, but I cannot confidently solve it from the current parse alone yet." |
|
|
|
|
| def _normalize_text(text: str) -> str: |
| return re.sub(r"\s+", " ", (text or "").strip().lower()) |
|
|
|
|
| def _extract_keywords(text: str) -> Set[str]: |
| raw = re.findall(r"[a-zA-Z][a-zA-Z0-9_+-]*", (text or "").lower()) |
| stop = { |
| "the", "a", "an", "is", "are", "to", "of", "for", "and", "or", "in", "on", "at", "by", "this", "that", |
| "it", "be", "do", "i", "me", "my", "you", "how", "what", "why", "give", "show", "please", "can", |
| } |
| return {w for w in raw if len(w) > 2 and w not in stop} |
|
|
|
|
| def _infer_structure_terms(question_text: str, topic: Optional[str], question_type: Optional[str]) -> List[str]: |
| terms: List[str] = [] |
|
|
| if topic and topic in STRUCTURE_KEYWORDS: |
| terms.extend(STRUCTURE_KEYWORDS[topic]) |
|
|
| if question_type: |
| terms.extend(question_type.replace("_", " ").split()) |
|
|
| q = (question_text or "").lower() |
| if "=" in q: |
| terms.extend(["equation", "solve"]) |
| if "x" in q or "y" in q: |
| terms.extend(["variable", "isolate"]) |
| if "/" in q or "divide" in q: |
| terms.extend(["divide", "undo operations"]) |
| if "*" in q or "times" in q or "multiply" in q: |
| terms.extend(["multiply", "undo operations"]) |
| if "%" in q or "percent" in q: |
| terms.extend(["percent", "percentage"]) |
| if "ratio" in q: |
| terms.extend(["ratio", "proportion"]) |
| if "mean" in q or "average" in q: |
| terms.extend(["mean", "average"]) |
| if "median" in q: |
| terms.extend(["median"]) |
| if "probability" in q: |
| terms.extend(["probability"]) |
| if "remainder" in q or "divisible" in q: |
| terms.extend(["remainder", "divisible"]) |
|
|
| return list(dict.fromkeys(terms)) |
|
|
|
|
| def _infer_mismatch_terms(topic: Optional[str], question_text: str) -> List[str]: |
| if not topic or topic not in MISMATCH_TERMS: |
| return [] |
| q = (question_text or "").lower() |
| return [term for term in MISMATCH_TERMS[topic] if term not in q] |
|
|
|
|
| def _intent_keywords(intent: str) -> List[str]: |
| return INTENT_KEYWORDS.get(intent, []) |
|
|
|
|
| def _is_direct_solve_request(text: str, intent: str) -> bool: |
| if intent == "answer": |
| return True |
|
|
| t = _normalize_text(text) |
| if any(re.search(p, t) for p in DIRECT_SOLVE_PATTERNS): |
| if not any(word in t for word in ["how", "explain", "why", "method", "hint", "define", "definition", "step"]): |
| return True |
| return False |
|
|
|
|
| def should_retrieve(intent: str, solved: bool, raw_user_text: str, category: Optional[str] = None) -> bool: |
| normalized_category = normalize_category(category) |
|
|
| if _is_direct_solve_request(raw_user_text, intent): |
| return (not solved) and normalized_category in {"Verbal", "DataInsight"} |
|
|
| if intent in RETRIEVAL_ALLOWED_INTENTS: |
| return True |
|
|
| if not solved and normalized_category in {"Verbal", "DataInsight"}: |
| return True |
|
|
| return False |
|
|
|
|
| def _score_chunk( |
| chunk: RetrievedChunk, |
| intent: str, |
| topic: Optional[str], |
| question_text: str, |
| question_type: Optional[str] = None, |
| ) -> float: |
| text = f"{chunk.topic} {chunk.text}".lower() |
| score = 0.0 |
|
|
| if topic: |
| chunk_topic = (chunk.topic or "").lower() |
| if chunk_topic == topic.lower(): |
| score += 4.0 |
| elif topic.lower() in text: |
| score += 2.0 |
|
|
| for term in _infer_structure_terms(question_text, topic, question_type): |
| if term.lower() in text: |
| score += 1.5 |
|
|
| for term in _intent_keywords(intent): |
| if term.lower() in text: |
| score += 1.2 |
|
|
| overlap = sum(1 for kw in _extract_keywords(question_text) if kw in text) |
| score += min(overlap * 0.4, 3.0) |
|
|
| for bad in _infer_mismatch_terms(topic, question_text): |
| if bad.lower() in text: |
| score -= 2.5 |
|
|
| return score |
|
|
|
|
| def _filter_retrieved_chunks( |
| chunks: List[RetrievedChunk], |
| intent: str, |
| topic: Optional[str], |
| question_text: str, |
| question_type: Optional[str] = None, |
| min_score: float = 3.2, |
| max_chunks: int = 3, |
| ) -> List[RetrievedChunk]: |
| scored: List[tuple[float, RetrievedChunk]] = [] |
| normalized_topic = (topic or "").lower() |
|
|
| for chunk in chunks: |
| chunk_topic = (chunk.topic or "").lower() |
|
|
| if normalized_topic and normalized_topic not in {"general", "unknown", "general_quant"}: |
| if chunk_topic == "general": |
| continue |
|
|
| s = _score_chunk(chunk, intent, topic, question_text, question_type) |
| if s >= min_score: |
| scored.append((s, chunk)) |
|
|
| scored.sort(key=lambda x: x[0], reverse=True) |
| filtered = [chunk for _, chunk in scored[:max_chunks]] |
| if filtered: |
| return filtered |
|
|
| fallback: List[tuple[float, RetrievedChunk]] = [] |
| for chunk in chunks: |
| s = _score_chunk(chunk, intent, topic, question_text, question_type) |
| if s >= 2.0: |
| fallback.append((s, chunk)) |
|
|
| fallback.sort(key=lambda x: x[0], reverse=True) |
| return [chunk for _, chunk in fallback[:max_chunks]] |
|
|
|
|
| def _build_retrieval_query( |
| raw_user_text: str, |
| question_text: str, |
| intent: str, |
| topic: Optional[str], |
| solved: bool, |
| question_type: Optional[str] = None, |
| category: Optional[str] = None, |
| ) -> str: |
| parts: List[str] = [] |
|
|
| raw = (raw_user_text or "").strip() |
| question = (question_text or "").strip() |
|
|
| |
| if question: |
| parts.append(question) |
| elif raw: |
| lowered = raw.lower() |
|
|
| wrappers = [ |
| "how do i solve", |
| "how to solve", |
| "solve", |
| "can you solve", |
| "walk me through", |
| "explain", |
| "help me solve", |
| "show me how to solve", |
| ] |
|
|
| cleaned = raw |
| for w in wrappers: |
| if lowered.startswith(w): |
| cleaned = raw[len(w):].strip(" :.-?") |
| break |
|
|
| if cleaned: |
| parts.append(cleaned) |
| else: |
| parts.append(raw) |
|
|
| normalized_category = normalize_category(category) |
| if normalized_category and normalized_category != "General": |
| parts.append(normalized_category) |
|
|
| if topic: |
| parts.append(topic) |
|
|
| if question_type: |
| parts.append(question_type.replace("_", " ")) |
|
|
| if intent in {"definition", "concept"}: |
| parts.append("definition concept explanation") |
| elif intent in {"walkthrough", "step_by_step", "method", "instruction"}: |
| parts.append("equation solving method isolate variable worked example") |
| elif intent == "hint": |
| parts.append("equation solving hint first step isolate variable") |
| elif intent == "explain": |
| parts.append("equation solving explanation reasoning") |
| elif not solved: |
| parts.append("teaching explanation method") |
|
|
| return " ".join(parts).strip() |
|
|
| class ConversationEngine: |
| def __init__( |
| self, |
| retriever: Optional[RetrievalEngine] = None, |
| generator: Optional[GeneratorEngine] = None, |
| **kwargs, |
| ) -> None: |
| self.retriever = retriever |
| self.generator = generator |
|
|
| def generate_response( |
| self, |
| raw_user_text: Optional[str] = None, |
| tone: float = 0.5, |
| verbosity: float = 0.5, |
| transparency: float = 0.5, |
| intent: Optional[str] = None, |
| help_mode: Optional[str] = None, |
| retrieval_context: Optional[List[RetrievedChunk]] = None, |
| chat_history: Optional[List[Dict[str, Any]]] = None, |
| question_text: Optional[str] = None, |
| options_text: Optional[List[str]] = None, |
| **kwargs, |
| ) -> SolverResult: |
| solver_input = (question_text or raw_user_text or "").strip() |
| user_text = (raw_user_text or "").strip() |
|
|
| category = normalize_category(kwargs.get("category")) |
| classification = classify_question(question_text=solver_input, category=category) |
| inferred_category = normalize_category(classification.get("category") or category) |
|
|
| question_topic = _normalize_classified_topic( |
| classification.get("topic"), |
| inferred_category, |
| solver_input, |
| ) |
| question_type = classification.get("type") |
|
|
| resolved_intent = intent or detect_intent(user_text, help_mode) |
| resolved_help_mode = help_mode or intent_to_help_mode(resolved_intent) |
| reveal_answer = resolved_help_mode == "answer" or transparency >= 0.8 |
|
|
| result = SolverResult( |
| domain="general", |
| solved=False, |
| help_mode=resolved_help_mode, |
| answer_letter=None, |
| answer_value=None, |
| topic=question_topic, |
| used_retrieval=False, |
| used_generator=False, |
| internal_answer=None, |
| steps=[], |
| teaching_chunks=[], |
| meta={}, |
| ) |
|
|
| selected_chunks: List[RetrievedChunk] = [] |
|
|
| if inferred_category == "Quantitative" or is_quant_question(solver_input): |
| solved_result = route_solver(solver_input) |
| if solved_result is not None: |
| result = solved_result |
| result.help_mode = resolved_help_mode |
| if not result.topic or result.topic in {"general_quant", "general", "unknown"}: |
| result.topic = question_topic |
| result.domain = "quant" |
|
|
| reply = _compose_reply( |
| result=result, |
| intent=resolved_intent, |
| reveal_answer=reveal_answer, |
| verbosity=verbosity, |
| category=inferred_category, |
| ) |
|
|
| allow_retrieval = should_retrieve( |
| intent=resolved_intent, |
| solved=bool(result.solved), |
| raw_user_text=user_text or solver_input, |
| category=inferred_category, |
| ) |
|
|
| if allow_retrieval and retrieval_context: |
| filtered = _filter_retrieved_chunks( |
| chunks=retrieval_context, |
| intent=resolved_intent, |
| topic=result.topic, |
| question_text=solver_input, |
| question_type=question_type, |
| ) |
| if filtered: |
| selected_chunks = filtered |
| result.used_retrieval = True |
| result.teaching_chunks = filtered |
|
|
| elif allow_retrieval and self.retriever is not None: |
| retrieved = self.retriever.search( |
| query=_build_retrieval_query( |
| raw_user_text=user_text, |
| question_text=solver_input, |
| intent=resolved_intent, |
| topic=result.topic, |
| solved=bool(result.solved), |
| question_type=question_type, |
| category=inferred_category, |
| ), |
| topic=result.topic or "", |
| intent=resolved_intent, |
| k=6, |
| ) |
| filtered = _filter_retrieved_chunks( |
| chunks=retrieved, |
| intent=resolved_intent, |
| topic=result.topic, |
| question_text=solver_input, |
| question_type=question_type, |
| ) |
| if filtered: |
| selected_chunks = filtered |
| result.used_retrieval = True |
| result.teaching_chunks = filtered |
|
|
| if selected_chunks and resolved_help_mode != "answer": |
| reply = f"{reply}\n\nRelevant study notes:\n" + "\n".join(_teaching_lines(selected_chunks)) |
|
|
| if not result.solved and self.generator is not None: |
| try: |
| generated = self.generator.generate( |
| user_text=user_text or solver_input, |
| question_text=solver_input, |
| topic=result.topic or "", |
| intent=resolved_intent, |
| retrieval_context=selected_chunks, |
| chat_history=chat_history or [], |
| ) |
| if generated and generated.strip(): |
| reply = generated.strip() |
| result.used_generator = True |
| except Exception: |
| pass |
|
|
| reply = format_reply(reply, tone, verbosity, transparency, resolved_help_mode) |
|
|
| result.reply = reply |
| result.help_mode = resolved_help_mode |
| result.meta = { |
| "intent": resolved_intent, |
| "question_text": question_text or "", |
| "options_count": len(options_text or []), |
| "category": inferred_category, |
| "question_type": question_type, |
| "classified_topic": question_topic, |
| } |
|
|
| return result |