j-js commited on
Commit
ccbb9de
·
verified ·
1 Parent(s): 850fc95

Update conversation_logic.py

Browse files
Files changed (1) hide show
  1. conversation_logic.py +405 -597
conversation_logic.py CHANGED
@@ -1,628 +1,436 @@
1
  from __future__ import annotations
2
 
 
3
  import re
4
- from typing import Any, Dict, List, Optional, Set
5
-
6
- from context_parser import detect_intent, intent_to_help_mode
7
- from formatting import format_reply
8
- from generator_engine import GeneratorEngine
9
- from models import RetrievedChunk, SolverResult
10
- from quant_solver import is_quant_question, solve_quant
11
- from question_classifier import classify_question
12
- from retrieval_engine import RetrievalEngine
13
-
14
-
15
- RETRIEVAL_ALLOWED_INTENTS = {
16
- "walkthrough",
17
- "step_by_step",
18
- "explain",
19
- "method",
20
- "hint",
21
- "definition",
22
- "concept",
23
- "instruction",
24
- }
25
-
26
- DIRECT_SOLVE_PATTERNS = [
27
- r"\bsolve\b",
28
- r"\bwhat is\b",
29
- r"\bfind\b",
30
- r"\bgive (?:me )?the answer\b",
31
- r"\bjust the answer\b",
32
- r"\banswer only\b",
33
- r"\bcalculate\b",
34
- ]
35
-
36
- STRUCTURE_KEYWORDS = {
37
- "algebra": [
38
- "equation", "solve", "isolate", "variable", "linear", "expression",
39
- "unknown", "algebra", "substitute", "rearrange",
40
- ],
41
- "percent": [
42
- "percent", "%", "percentage", "increase", "decrease", "of",
43
- ],
44
- "ratio": [
45
- "ratio", "proportion", "proportional", "part", "share",
46
- ],
47
- "statistics": [
48
- "mean", "median", "mode", "range", "average", "standard deviation",
49
- ],
50
- "probability": [
51
- "probability", "chance", "likely", "odds", "event",
52
- ],
53
- "geometry": [
54
- "triangle", "circle", "angle", "area", "perimeter", "radius", "diameter",
55
- ],
56
- "number_properties": [
57
- "integer", "odd", "even", "prime", "divisible", "factor", "multiple",
58
- ],
59
- "number_theory": [
60
- "integer", "odd", "even", "prime", "divisible", "factor", "multiple", "remainder",
61
- ],
62
- "sequence": [
63
- "sequence", "geometric", "arithmetic", "term", "series",
64
- ],
65
- "quant": [
66
- "equation", "solve", "value", "integer", "ratio", "percent",
67
- ],
68
- "data": [
69
- "data", "mean", "median", "trend", "chart", "table", "correlation",
70
- ],
71
- "verbal": [
72
- "grammar", "meaning", "author", "argument", "sentence", "word",
73
- ],
74
- "reasoning": [
75
- "argument", "assume", "conclusion", "evidence", "author",
76
- ],
77
- "vocabulary": [
78
- "meaning", "definition", "word", "closest in meaning",
79
- ],
80
- "grammar": [
81
- "grammar", "sentence", "verb", "agreement", "idiom", "modifier",
82
- ],
83
- }
84
-
85
- INTENT_KEYWORDS = {
86
- "walkthrough": ["walkthrough", "work through", "step by step", "full working"],
87
- "step_by_step": ["step", "first step", "next step", "step by step"],
88
- "explain": ["explain", "why", "understand"],
89
- "method": ["method", "approach", "how do i solve", "how to solve"],
90
- "hint": ["hint", "nudge", "clue"],
91
- "definition": ["define", "definition", "what does", "what is meant by"],
92
- "concept": ["concept", "idea", "principle", "rule"],
93
- "instruction": ["how do i", "how to", "what should i do first", "what step", "first step"],
94
- }
95
-
96
- MISMATCH_TERMS = {
97
- "algebra": [
98
- "absolute value", "modulus", "square root", "quadratic", "inequality",
99
- "roots", "parabola", "simultaneous equations",
100
- ],
101
- "percent": ["triangle", "circle", "prime", "absolute value"],
102
- "ratio": ["absolute value", "quadratic", "circle"],
103
- "statistics": ["absolute value", "prime", "triangle"],
104
- "probability": ["absolute value", "circle area", "quadratic"],
105
- "geometry": ["absolute value", "prime", "median salary"],
106
- "number_properties": ["circle", "triangle", "absolute value"],
107
- "number_theory": ["circle", "triangle", "median salary"],
108
- }
109
-
110
-
111
- def _normalize_classified_topic(
112
- topic: Optional[str],
113
- category: Optional[str],
114
- question_text: str,
115
- ) -> Optional[str]:
116
- t = (topic or "").strip().lower()
117
- q = (question_text or "").lower()
118
- c = (category or "").strip()
119
-
120
- if t not in {"general_quant", "general", "unknown", ""}:
121
- return topic
122
-
123
- if "%" in q or "percent" in q:
124
- return "percent"
125
-
126
- if "ratio" in q or ":" in q:
127
- return "ratio"
128
-
129
- if "probability" in q or "chosen at random" in q:
130
- return "probability"
131
-
132
- if "divisible" in q or "remainder" in q or "prime" in q or "factor" in q:
133
- return "number_theory"
134
-
135
- if "circle" in q or "triangle" in q or "perimeter" in q or "area" in q or "circumference" in q:
136
- return "geometry"
137
-
138
- if "mean" in q or "median" in q or "average" in q or "sales" in q or "revenue" in q:
139
- if c == "Quantitative":
140
- return "statistics"
141
- return "data"
142
-
143
- if "=" in q or "what is x" in q or "what is y" in q or "integer" in q:
144
- return "algebra"
145
-
146
- if c == "DataInsight":
147
- return "data"
148
-
149
- if c == "Verbal":
150
- return "verbal"
151
-
152
- if c == "Quantitative":
153
- return "quant"
154
-
155
- return "general"
156
-
157
- return topic
158
-
159
-
160
- def _teaching_lines(chunks: List[RetrievedChunk]) -> List[str]:
161
- lines: List[str] = []
162
- for chunk in chunks:
163
- text = (chunk.text or "").strip().replace("\n", " ")
164
- if len(text) > 220:
165
- text = text[:217].rstrip() + "…"
166
- topic = chunk.topic or "general"
167
- lines.append(f"- {topic}: {text}")
168
- return lines
169
-
170
-
171
- def _compose_reply(
172
- result: SolverResult,
173
- intent: str,
174
- reveal_answer: bool,
175
- verbosity: float,
176
- category: Optional[str] = None,
177
- question_type: Optional[str] = None,
178
- ) -> str:
179
- steps = result.steps or []
180
- internal = result.internal_answer or result.answer_value or ""
181
-
182
- if intent == "hint":
183
- return steps[0] if steps else "Start by identifying what the question is really asking."
184
-
185
- if intent == "instruction":
186
- if steps:
187
- return f"First step: {steps[0]}"
188
- return "First, identify the key relationship or comparison in the question."
189
-
190
- if intent == "definition":
191
- if steps:
192
- return f"Here is the idea in context:\n- {steps[0]}"
193
- return "This is asking for the meaning of the term or idea in the question."
194
-
195
- if intent in {"walkthrough", "step_by_step", "explain", "method", "concept"}:
196
- if not steps:
197
- if reveal_answer and internal:
198
- return f"The result is {internal}."
199
- return "I can explain the method, but I do not have enough structured steps yet."
200
 
201
- shown_steps = steps if verbosity >= 0.66 else steps[: min(3, len(steps))]
202
- body = "\n".join(f"- {s}" for s in shown_steps)
 
 
203
 
204
- if reveal_answer and internal:
205
- return f"Walkthrough:\n{body}\n\nThat gives {internal}."
206
- return f"Walkthrough:\n{body}"
207
 
208
- if reveal_answer and internal:
209
- if result.answer_value and str(result.answer_value).startswith("x ="):
210
- return f"The result is {result.answer_value}."
211
- if result.answer_value:
212
- return f"The answer is {result.answer_value}."
213
- return f"The result is {internal}."
214
-
215
- if steps:
216
- return steps[0]
217
-
218
- if category == "Verbal":
219
- return "I can help analyse the wording or logic, but I do not have a full verbal solver yet."
220
-
221
- if category == "DataInsight":
222
- return "I can help reason through the data, but I cannot confidently solve this from the current parse alone yet."
223
-
224
- return "I can help with this, but I cannot confidently solve it from the current parse alone yet."
225
-
226
-
227
- def _normalize_text(text: str) -> str:
228
- return re.sub(r"\s+", " ", (text or "").strip().lower())
229
-
230
-
231
- def _extract_keywords(text: str) -> Set[str]:
232
- raw = re.findall(r"[a-zA-Z][a-zA-Z0-9_+-]*", (text or "").lower())
233
- stop = {
234
- "the", "a", "an", "is", "are", "to", "of", "for", "and", "or", "in", "on",
235
- "at", "by", "this", "that", "it", "be", "do", "i", "me", "my", "you",
236
- "how", "what", "why", "give", "show", "please", "can",
237
- }
238
- return {w for w in raw if len(w) > 2 and w not in stop}
239
 
 
 
 
 
 
 
 
 
 
240
 
241
- def _infer_structure_terms(
242
- question_text: str,
243
- topic: Optional[str],
244
- question_type: Optional[str],
245
- ) -> List[str]:
246
- terms: List[str] = []
247
 
248
- if topic and topic in STRUCTURE_KEYWORDS:
249
- terms.extend(STRUCTURE_KEYWORDS[topic])
250
 
251
- if question_type:
252
- terms.extend(question_type.replace("_", " ").split())
253
 
254
- q = (question_text or "").lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
- if "=" in q:
257
- terms.extend(["equation", "solve"])
258
- if "x" in q or "y" in q:
259
- terms.extend(["variable", "isolate"])
260
- if "/" in q or "divide" in q:
261
- terms.extend(["divide", "undo operations"])
262
- if "*" in q or "times" in q or "multiply" in q:
263
- terms.extend(["multiply", "undo operations"])
264
- if "%" in q or "percent" in q:
265
- terms.extend(["percent", "percentage"])
266
- if "ratio" in q:
267
- terms.extend(["ratio", "proportion"])
268
- if "mean" in q or "average" in q:
269
- terms.extend(["mean", "average"])
270
- if "median" in q:
271
- terms.extend(["median"])
272
- if "probability" in q:
273
- terms.extend(["probability"])
274
- if "remainder" in q or "divisible" in q:
275
- terms.extend(["remainder", "divisible"])
276
 
277
- return list(dict.fromkeys(terms))
 
 
 
 
 
 
278
 
279
 
280
- def _infer_mismatch_terms(topic: Optional[str], question_text: str) -> List[str]:
281
- if not topic or topic not in MISMATCH_TERMS:
282
- return []
 
283
 
284
- q = (question_text or "").lower()
285
- return [term for term in MISMATCH_TERMS[topic] if term not in q]
 
 
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
- def _intent_keywords(intent: str) -> List[str]:
289
- return INTENT_KEYWORDS.get(intent, [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
 
291
 
292
- def _is_direct_solve_request(text: str, intent: str) -> bool:
293
- if intent == "answer":
294
- return True
295
 
296
- t = _normalize_text(text)
297
- if any(re.search(p, t) for p in DIRECT_SOLVE_PATTERNS):
298
- if not any(
299
- word in t
300
- for word in ["how", "explain", "why", "method", "hint", "define", "definition", "step"]
301
- ):
302
- return True
303
- return False
304
 
 
 
 
 
 
 
 
 
 
 
305
 
306
- def should_retrieve(
307
- intent: str,
308
- solved: bool,
309
- raw_user_text: str,
310
- category: Optional[str] = None,
311
- ) -> bool:
312
- if _is_direct_solve_request(raw_user_text, intent):
313
- return (not solved) and category in {"Verbal", "DataInsight"}
 
 
314
 
315
- if intent in RETRIEVAL_ALLOWED_INTENTS:
316
- return True
317
 
318
- if not solved and category in {"Verbal", "DataInsight"}:
319
- return True
320
 
321
- return False
 
 
322
 
 
 
 
323
 
324
- def _score_chunk(
325
- chunk: RetrievedChunk,
326
- intent: str,
327
- topic: Optional[str],
328
- question_text: str,
329
- question_type: Optional[str] = None,
330
- ) -> float:
331
- text = f"{chunk.topic} {chunk.text}".lower()
332
- score = 0.0
333
-
334
- if topic:
335
- chunk_topic = (chunk.topic or "").lower()
336
- if chunk_topic == topic.lower():
337
- score += 4.0
338
- elif topic.lower() in text:
339
- score += 2.0
340
-
341
- for term in _infer_structure_terms(question_text, topic, question_type):
342
- if term.lower() in text:
343
- score += 1.5
344
-
345
- for term in _intent_keywords(intent):
346
- if term.lower() in text:
347
- score += 1.2
348
-
349
- overlap = sum(1 for kw in _extract_keywords(question_text) if kw in text)
350
- score += min(overlap * 0.4, 3.0)
351
-
352
- for bad in _infer_mismatch_terms(topic, question_text):
353
- if bad.lower() in text:
354
- score -= 2.5
355
-
356
- return score
357
-
358
-
359
- def _filter_retrieved_chunks(
360
- chunks: List[RetrievedChunk],
361
- intent: str,
362
- topic: Optional[str],
363
- question_text: str,
364
- question_type: Optional[str] = None,
365
- min_score: float = 3.2,
366
- max_chunks: int = 3,
367
- ) -> List[RetrievedChunk]:
368
- scored: List[tuple[float, RetrievedChunk]] = []
369
- normalized_topic = (topic or "").lower()
370
-
371
- for chunk in chunks:
372
- chunk_topic = (chunk.topic or "").lower()
373
-
374
- if normalized_topic and normalized_topic not in {"general", "unknown", "general_quant"}:
375
- if chunk_topic == "general":
376
- continue
377
-
378
- s = _score_chunk(chunk, intent, topic, question_text, question_type)
379
- if s >= min_score:
380
- scored.append((s, chunk))
381
-
382
- scored.sort(key=lambda x: x[0], reverse=True)
383
- filtered = [chunk for _, chunk in scored[:max_chunks]]
384
-
385
- if filtered:
386
- return filtered
387
-
388
- fallback: List[tuple[float, RetrievedChunk]] = []
389
- for chunk in chunks:
390
- s = _score_chunk(chunk, intent, topic, question_text, question_type)
391
- if s >= 2.0:
392
- fallback.append((s, chunk))
393
-
394
- fallback.sort(key=lambda x: x[0], reverse=True)
395
- return [chunk for _, chunk in fallback[:max_chunks]]
396
-
397
-
398
- def _build_retrieval_query(
399
- raw_user_text: str,
400
- question_text: str,
401
- intent: str,
402
- topic: Optional[str],
403
- solved: bool,
404
- question_type: Optional[str] = None,
405
- category: Optional[str] = None,
406
- ) -> str:
407
- parts: List[str] = []
408
-
409
- base = (question_text or "").strip() or (raw_user_text or "").strip()
410
- if base:
411
- parts.append(base)
412
-
413
- if category:
414
- parts.append(category)
415
-
416
- if topic:
417
- parts.append(topic)
418
-
419
- if question_type:
420
- parts.append(question_type.replace("_", " "))
421
-
422
- if intent in {"definition", "concept"}:
423
- parts.append("definition concept explanation")
424
- elif intent in {"walkthrough", "step_by_step", "method", "instruction"}:
425
- parts.append("method steps worked example")
426
- elif intent == "hint":
427
- parts.append("hint strategy first step")
428
- elif intent == "explain":
429
- parts.append("explanation reasoning")
430
- elif not solved:
431
- parts.append("teaching explanation method")
432
-
433
- return " ".join(parts).strip()
434
-
435
-
436
- class ConversationEngine:
437
- def __init__(
438
- self,
439
- retriever: Optional[RetrievalEngine] = None,
440
- generator: Optional[GeneratorEngine] = None,
441
- **kwargs,
442
- ) -> None:
443
- self.retriever = retriever
444
- self.generator = generator
445
-
446
- def generate_response(
447
- self,
448
- raw_user_text: Optional[str] = None,
449
- tone: float = 0.5,
450
- verbosity: float = 0.5,
451
- transparency: float = 0.5,
452
- intent: Optional[str] = None,
453
- help_mode: Optional[str] = None,
454
- retrieval_context: Optional[List[RetrievedChunk]] = None,
455
- chat_history: Optional[List[Dict[str, Any]]] = None,
456
- question_text: Optional[str] = None,
457
- options_text: Optional[List[str]] = None,
458
- **kwargs,
459
- ) -> SolverResult:
460
- solver_input = (question_text or raw_user_text or "").strip()
461
- user_text = (raw_user_text or "").strip()
462
-
463
- category = kwargs.get("category")
464
- classification = classify_question(
465
- question_text=solver_input,
466
- category=category,
467
- )
468
- inferred_category = classification.get("category") or category
469
-
470
- if not inferred_category:
471
- q = solver_input.lower()
472
- if any(
473
- k in q
474
- for k in [
475
- "percent",
476
- "%",
477
- "ratio",
478
- "divisible",
479
- "remainder",
480
- "probability",
481
- "circle",
482
- "triangle",
483
- "=",
484
- ]
485
- ):
486
- inferred_category = "Quantitative"
487
- elif any(
488
- k in q
489
- for k in [
490
- "sales",
491
- "revenue",
492
- "median",
493
- "mean",
494
- "chart",
495
- "table",
496
- "scatter",
497
- "distribution",
498
- ]
499
- ):
500
- inferred_category = "DataInsight"
501
- else:
502
- inferred_category = "General"
503
-
504
- question_topic = _normalize_classified_topic(
505
- classification.get("topic"),
506
- inferred_category,
507
- solver_input,
508
- )
509
- question_type = classification.get("type")
510
-
511
- resolved_intent = intent or detect_intent(user_text, help_mode)
512
- resolved_help_mode = help_mode or intent_to_help_mode(resolved_intent)
513
- reveal_answer = resolved_help_mode == "answer" or transparency >= 0.8
514
-
515
- result = SolverResult(
516
- domain="general",
517
- solved=False,
518
- help_mode=resolved_help_mode,
519
- answer_letter=None,
520
- answer_value=None,
521
- topic=question_topic,
522
- used_retrieval=False,
523
- used_generator=False,
524
- internal_answer=None,
525
- steps=[],
526
- teaching_chunks=[],
527
- meta={},
528
- )
529
 
530
- selected_chunks: List[RetrievedChunk] = []
531
-
532
- if inferred_category == "Quantitative" or is_quant_question(solver_input):
533
- solved_result = solve_quant(solver_input)
534
- if solved_result is not None:
535
- result = solved_result
536
- result.help_mode = resolved_help_mode
537
- if not result.topic or result.topic in {"general_quant", "general", "unknown"}:
538
- result.topic = question_topic
539
- if result.domain == "general":
540
- result.domain = "quant"
541
-
542
- reply = _compose_reply(
543
- result=result,
544
- intent=resolved_intent,
545
- reveal_answer=reveal_answer,
546
- verbosity=verbosity,
547
- category=inferred_category,
548
- question_type=question_type,
549
  )
550
-
551
- allow_retrieval = should_retrieve(
552
- intent=resolved_intent,
553
- solved=bool(result.solved),
554
- raw_user_text=user_text or solver_input,
555
- category=inferred_category,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
  )
557
-
558
- if allow_retrieval and retrieval_context:
559
- filtered = _filter_retrieved_chunks(
560
- chunks=retrieval_context,
561
- intent=resolved_intent,
562
- topic=result.topic,
563
- question_text=solver_input,
564
- question_type=question_type,
565
- )
566
- if filtered:
567
- selected_chunks = filtered
568
- result.used_retrieval = True
569
- result.teaching_chunks = filtered
570
-
571
- elif allow_retrieval and self.retriever is not None:
572
- retrieved = self.retriever.search(
573
- query=_build_retrieval_query(
574
- raw_user_text=user_text,
575
- question_text=solver_input,
576
- intent=resolved_intent,
577
- topic=result.topic,
578
- solved=bool(result.solved),
579
- question_type=question_type,
580
- category=inferred_category,
581
- ),
582
- topic=result.topic or "",
583
- intent=resolved_intent,
584
- k=6,
585
- )
586
- filtered = _filter_retrieved_chunks(
587
- chunks=retrieved,
588
- intent=resolved_intent,
589
- topic=result.topic,
590
- question_text=solver_input,
591
- question_type=question_type,
592
- )
593
- if filtered:
594
- selected_chunks = filtered
595
- result.used_retrieval = True
596
- result.teaching_chunks = filtered
597
-
598
- if selected_chunks and resolved_help_mode != "answer":
599
- reply = f"{reply}\n\nRelevant study notes:\n" + "\n".join(_teaching_lines(selected_chunks))
600
-
601
- if not result.solved and self.generator is not None:
602
- try:
603
- generated = self.generator.generate(
604
- user_text=user_text or solver_input,
605
- intent=resolved_intent,
606
- topic=result.topic,
607
- chat_history=chat_history or [],
608
- )
609
- if generated and generated.strip():
610
- reply = generated.strip()
611
- result.used_generator = True
612
- except Exception:
613
- pass
614
-
615
- reply = format_reply(reply, tone, verbosity, transparency, resolved_help_mode)
616
-
617
- result.reply = reply
618
- result.help_mode = resolved_help_mode
619
- result.meta = {
620
- "intent": resolved_intent,
621
- "question_text": question_text or "",
622
- "options_count": len(options_text or []),
623
- "category": inferred_category,
624
- "question_type": question_type,
625
- "classified_topic": question_topic,
626
- }
627
-
628
- return result
 
1
  from __future__ import annotations
2
 
3
+ import math
4
  import re
5
+ from statistics import mean, median
6
+ from typing import Dict, List, Optional, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ try:
9
+ import sympy as sp
10
+ except Exception:
11
+ sp = None
12
 
13
+ from models import SolverResult
14
+ from utils import clean_math_text, normalize_spaces
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def extract_choices(text: str) -> Dict[str, str]:
18
+ text = text or ""
19
+ matches = list(
20
+ re.finditer(
21
+ r"(?i)\b([A-E])[\)\.:]\s*(.*?)(?=\s+\b[A-E][\)\.:]\s*|$)",
22
+ text,
23
+ )
24
+ )
25
+ return {m.group(1).upper(): normalize_spaces(m.group(2)) for m in matches}
26
 
 
 
 
 
 
 
27
 
28
+ def has_answer_choices(text: str) -> bool:
29
+ return len(extract_choices(text)) >= 3
30
 
 
 
31
 
32
+ def is_quant_question(text: str) -> bool:
33
+ lower = clean_math_text(text).lower()
34
+ keywords = [
35
+ "solve", "equation", "percent", "ratio", "probability", "mean", "median",
36
+ "average", "sum", "difference", "product", "quotient", "triangle", "circle",
37
+ "rectangle", "area", "perimeter", "volume", "algebra", "integer", "divisible",
38
+ "number", "fraction", "decimal", "geometry", "distance", "speed", "work",
39
+ "remainder", "discount",
40
+ ]
41
+ if any(k in lower for k in keywords):
42
+ return True
43
+ if "=" in lower and re.search(r"[a-z]", lower):
44
+ return True
45
+ if re.search(r"\d", lower) and ("?" in lower or has_answer_choices(lower)):
46
+ return True
47
+ return False
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ def _prepare_expression(expr: str) -> str:
51
+ expr = clean_math_text(expr).strip()
52
+ expr = expr.replace("^", "**")
53
+ expr = re.sub(r"(\d)\s*\(", r"\1*(", expr)
54
+ expr = re.sub(r"\)\s*(\d)", r")*\1", expr)
55
+ expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr)
56
+ return expr
57
 
58
 
59
+ def _extract_equation(text: str) -> Optional[str]:
60
+ cleaned = clean_math_text(text)
61
+ if "=" not in cleaned:
62
+ return None
63
 
64
+ patterns = [
65
+ r"([A-Za-z0-9\.\+\-\*/\^\(\)\s]*[a-zA-Z][A-Za-z0-9\.\+\-\*/\^\(\)\s]*=[A-Za-z0-9\.\+\-\*/\^\(\)\s]+)",
66
+ r"([0-9A-Za-z\.\+\-\*/\^\(\)\s]+=[0-9A-Za-z\.\+\-\*/\^\(\)\s]+)",
67
+ ]
68
 
69
+ for pattern in patterns:
70
+ for m in re.finditer(pattern, cleaned):
71
+ candidate = m.group(1).strip()
72
+ if re.search(r"[a-z]", candidate.lower()) and not candidate.lower().startswith(
73
+ ("how do", "can you", "please", "what is", "solve ")
74
+ ):
75
+ return candidate
76
+
77
+ eq_index = cleaned.find("=")
78
+ left = re.findall(r"[A-Za-z0-9\.\+\-\*/\^\(\)\s]+$", cleaned[:eq_index])
79
+ right = re.findall(r"^[A-Za-z0-9\.\+\-\*/\^\(\)\s]+", cleaned[eq_index + 1:])
80
+ if left and right:
81
+ candidate = left[0].strip().split()[-1] + " = " + right[0].strip().split()[0]
82
+ if re.search(r"[a-z]", candidate.lower()):
83
+ return candidate
84
+ return None
85
+
86
+
87
+ def _parse_number(text: str) -> Optional[float]:
88
+ raw = clean_math_text(text).strip().lower()
89
+
90
+ pct = re.fullmatch(r"(-?\d+(?:\.\d+)?)%", raw.replace(" ", ""))
91
+ if pct:
92
+ return float(pct.group(1)) / 100.0
93
+
94
+ frac = re.fullmatch(r"(-?\d+)\s*/\s*(-?\d+)", raw)
95
+ if frac:
96
+ den = float(frac.group(2))
97
+ if den == 0:
98
+ return None
99
+ return float(frac.group(1)) / den
100
+
101
+ try:
102
+ return float(
103
+ eval(
104
+ _prepare_expression(raw),
105
+ {"__builtins__": {}},
106
+ {"sqrt": math.sqrt, "pi": math.pi},
107
+ )
108
+ )
109
+ except Exception:
110
+ return None
111
+
112
+
113
+ def _best_choice(answer_value: float, choices: Dict[str, str]) -> Optional[str]:
114
+ best_letter = None
115
+ best_diff = float("inf")
116
+
117
+ for letter, raw in choices.items():
118
+ parsed = _parse_number(raw)
119
+ if parsed is None:
120
+ continue
121
+ diff = abs(parsed - answer_value)
122
+ if diff < best_diff:
123
+ best_diff = diff
124
+ best_letter = letter
125
+
126
+ if best_letter is not None and best_diff <= 1e-6:
127
+ return best_letter
128
+ return None
129
+
130
+
131
+ def _make_result(
132
+ *,
133
+ topic: str,
134
+ answer_value: str,
135
+ internal_answer: Optional[str] = None,
136
+ steps: Optional[List[str]] = None,
137
+ choices_text: str = "",
138
+ ) -> SolverResult:
139
+ answer_float = _parse_number(answer_value)
140
+ choices = extract_choices(choices_text)
141
+ answer_letter = _best_choice(answer_float, choices) if (answer_float is not None and choices) else None
142
+
143
+ return SolverResult(
144
+ domain="quant",
145
+ solved=True,
146
+ topic=topic,
147
+ answer_value=answer_value,
148
+ answer_letter=answer_letter,
149
+ internal_answer=internal_answer or answer_value,
150
+ steps=steps or [],
151
+ )
152
+
153
+
154
+ def _solve_successive_percent(text: str) -> Optional[SolverResult]:
155
+ lower = clean_math_text(text).lower()
156
+
157
+ pattern = re.findall(
158
+ r"(increase|decrease|discount|mark(?:ed)?\s*up|mark(?:ed)?\s*down|rise|fall)\s+by\s+(\d+(?:\.\d+)?)\s*(?:%|percent)",
159
+ lower,
160
+ )
161
+ if len(pattern) < 2:
162
+ pattern = re.findall(
163
+ r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+(increase|decrease|discount|rise|fall)",
164
+ lower,
165
+ )
166
+ pattern = [(op, pct) for pct, op in pattern]
167
+
168
+ if len(pattern) < 2:
169
+ return None
170
+
171
+ multiplier = 1.0
172
+ step_lines: List[str] = []
173
+
174
+ for op, pct_raw in pattern:
175
+ pct = float(pct_raw)
176
+ if any(k in op for k in ["decrease", "discount", "down", "fall"]):
177
+ factor = 1 - pct / 100.0
178
+ step_lines.append(f"A {pct:g}% decrease means multiply by {factor:g}.")
179
+ else:
180
+ factor = 1 + pct / 100.0
181
+ step_lines.append(f"A {pct:g}% increase means multiply by {factor:g}.")
182
+ multiplier *= factor
183
+
184
+ net_change = (multiplier - 1.0) * 100.0
185
+ direction = "increase" if net_change >= 0 else "decrease"
186
+ magnitude = abs(net_change)
187
+
188
+ return _make_result(
189
+ topic="percent",
190
+ answer_value=f"{magnitude:g}%",
191
+ internal_answer=f"net {direction} of {magnitude:g}%",
192
+ steps=step_lines + [f"The combined multiplier gives a net {direction} of {magnitude:g}%."],
193
+ choices_text=text,
194
+ )
195
+
196
+
197
+ def _extract_ratio_labels(text: str) -> Optional[Tuple[str, str]]:
198
+ m = re.search(r"ratio of ([a-z ]+?) to ([a-z ]+?) is \d+\s*:\s*\d+", text.lower())
199
+ if not m:
200
+ return None
201
+ left = normalize_spaces(m.group(1)).rstrip("s")
202
+ right = normalize_spaces(m.group(2)).rstrip("s")
203
+ return left, right
204
+
205
+
206
+ def _solve_ratio_total(text: str) -> Optional[SolverResult]:
207
+ lower = clean_math_text(text).lower()
208
+
209
+ ratio_match = re.search(r"(\d+)\s*:\s*(\d+)", lower)
210
+ total_match = re.search(r"(?:total|altogether|in all|sum)\s*(?:is|=|of)?\s*(\d+)", lower)
211
+
212
+ if not ratio_match or not total_match:
213
+ return None
214
+
215
+ a = int(ratio_match.group(1))
216
+ b = int(ratio_match.group(2))
217
+ total = int(total_match.group(1))
218
+
219
+ part_sum = a + b
220
+ if part_sum == 0:
221
+ return None
222
+
223
+ unit = total / part_sum
224
+ left_value = a * unit
225
+ right_value = b * unit
226
+
227
+ labels = _extract_ratio_labels(lower)
228
+ requested_value = left_value
229
+ requested_label = "first quantity"
230
+
231
+ if labels:
232
+ left_label, right_label = labels
233
+ if left_label in lower and re.search(rf"how many {re.escape(left_label)}", lower):
234
+ requested_value = left_value
235
+ requested_label = left_label
236
+ elif right_label in lower and re.search(rf"how many {re.escape(right_label)}", lower):
237
+ requested_value = right_value
238
+ requested_label = right_label
239
+ else:
240
+ requested_value = left_value
241
+ requested_label = left_label
242
+
243
+ return _make_result(
244
+ topic="ratio",
245
+ answer_value=f"{requested_value:g}",
246
+ internal_answer=f"{requested_label} = {requested_value:g}",
247
+ steps=[
248
+ f"Add the ratio parts: {a} + {b} = {part_sum}.",
249
+ f"Each ratio unit is {total} / {part_sum} = {unit:g}.",
250
+ f"Multiply by the required ratio part to get {requested_value:g}.",
251
+ ],
252
+ choices_text=text,
253
+ )
254
+
255
+
256
+ def _solve_remainder(text: str) -> Optional[SolverResult]:
257
+ lower = clean_math_text(text).lower()
258
+
259
+ m = re.search(r"remainder .*? when (\d+) is divided by (\d+)", lower)
260
+ if not m:
261
+ m = re.search(r"(\d+)\s*(?:mod|%)\s*(\d+)", lower)
262
+ if not m:
263
+ return None
264
+
265
+ a = int(m.group(1))
266
+ b = int(m.group(2))
267
+ if b == 0:
268
+ return None
269
+
270
+ r = a % b
271
+
272
+ return _make_result(
273
+ topic="number_theory",
274
+ answer_value=str(r),
275
+ internal_answer=str(r),
276
+ steps=[
277
+ f"Divide {a} by {b}.",
278
+ f"The remainder is {a} mod {b} = {r}.",
279
+ ],
280
+ choices_text=text,
281
+ )
282
+
283
+
284
+ def _solve_percent(text: str) -> Optional[SolverResult]:
285
+ lower = clean_math_text(text).lower()
286
+ choices = extract_choices(text)
287
+
288
+ m = re.search(r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(?:a\s+)?number\s+is\s+(\d+(?:\.\d+)?)", lower)
289
+ if m:
290
+ p = float(m.group(1))
291
+ value = float(m.group(2))
292
+ ans = value / (p / 100.0)
293
+ answer_letter = _best_choice(ans, choices) if choices else None
294
+
295
+ return SolverResult(
296
+ domain="quant",
297
+ solved=True,
298
+ topic="percent",
299
+ answer_value=f"{ans:g}",
300
+ answer_letter=answer_letter,
301
+ internal_answer=f"{ans:g}",
302
+ steps=[
303
+ "Let the number be n.",
304
+ f"Write {p}% of n as {p / 100:g}n.",
305
+ f"Set {p / 100:g}n = {value} and solve for n.",
306
+ ],
307
+ )
308
 
309
+ m = re.search(r"what is\s+(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(\d+(?:\.\d+)?)", lower)
310
+ if m:
311
+ p = float(m.group(1))
312
+ n = float(m.group(2))
313
+ ans = p / 100.0 * n
314
+ answer_letter = _best_choice(ans, choices) if choices else None
315
+
316
+ return SolverResult(
317
+ domain="quant",
318
+ solved=True,
319
+ topic="percent",
320
+ answer_value=f"{ans:g}",
321
+ answer_letter=answer_letter,
322
+ internal_answer=f"{ans:g}",
323
+ steps=[
324
+ f"Convert {p}% to {p / 100:g}.",
325
+ f"Multiply by {n}.",
326
+ ],
327
+ )
328
 
329
+ return None
330
 
 
 
 
331
 
332
+ def _solve_mean_median(text: str) -> Optional[SolverResult]:
333
+ lower = clean_math_text(text).lower()
334
+ nums = [float(n) for n in re.findall(r"-?\d+(?:\.\d+)?", lower)]
335
+ if not nums:
336
+ return None
 
 
 
337
 
338
+ if "mean" in lower or "average" in lower:
339
+ ans = mean(nums)
340
+ return SolverResult(
341
+ domain="quant",
342
+ solved=True,
343
+ topic="statistics",
344
+ answer_value=f"{ans:g}",
345
+ internal_answer=f"{ans:g}",
346
+ steps=["Add the values.", f"Divide by {len(nums)}."],
347
+ )
348
 
349
+ if "median" in lower:
350
+ ans = median(nums)
351
+ return SolverResult(
352
+ domain="quant",
353
+ solved=True,
354
+ topic="statistics",
355
+ answer_value=f"{ans:g}",
356
+ internal_answer=f"{ans:g}",
357
+ steps=["Order the values.", "Take the middle value."],
358
+ )
359
 
360
+ return None
 
361
 
 
 
362
 
363
+ def _solve_linear_equation(text: str) -> Optional[SolverResult]:
364
+ if sp is None:
365
+ return None
366
 
367
+ expr = _extract_equation(text)
368
+ if not expr:
369
+ return None
370
 
371
+ try:
372
+ lhs, rhs = expr.split("=", 1)
373
+ symbols = sorted(set(re.findall(r"\b[a-z]\b", expr)))
374
+ if not symbols:
375
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
+ var_name = symbols[0]
378
+ var = sp.symbols(var_name)
379
+ sol = sp.solve(
380
+ sp.Eq(sp.sympify(_prepare_expression(lhs)), sp.sympify(_prepare_expression(rhs))),
381
+ var,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  )
383
+ if not sol:
384
+ return None
385
+
386
+ value = sol[0]
387
+ try:
388
+ as_float = float(value)
389
+ except Exception:
390
+ as_float = None
391
+
392
+ choices = extract_choices(text)
393
+
394
+ return SolverResult(
395
+ domain="quant",
396
+ solved=True,
397
+ topic="algebra",
398
+ answer_value=str(value),
399
+ answer_letter=_best_choice(as_float, choices) if (as_float is not None and choices) else None,
400
+ internal_answer=f"{var_name} = {value}",
401
+ steps=[
402
+ "Treat the statement as an equation.",
403
+ "Undo operations on both sides to isolate the variable.",
404
+ f"That gives {var_name} = {value}.",
405
+ ],
406
  )
407
+ except Exception:
408
+ return None
409
+
410
+
411
+ def solve_quant(text: str) -> SolverResult:
412
+ text = text or ""
413
+
414
+ for fn in (
415
+ _solve_successive_percent,
416
+ _solve_ratio_total,
417
+ _solve_remainder,
418
+ _solve_percent,
419
+ _solve_mean_median,
420
+ _solve_linear_equation,
421
+ ):
422
+ result = fn(text)
423
+ if result is not None:
424
+ return result
425
+
426
+ return SolverResult(
427
+ domain="quant",
428
+ solved=False,
429
+ topic="general_quant",
430
+ reply="This looks quantitative, but it does not match a strong rule-based pattern yet.",
431
+ steps=[
432
+ "Identify the quantity the question wants.",
433
+ "Translate the wording into an equation, ratio, or diagram.",
434
+ "Carry out the calculation carefully.",
435
+ ],
436
+ )