GameAI / conversation_logic.py
j-js's picture
Update conversation_logic.py
8a54dc5 verified
raw
history blame
18.9 kB
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional, Set
from context_parser import detect_intent, intent_to_help_mode
from formatting import format_reply
from generator_engine import GeneratorEngine
from models import RetrievedChunk, SolverResult
from quant_solver import is_quant_question
from solver_router import route_solver
from question_classifier import classify_question, normalize_category
from retrieval_engine import RetrievalEngine
RETRIEVAL_ALLOWED_INTENTS = {
"walkthrough",
"step_by_step",
"explain",
"method",
"hint",
"definition",
"concept",
"instruction",
}
DIRECT_SOLVE_PATTERNS = [
r"\bsolve\b",
r"\bwhat is\b",
r"\bfind\b",
r"\bgive (?:me )?the answer\b",
r"\bjust the answer\b",
r"\banswer only\b",
r"\bcalculate\b",
]
STRUCTURE_KEYWORDS = {
"algebra": ["equation", "solve", "isolate", "variable", "linear", "expression", "unknown", "algebra"],
"percent": ["percent", "%", "percentage", "increase", "decrease"],
"ratio": ["ratio", "proportion", "part", "share"],
"statistics": ["mean", "median", "mode", "range", "average"],
"probability": ["probability", "chance", "odds"],
"geometry": ["triangle", "circle", "angle", "area", "perimeter", "radius", "diameter"],
"number_theory": ["integer", "odd", "even", "prime", "divisible", "factor", "multiple", "remainder"],
"sequence": ["sequence", "geometric", "arithmetic", "term", "series"],
"quant": ["equation", "solve", "value", "integer", "ratio", "percent"],
"data": ["data", "mean", "median", "trend", "chart", "table", "correlation"],
"verbal": ["grammar", "meaning", "author", "argument", "sentence", "word"],
}
INTENT_KEYWORDS = {
"walkthrough": ["walkthrough", "work through", "step by step", "full working"],
"step_by_step": ["step", "first step", "next step", "step by step"],
"explain": ["explain", "why", "understand"],
"method": ["method", "approach", "how do i solve", "how to solve"],
"hint": ["hint", "nudge", "clue"],
"definition": ["define", "definition", "what does", "what is meant by"],
"concept": ["concept", "idea", "principle", "rule"],
"instruction": ["how do i", "how to", "what should i do first", "what step", "first step"],
}
MISMATCH_TERMS = {
"algebra": ["absolute value", "modulus", "square root", "quadratic", "inequality", "roots", "parabola"],
"percent": ["triangle", "circle", "prime", "absolute value"],
"ratio": ["absolute value", "quadratic", "circle"],
"statistics": ["absolute value", "prime", "triangle"],
"probability": ["absolute value", "circle area", "quadratic"],
"geometry": ["absolute value", "prime", "median salary"],
"number_theory": ["circle", "triangle", "median salary"],
}
def _normalize_classified_topic(topic: Optional[str], category: Optional[str], question_text: str) -> str:
t = (topic or "").strip().lower()
q = (question_text or "").lower()
c = normalize_category(category)
if t not in {"general_quant", "general", "unknown", ""}:
return t
if "%" in q or "percent" in q:
return "percent"
if "ratio" in q or ":" in q:
return "ratio"
if "probability" in q or "chosen at random" in q:
return "probability"
if "divisible" in q or "remainder" in q or "prime" in q or "factor" in q:
return "number_theory"
if any(k in q for k in ["circle", "triangle", "perimeter", "area", "circumference"]):
return "geometry"
if any(k in q for k in ["mean", "median", "average", "sales", "revenue"]):
return "statistics" if c == "Quantitative" else "data"
if "=" in q or "what is x" in q or "what is y" in q or "integer" in q:
return "algebra"
if c == "DataInsight":
return "data"
if c == "Verbal":
return "verbal"
if c == "Quantitative":
return "quant"
return "general"
def _teaching_lines(chunks: List[RetrievedChunk]) -> List[str]:
lines: List[str] = []
for chunk in chunks:
text = (chunk.text or "").strip().replace("\n", " ")
if len(text) > 220:
text = text[:217].rstrip() + "…"
topic = chunk.topic or "general"
lines.append(f"- {topic}: {text}")
return lines
def _compose_reply(
result: SolverResult,
intent: str,
reveal_answer: bool,
verbosity: float,
category: Optional[str] = None,
) -> str:
steps = result.steps or []
internal = result.internal_answer or result.answer_value or ""
if intent == "hint":
return steps[0] if steps else "Start by identifying what the question is really asking."
if intent == "instruction":
if steps:
return f"First step: {steps[0]}"
return "First, identify the key relationship or comparison in the question."
if intent == "definition":
if steps:
return f"Here is the idea in context:\n- {steps[0]}"
return "This is asking for the meaning of the term or idea in the question."
if intent in {"walkthrough", "step_by_step", "explain", "method", "concept"}:
if not steps:
if reveal_answer and internal:
return f"The result is {internal}."
return "I can explain the method, but I do not have enough structured steps yet."
shown_steps = steps if verbosity >= 0.66 else steps[: min(3, len(steps))]
body = "\n".join(f"- {s}" for s in shown_steps)
if reveal_answer and internal:
return f"{body}\n\nThat gives {internal}."
return body
if reveal_answer and internal:
if result.answer_value:
return f"The answer is {result.answer_value}."
return f"The result is {internal}."
if steps:
return steps[0]
if normalize_category(category) == "Verbal":
return "I can help analyse the wording or logic, but I do not have a full verbal solver yet."
if normalize_category(category) == "DataInsight":
return "I can help reason through the data, but I cannot confidently solve this from the current parse alone yet."
return "I can help with this, but I cannot confidently solve it from the current parse alone yet."
def _normalize_text(text: str) -> str:
return re.sub(r"\s+", " ", (text or "").strip().lower())
def _extract_keywords(text: str) -> Set[str]:
raw = re.findall(r"[a-zA-Z][a-zA-Z0-9_+-]*", (text or "").lower())
stop = {
"the", "a", "an", "is", "are", "to", "of", "for", "and", "or", "in", "on", "at", "by", "this", "that",
"it", "be", "do", "i", "me", "my", "you", "how", "what", "why", "give", "show", "please", "can",
}
return {w for w in raw if len(w) > 2 and w not in stop}
def _infer_structure_terms(question_text: str, topic: Optional[str], question_type: Optional[str]) -> List[str]:
terms: List[str] = []
if topic and topic in STRUCTURE_KEYWORDS:
terms.extend(STRUCTURE_KEYWORDS[topic])
if question_type:
terms.extend(question_type.replace("_", " ").split())
q = (question_text or "").lower()
if "=" in q:
terms.extend(["equation", "solve"])
if "x" in q or "y" in q:
terms.extend(["variable", "isolate"])
if "/" in q or "divide" in q:
terms.extend(["divide", "undo operations"])
if "*" in q or "times" in q or "multiply" in q:
terms.extend(["multiply", "undo operations"])
if "%" in q or "percent" in q:
terms.extend(["percent", "percentage"])
if "ratio" in q:
terms.extend(["ratio", "proportion"])
if "mean" in q or "average" in q:
terms.extend(["mean", "average"])
if "median" in q:
terms.extend(["median"])
if "probability" in q:
terms.extend(["probability"])
if "remainder" in q or "divisible" in q:
terms.extend(["remainder", "divisible"])
return list(dict.fromkeys(terms))
def _infer_mismatch_terms(topic: Optional[str], question_text: str) -> List[str]:
if not topic or topic not in MISMATCH_TERMS:
return []
q = (question_text or "").lower()
return [term for term in MISMATCH_TERMS[topic] if term not in q]
def _intent_keywords(intent: str) -> List[str]:
return INTENT_KEYWORDS.get(intent, [])
def _is_direct_solve_request(text: str, intent: str) -> bool:
if intent == "answer":
return True
t = _normalize_text(text)
if any(re.search(p, t) for p in DIRECT_SOLVE_PATTERNS):
if not any(word in t for word in ["how", "explain", "why", "method", "hint", "define", "definition", "step"]):
return True
return False
def should_retrieve(intent: str, solved: bool, raw_user_text: str, category: Optional[str] = None) -> bool:
normalized_category = normalize_category(category)
if _is_direct_solve_request(raw_user_text, intent):
return (not solved) and normalized_category in {"Verbal", "DataInsight"}
if intent in RETRIEVAL_ALLOWED_INTENTS:
return True
if not solved and normalized_category in {"Verbal", "DataInsight"}:
return True
return False
def _score_chunk(
chunk: RetrievedChunk,
intent: str,
topic: Optional[str],
question_text: str,
question_type: Optional[str] = None,
) -> float:
text = f"{chunk.topic} {chunk.text}".lower()
score = 0.0
if topic:
chunk_topic = (chunk.topic or "").lower()
if chunk_topic == topic.lower():
score += 4.0
elif topic.lower() in text:
score += 2.0
for term in _infer_structure_terms(question_text, topic, question_type):
if term.lower() in text:
score += 1.5
for term in _intent_keywords(intent):
if term.lower() in text:
score += 1.2
overlap = sum(1 for kw in _extract_keywords(question_text) if kw in text)
score += min(overlap * 0.4, 3.0)
for bad in _infer_mismatch_terms(topic, question_text):
if bad.lower() in text:
score -= 2.5
return score
def _filter_retrieved_chunks(
chunks: List[RetrievedChunk],
intent: str,
topic: Optional[str],
question_text: str,
question_type: Optional[str] = None,
min_score: float = 3.2,
max_chunks: int = 3,
) -> List[RetrievedChunk]:
scored: List[tuple[float, RetrievedChunk]] = []
normalized_topic = (topic or "").lower()
for chunk in chunks:
chunk_topic = (chunk.topic or "").lower()
if normalized_topic and normalized_topic not in {"general", "unknown", "general_quant"}:
if chunk_topic == "general":
continue
s = _score_chunk(chunk, intent, topic, question_text, question_type)
if s >= min_score:
scored.append((s, chunk))
scored.sort(key=lambda x: x[0], reverse=True)
filtered = [chunk for _, chunk in scored[:max_chunks]]
if filtered:
return filtered
fallback: List[tuple[float, RetrievedChunk]] = []
for chunk in chunks:
s = _score_chunk(chunk, intent, topic, question_text, question_type)
if s >= 2.0:
fallback.append((s, chunk))
fallback.sort(key=lambda x: x[0], reverse=True)
return [chunk for _, chunk in fallback[:max_chunks]]
def _build_retrieval_query(
raw_user_text: str,
question_text: str,
intent: str,
topic: Optional[str],
solved: bool,
question_type: Optional[str] = None,
category: Optional[str] = None,
) -> str:
parts: List[str] = []
raw = (raw_user_text or "").strip()
question = (question_text or "").strip()
# Prefer the actual math content, not the conversational wrapper
if question:
parts.append(question)
elif raw:
lowered = raw.lower()
wrappers = [
"how do i solve",
"how to solve",
"solve",
"can you solve",
"walk me through",
"explain",
"help me solve",
"show me how to solve",
]
cleaned = raw
for w in wrappers:
if lowered.startswith(w):
cleaned = raw[len(w):].strip(" :.-?")
break
if cleaned:
parts.append(cleaned)
else:
parts.append(raw)
normalized_category = normalize_category(category)
if normalized_category and normalized_category != "General":
parts.append(normalized_category)
if topic:
parts.append(topic)
if question_type:
parts.append(question_type.replace("_", " "))
if intent in {"definition", "concept"}:
parts.append("definition concept explanation")
elif intent in {"walkthrough", "step_by_step", "method", "instruction"}:
parts.append("equation solving method isolate variable worked example")
elif intent == "hint":
parts.append("equation solving hint first step isolate variable")
elif intent == "explain":
parts.append("equation solving explanation reasoning")
elif not solved:
parts.append("teaching explanation method")
return " ".join(parts).strip()
class ConversationEngine:
def __init__(
self,
retriever: Optional[RetrievalEngine] = None,
generator: Optional[GeneratorEngine] = None,
**kwargs,
) -> None:
self.retriever = retriever
self.generator = generator
def generate_response(
self,
raw_user_text: Optional[str] = None,
tone: float = 0.5,
verbosity: float = 0.5,
transparency: float = 0.5,
intent: Optional[str] = None,
help_mode: Optional[str] = None,
retrieval_context: Optional[List[RetrievedChunk]] = None,
chat_history: Optional[List[Dict[str, Any]]] = None,
question_text: Optional[str] = None,
options_text: Optional[List[str]] = None,
**kwargs,
) -> SolverResult:
solver_input = (question_text or raw_user_text or "").strip()
user_text = (raw_user_text or "").strip()
category = normalize_category(kwargs.get("category"))
classification = classify_question(question_text=solver_input, category=category)
inferred_category = normalize_category(classification.get("category") or category)
question_topic = _normalize_classified_topic(
classification.get("topic"),
inferred_category,
solver_input,
)
question_type = classification.get("type")
resolved_intent = intent or detect_intent(user_text, help_mode)
resolved_help_mode = help_mode or intent_to_help_mode(resolved_intent)
reveal_answer = resolved_help_mode == "answer" or transparency >= 0.8
result = SolverResult(
domain="general",
solved=False,
help_mode=resolved_help_mode,
answer_letter=None,
answer_value=None,
topic=question_topic,
used_retrieval=False,
used_generator=False,
internal_answer=None,
steps=[],
teaching_chunks=[],
meta={},
)
selected_chunks: List[RetrievedChunk] = []
if inferred_category == "Quantitative" or is_quant_question(solver_input):
solved_result = route_solver(solver_input)
if solved_result is not None:
result = solved_result
result.help_mode = resolved_help_mode
if not result.topic or result.topic in {"general_quant", "general", "unknown"}:
result.topic = question_topic
result.domain = "quant"
reply = _compose_reply(
result=result,
intent=resolved_intent,
reveal_answer=reveal_answer,
verbosity=verbosity,
category=inferred_category,
)
allow_retrieval = should_retrieve(
intent=resolved_intent,
solved=bool(result.solved),
raw_user_text=user_text or solver_input,
category=inferred_category,
)
if allow_retrieval and retrieval_context:
filtered = _filter_retrieved_chunks(
chunks=retrieval_context,
intent=resolved_intent,
topic=result.topic,
question_text=solver_input,
question_type=question_type,
)
if filtered:
selected_chunks = filtered
result.used_retrieval = True
result.teaching_chunks = filtered
elif allow_retrieval and self.retriever is not None:
retrieved = self.retriever.search(
query=_build_retrieval_query(
raw_user_text=user_text,
question_text=solver_input,
intent=resolved_intent,
topic=result.topic,
solved=bool(result.solved),
question_type=question_type,
category=inferred_category,
),
topic=result.topic or "",
intent=resolved_intent,
k=6,
)
filtered = _filter_retrieved_chunks(
chunks=retrieved,
intent=resolved_intent,
topic=result.topic,
question_text=solver_input,
question_type=question_type,
)
if filtered:
selected_chunks = filtered
result.used_retrieval = True
result.teaching_chunks = filtered
if selected_chunks and resolved_help_mode != "answer":
reply = f"{reply}\n\nRelevant study notes:\n" + "\n".join(_teaching_lines(selected_chunks))
if not result.solved and self.generator is not None:
try:
generated = self.generator.generate(
user_text=user_text or solver_input,
question_text=solver_input,
topic=result.topic or "",
intent=resolved_intent,
retrieval_context=selected_chunks,
chat_history=chat_history or [],
)
if generated and generated.strip():
reply = generated.strip()
result.used_generator = True
except Exception:
pass
reply = format_reply(reply, tone, verbosity, transparency, resolved_help_mode)
result.reply = reply
result.help_mode = resolved_help_mode
result.meta = {
"intent": resolved_intent,
"question_text": question_text or "",
"options_count": len(options_text or []),
"category": inferred_category,
"question_type": question_type,
"classified_topic": question_topic,
}
return result