j-js commited on
Commit
c0df734
·
verified ·
1 Parent(s): 0e7b568

Update conversation_logic.py

Browse files
Files changed (1) hide show
  1. conversation_logic.py +113 -68
conversation_logic.py CHANGED
@@ -12,10 +12,6 @@ from retrieval_engine import RetrievalEngine
12
  from utils import short_lines
13
 
14
 
15
- # -----------------------------
16
- # Retrieval intent configuration
17
- # -----------------------------
18
-
19
  RETRIEVAL_ALLOWED_INTENTS = {
20
  "walkthrough",
21
  "step_by_step",
@@ -39,26 +35,64 @@ DIRECT_SOLVE_PATTERNS = [
39
 
40
  STRUCTURE_KEYWORDS = {
41
  "algebra": [
42
- "equation", "solve", "isolate", "variable", "linear", "expression",
43
- "unknown", "algebra", "substitute", "rearrange"
 
 
 
 
 
 
 
 
44
  ],
45
  "percent": [
46
- "percent", "%", "percentage", "increase", "decrease", "of"
 
 
 
 
 
47
  ],
48
  "ratio": [
49
- "ratio", "proportion", "proportional", "part", "share"
 
 
 
 
50
  ],
51
  "statistics": [
52
- "mean", "median", "mode", "range", "average", "standard deviation"
 
 
 
 
 
53
  ],
54
  "probability": [
55
- "probability", "chance", "likely", "odds", "event"
 
 
 
 
56
  ],
57
  "geometry": [
58
- "triangle", "circle", "angle", "area", "perimeter", "radius", "diameter"
 
 
 
 
 
 
59
  ],
60
  "number_properties": [
61
- "integer", "odd", "even", "prime", "divisible", "factor", "multiple"
 
 
 
 
 
 
62
  ],
63
  }
64
 
@@ -70,46 +104,61 @@ INTENT_KEYWORDS = {
70
  "hint": ["hint", "nudge", "clue"],
71
  "definition": ["define", "definition", "what does", "what is meant by"],
72
  "concept": ["concept", "idea", "principle", "rule"],
73
- "instruction": ["how do i", "how to", "what should i do first", "what step"],
74
  }
75
 
76
  MISMATCH_TERMS = {
77
  "algebra": [
78
- "absolute value", "modulus", "square root", "quadratic", "inequality",
79
- "roots", "parabola", "simultaneous equations"
 
 
 
 
 
 
80
  ],
81
  "percent": [
82
- "triangle", "circle", "prime", "absolute value"
 
 
 
83
  ],
84
  "ratio": [
85
- "absolute value", "quadratic", "circle"
 
 
86
  ],
87
  "statistics": [
88
- "absolute value", "prime", "triangle"
 
 
89
  ],
90
  "probability": [
91
- "absolute value", "circle area", "quadratic"
 
 
92
  ],
93
  "geometry": [
94
- "absolute value", "prime", "median salary"
 
 
95
  ],
96
  "number_properties": [
97
- "circle", "triangle", "absolute value"
 
 
98
  ],
99
  }
100
 
101
 
102
- # -----------------------------
103
- # Reply building
104
- # -----------------------------
105
-
106
  def _teaching_lines(chunks: List[RetrievedChunk]) -> List[str]:
107
- lines = []
108
  for chunk in chunks:
109
- text = chunk.text.strip().replace("\n", " ")
110
  if len(text) > 220:
111
  text = text[:217].rstrip() + "…"
112
- topic = getattr(chunk, "topic", "general") or "general"
113
  lines.append(f"- {topic}: {text}")
114
  return lines
115
 
@@ -134,7 +183,7 @@ def _compose_quant_reply(
134
  if intent == "definition":
135
  if steps:
136
  return f"Here is the idea in context:\n- {steps[0]}"
137
- return "This means identifying the mathematical idea being used and expressing it clearly."
138
 
139
  if intent in {"walkthrough", "step_by_step", "explain", "method", "concept"}:
140
  if not steps:
@@ -151,7 +200,6 @@ def _compose_quant_reply(
151
  return f"Walkthrough:\n{body}\n\nThat gives {internal}."
152
  return f"Walkthrough:\n{body}"
153
 
154
- # answer/default
155
  if reveal_answer and internal:
156
  if result.answer_value and str(result.answer_value).startswith("x ="):
157
  return f"The result is {result.answer_value}."
@@ -165,30 +213,27 @@ def _compose_quant_reply(
165
  return "I can help with this, but I cannot confidently solve it from the current parse alone yet."
166
 
167
 
168
- # -----------------------------
169
- # Intent / retrieval helpers
170
- # -----------------------------
171
-
172
  def _normalize_text(text: str) -> str:
173
  return re.sub(r"\s+", " ", (text or "").strip().lower())
174
 
175
 
176
  def _extract_keywords(text: str) -> Set[str]:
177
- raw = re.findall(r"[a-zA-Z][a-zA-Z0-9_+-]*", text.lower())
178
  stop = {
179
  "the", "a", "an", "is", "are", "to", "of", "for", "and", "or", "in", "on",
180
  "at", "by", "this", "that", "it", "be", "do", "i", "me", "my", "you",
181
- "how", "what", "why", "give", "show", "please", "can"
182
  }
183
  return {w for w in raw if len(w) > 2 and w not in stop}
184
 
185
 
186
  def _infer_structure_terms(question_text: str, topic: Optional[str]) -> List[str]:
187
  terms: List[str] = []
 
188
  if topic and topic in STRUCTURE_KEYWORDS:
189
  terms.extend(STRUCTURE_KEYWORDS[topic])
190
 
191
- q = question_text.lower()
192
 
193
  if "=" in q:
194
  terms.extend(["equation", "solve"])
@@ -207,11 +252,14 @@ def _infer_structure_terms(question_text: str, topic: Optional[str]) -> List[str
207
  def _infer_mismatch_terms(topic: Optional[str], question_text: str) -> List[str]:
208
  if not topic or topic not in MISMATCH_TERMS:
209
  return []
210
- q = question_text.lower()
211
- terms = []
 
 
212
  for term in MISMATCH_TERMS[topic]:
213
  if term not in q:
214
  terms.append(term)
 
215
  return terms
216
 
217
 
@@ -226,7 +274,10 @@ def _is_direct_solve_request(text: str, intent: str) -> bool:
226
  t = _normalize_text(text)
227
 
228
  if any(re.search(p, t) for p in DIRECT_SOLVE_PATTERNS):
229
- if not any(word in t for word in ["how", "explain", "why", "method", "hint", "define", "definition", "step"]):
 
 
 
230
  return True
231
 
232
  return False
@@ -251,34 +302,29 @@ def _score_chunk(
251
  topic: Optional[str],
252
  question_text: str,
253
  ) -> float:
254
- text = f"{getattr(chunk, 'topic', '')} {chunk.text}".lower()
255
  score = 0.0
256
 
257
- # topic match
258
  if topic:
259
- chunk_topic = (getattr(chunk, "topic", "") or "").lower()
260
  if chunk_topic == topic.lower():
261
  score += 4.0
262
  elif topic.lower() in text:
263
  score += 2.0
264
 
265
- # structure match
266
  structure_terms = _infer_structure_terms(question_text, topic)
267
  for term in structure_terms:
268
  if term.lower() in text:
269
  score += 1.5
270
 
271
- # intent match
272
  for term in _intent_keywords(intent):
273
  if term.lower() in text:
274
  score += 1.2
275
 
276
- # question keyword overlap
277
  q_keywords = _extract_keywords(question_text)
278
  overlap = sum(1 for kw in q_keywords if kw in text)
279
  score += min(overlap * 0.4, 3.0)
280
 
281
- # penalties for obvious mismatch
282
  mismatch_terms = _infer_mismatch_terms(topic, question_text)
283
  for bad in mismatch_terms:
284
  if bad.lower() in text:
@@ -295,7 +341,8 @@ def _filter_retrieved_chunks(
295
  min_score: float = 2.5,
296
  max_chunks: int = 3,
297
  ) -> List[RetrievedChunk]:
298
- scored = []
 
299
  for chunk in chunks:
300
  s = _score_chunk(chunk, intent, topic, question_text)
301
  if s >= min_score:
@@ -314,7 +361,7 @@ def _build_retrieval_query(
314
  ) -> str:
315
  parts: List[str] = []
316
 
317
- base = question_text.strip() if question_text.strip() else raw_user_text.strip()
318
  if base:
319
  parts.append(base)
320
 
@@ -335,10 +382,6 @@ def _build_retrieval_query(
335
  return " ".join(parts).strip()
336
 
337
 
338
- # -----------------------------
339
- # Public entry point
340
- # -----------------------------
341
-
342
  def generate_response(
343
  raw_user_text: str,
344
  tone: float = 0.5,
@@ -355,21 +398,23 @@ def generate_response(
355
 
356
  intent = detect_intent(user_text)
357
  help_mode = intent_to_help_mode(intent)
358
-
359
  reveal_answer = help_mode == "answer" or transparency >= 0.8
360
 
361
  result = SolverResult(
362
  domain="general",
363
  solved=False,
 
364
  answer_letter=None,
365
  answer_value=None,
 
 
 
366
  internal_answer=None,
367
  steps=[],
368
- topic=None,
 
369
  )
370
 
371
- used_retrieval = False
372
- used_generator = False
373
  selected_chunks: List[RetrievedChunk] = []
374
 
375
  if is_quant_question(solver_input):
@@ -388,7 +433,6 @@ def generate_response(
388
  raw_user_text=user_text or solver_input,
389
  )
390
 
391
- # Use passed-in retrieval context only if retrieval is allowed
392
  if allow_retrieval and retrieval_context:
393
  filtered = _filter_retrieved_chunks(
394
  chunks=retrieval_context,
@@ -398,9 +442,9 @@ def generate_response(
398
  )
399
  if filtered:
400
  selected_chunks = filtered
401
- used_retrieval = True
 
402
 
403
- # Otherwise retrieve fresh if allowed
404
  elif allow_retrieval and retrieval_engine is not None:
405
  query = _build_retrieval_query(
406
  raw_user_text=user_text,
@@ -418,13 +462,12 @@ def generate_response(
418
  )
419
  if filtered:
420
  selected_chunks = filtered
421
- used_retrieval = True
 
422
 
423
- # Add teaching notes only if they survived filtering
424
  if selected_chunks:
425
  reply = f"{reply}\n\nRelevant study notes:\n" + "\n".join(_teaching_lines(selected_chunks))
426
 
427
- # Optional generator fallback for non-quant / weak cases
428
  if not result.solved and generator_engine is not None:
429
  try:
430
  generated = generator_engine.generate(
@@ -435,7 +478,7 @@ def generate_response(
435
  )
436
  if generated and generated.strip():
437
  reply = generated.strip()
438
- used_generator = True
439
  except Exception:
440
  pass
441
 
@@ -446,16 +489,18 @@ def generate_response(
446
  transparency=transparency,
447
  )
448
 
 
 
449
  return {
450
- "reply": short_lines(reply),
451
  "meta": {
452
  "domain": result.domain,
453
  "solved": result.solved,
454
- "help_mode": help_mode,
455
  "answer_letter": result.answer_letter,
456
  "answer_value": result.answer_value,
457
  "topic": result.topic,
458
- "used_retrieval": used_retrieval,
459
- "used_generator": used_generator,
460
  },
461
  }
 
12
  from utils import short_lines
13
 
14
 
 
 
 
 
15
  RETRIEVAL_ALLOWED_INTENTS = {
16
  "walkthrough",
17
  "step_by_step",
 
35
 
36
  STRUCTURE_KEYWORDS = {
37
  "algebra": [
38
+ "equation",
39
+ "solve",
40
+ "isolate",
41
+ "variable",
42
+ "linear",
43
+ "expression",
44
+ "unknown",
45
+ "algebra",
46
+ "substitute",
47
+ "rearrange",
48
  ],
49
  "percent": [
50
+ "percent",
51
+ "%",
52
+ "percentage",
53
+ "increase",
54
+ "decrease",
55
+ "of",
56
  ],
57
  "ratio": [
58
+ "ratio",
59
+ "proportion",
60
+ "proportional",
61
+ "part",
62
+ "share",
63
  ],
64
  "statistics": [
65
+ "mean",
66
+ "median",
67
+ "mode",
68
+ "range",
69
+ "average",
70
+ "standard deviation",
71
  ],
72
  "probability": [
73
+ "probability",
74
+ "chance",
75
+ "likely",
76
+ "odds",
77
+ "event",
78
  ],
79
  "geometry": [
80
+ "triangle",
81
+ "circle",
82
+ "angle",
83
+ "area",
84
+ "perimeter",
85
+ "radius",
86
+ "diameter",
87
  ],
88
  "number_properties": [
89
+ "integer",
90
+ "odd",
91
+ "even",
92
+ "prime",
93
+ "divisible",
94
+ "factor",
95
+ "multiple",
96
  ],
97
  }
98
 
 
104
  "hint": ["hint", "nudge", "clue"],
105
  "definition": ["define", "definition", "what does", "what is meant by"],
106
  "concept": ["concept", "idea", "principle", "rule"],
107
+ "instruction": ["how do i", "how to", "what should i do first", "what step", "first step"],
108
  }
109
 
110
  MISMATCH_TERMS = {
111
  "algebra": [
112
+ "absolute value",
113
+ "modulus",
114
+ "square root",
115
+ "quadratic",
116
+ "inequality",
117
+ "roots",
118
+ "parabola",
119
+ "simultaneous equations",
120
  ],
121
  "percent": [
122
+ "triangle",
123
+ "circle",
124
+ "prime",
125
+ "absolute value",
126
  ],
127
  "ratio": [
128
+ "absolute value",
129
+ "quadratic",
130
+ "circle",
131
  ],
132
  "statistics": [
133
+ "absolute value",
134
+ "prime",
135
+ "triangle",
136
  ],
137
  "probability": [
138
+ "absolute value",
139
+ "circle area",
140
+ "quadratic",
141
  ],
142
  "geometry": [
143
+ "absolute value",
144
+ "prime",
145
+ "median salary",
146
  ],
147
  "number_properties": [
148
+ "circle",
149
+ "triangle",
150
+ "absolute value",
151
  ],
152
  }
153
 
154
 
 
 
 
 
155
  def _teaching_lines(chunks: List[RetrievedChunk]) -> List[str]:
156
+ lines: List[str] = []
157
  for chunk in chunks:
158
+ text = (chunk.text or "").strip().replace("\n", " ")
159
  if len(text) > 220:
160
  text = text[:217].rstrip() + "…"
161
+ topic = chunk.topic or "general"
162
  lines.append(f"- {topic}: {text}")
163
  return lines
164
 
 
183
  if intent == "definition":
184
  if steps:
185
  return f"Here is the idea in context:\n- {steps[0]}"
186
+ return "This is asking for the meaning of the term or operation in the problem."
187
 
188
  if intent in {"walkthrough", "step_by_step", "explain", "method", "concept"}:
189
  if not steps:
 
200
  return f"Walkthrough:\n{body}\n\nThat gives {internal}."
201
  return f"Walkthrough:\n{body}"
202
 
 
203
  if reveal_answer and internal:
204
  if result.answer_value and str(result.answer_value).startswith("x ="):
205
  return f"The result is {result.answer_value}."
 
213
  return "I can help with this, but I cannot confidently solve it from the current parse alone yet."
214
 
215
 
 
 
 
 
216
  def _normalize_text(text: str) -> str:
217
  return re.sub(r"\s+", " ", (text or "").strip().lower())
218
 
219
 
220
  def _extract_keywords(text: str) -> Set[str]:
221
+ raw = re.findall(r"[a-zA-Z][a-zA-Z0-9_+-]*", (text or "").lower())
222
  stop = {
223
  "the", "a", "an", "is", "are", "to", "of", "for", "and", "or", "in", "on",
224
  "at", "by", "this", "that", "it", "be", "do", "i", "me", "my", "you",
225
+ "how", "what", "why", "give", "show", "please", "can",
226
  }
227
  return {w for w in raw if len(w) > 2 and w not in stop}
228
 
229
 
230
  def _infer_structure_terms(question_text: str, topic: Optional[str]) -> List[str]:
231
  terms: List[str] = []
232
+
233
  if topic and topic in STRUCTURE_KEYWORDS:
234
  terms.extend(STRUCTURE_KEYWORDS[topic])
235
 
236
+ q = (question_text or "").lower()
237
 
238
  if "=" in q:
239
  terms.extend(["equation", "solve"])
 
252
  def _infer_mismatch_terms(topic: Optional[str], question_text: str) -> List[str]:
253
  if not topic or topic not in MISMATCH_TERMS:
254
  return []
255
+
256
+ q = (question_text or "").lower()
257
+ terms: List[str] = []
258
+
259
  for term in MISMATCH_TERMS[topic]:
260
  if term not in q:
261
  terms.append(term)
262
+
263
  return terms
264
 
265
 
 
274
  t = _normalize_text(text)
275
 
276
  if any(re.search(p, t) for p in DIRECT_SOLVE_PATTERNS):
277
+ if not any(
278
+ word in t
279
+ for word in ["how", "explain", "why", "method", "hint", "define", "definition", "step"]
280
+ ):
281
  return True
282
 
283
  return False
 
302
  topic: Optional[str],
303
  question_text: str,
304
  ) -> float:
305
+ text = f"{chunk.topic} {chunk.text}".lower()
306
  score = 0.0
307
 
 
308
  if topic:
309
+ chunk_topic = (chunk.topic or "").lower()
310
  if chunk_topic == topic.lower():
311
  score += 4.0
312
  elif topic.lower() in text:
313
  score += 2.0
314
 
 
315
  structure_terms = _infer_structure_terms(question_text, topic)
316
  for term in structure_terms:
317
  if term.lower() in text:
318
  score += 1.5
319
 
 
320
  for term in _intent_keywords(intent):
321
  if term.lower() in text:
322
  score += 1.2
323
 
 
324
  q_keywords = _extract_keywords(question_text)
325
  overlap = sum(1 for kw in q_keywords if kw in text)
326
  score += min(overlap * 0.4, 3.0)
327
 
 
328
  mismatch_terms = _infer_mismatch_terms(topic, question_text)
329
  for bad in mismatch_terms:
330
  if bad.lower() in text:
 
341
  min_score: float = 2.5,
342
  max_chunks: int = 3,
343
  ) -> List[RetrievedChunk]:
344
+ scored: List[tuple[float, RetrievedChunk]] = []
345
+
346
  for chunk in chunks:
347
  s = _score_chunk(chunk, intent, topic, question_text)
348
  if s >= min_score:
 
361
  ) -> str:
362
  parts: List[str] = []
363
 
364
+ base = question_text.strip() if (question_text or "").strip() else (raw_user_text or "").strip()
365
  if base:
366
  parts.append(base)
367
 
 
382
  return " ".join(parts).strip()
383
 
384
 
 
 
 
 
385
  def generate_response(
386
  raw_user_text: str,
387
  tone: float = 0.5,
 
398
 
399
  intent = detect_intent(user_text)
400
  help_mode = intent_to_help_mode(intent)
 
401
  reveal_answer = help_mode == "answer" or transparency >= 0.8
402
 
403
  result = SolverResult(
404
  domain="general",
405
  solved=False,
406
+ help_mode=help_mode,
407
  answer_letter=None,
408
  answer_value=None,
409
+ topic=None,
410
+ used_retrieval=False,
411
+ used_generator=False,
412
  internal_answer=None,
413
  steps=[],
414
+ teaching_chunks=[],
415
+ meta={},
416
  )
417
 
 
 
418
  selected_chunks: List[RetrievedChunk] = []
419
 
420
  if is_quant_question(solver_input):
 
433
  raw_user_text=user_text or solver_input,
434
  )
435
 
 
436
  if allow_retrieval and retrieval_context:
437
  filtered = _filter_retrieved_chunks(
438
  chunks=retrieval_context,
 
442
  )
443
  if filtered:
444
  selected_chunks = filtered
445
+ result.used_retrieval = True
446
+ result.teaching_chunks = filtered
447
 
 
448
  elif allow_retrieval and retrieval_engine is not None:
449
  query = _build_retrieval_query(
450
  raw_user_text=user_text,
 
462
  )
463
  if filtered:
464
  selected_chunks = filtered
465
+ result.used_retrieval = True
466
+ result.teaching_chunks = filtered
467
 
 
468
  if selected_chunks:
469
  reply = f"{reply}\n\nRelevant study notes:\n" + "\n".join(_teaching_lines(selected_chunks))
470
 
 
471
  if not result.solved and generator_engine is not None:
472
  try:
473
  generated = generator_engine.generate(
 
478
  )
479
  if generated and generated.strip():
480
  reply = generated.strip()
481
+ result.used_generator = True
482
  except Exception:
483
  pass
484
 
 
489
  transparency=transparency,
490
  )
491
 
492
+ result.reply = short_lines(reply)
493
+
494
  return {
495
+ "reply": result.reply,
496
  "meta": {
497
  "domain": result.domain,
498
  "solved": result.solved,
499
+ "help_mode": result.help_mode,
500
  "answer_letter": result.answer_letter,
501
  "answer_value": result.answer_value,
502
  "topic": result.topic,
503
+ "used_retrieval": result.used_retrieval,
504
+ "used_generator": result.used_generator,
505
  },
506
  }