j-js commited on
Commit
b45f0ac
·
verified ·
1 Parent(s): 9bed299

Update solver_standard_deviation.py

Browse files
Files changed (1) hide show
  1. solver_standard_deviation.py +498 -44
solver_standard_deviation.py CHANGED
@@ -1,69 +1,523 @@
1
  from __future__ import annotations
2
 
 
3
  import re
4
  from statistics import pstdev
5
- from typing import Optional, List
6
 
7
  from models import SolverResult
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def _nums(text: str) -> List[float]:
11
- return [float(x) for x in re.findall(r"-?\d+(?:\.\d+)?", text)]
12
 
13
 
14
- def solve_standard_deviation(text: str) -> Optional[SolverResult]:
15
- lower = (text or "").lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- if "standard deviation" not in lower and "std dev" not in lower:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  return None
19
 
20
- nums = _nums(lower)
 
 
 
 
 
 
21
 
22
- # Conceptual rule: adding same constant does not change SD
23
- if any(p in lower for p in ["add the same", "increased by the same", "decreased by the same", "plus a constant"]):
24
- result = "unchanged"
25
- return SolverResult(
26
- domain="quant",
27
  solved=True,
28
- topic="standard_deviation",
29
- answer_value=result,
30
- internal_answer=result,
31
  steps=[
32
- "Adding or subtracting the same constant shifts all values equally.",
33
- "The spread does not change, so standard deviation is unchanged.",
 
34
  ],
35
  )
36
 
37
- # Conceptual rule: multiplying all by a constant scales SD
38
- if any(p in lower for p in ["multiplied by", "scaled by", "increase by the same percent", "decrease by the same percent"]):
39
- m = re.search(r"(multiplied by|scaled by)\s*(-?\d+(?:\.\d+)?)", lower)
40
- if m:
41
- factor = abs(float(m.group(2)))
42
- return SolverResult(
43
- domain="quant",
44
- solved=True,
45
- topic="standard_deviation",
46
- answer_value=f"multiplied by {factor:g}",
47
- internal_answer=f"multiplied by {factor:g}",
48
- steps=[
49
- "When every value is multiplied by a constant, standard deviation is multiplied by the absolute value of that constant.",
50
- ],
51
- )
52
-
53
- # Numeric SD if list-like
54
- if len(nums) >= 2:
55
- result = pstdev(nums)
56
- return SolverResult(
57
- domain="quant",
58
  solved=True,
59
- topic="standard_deviation",
60
- answer_value=f"{result:g}",
61
- internal_answer=f"{result:g}",
62
  steps=[
63
- "Find the mean.",
64
- "Compute squared deviations from the mean.",
65
- "Average them and take the square root.",
66
  ],
67
  )
68
 
69
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import math
4
  import re
5
  from statistics import pstdev
6
+ from typing import Optional, List, Tuple
7
 
8
  from models import SolverResult
9
 
10
 
11
+ # -----------------------------
12
+ # Basic parsing helpers
13
+ # -----------------------------
14
+
15
+ _NUMBER_RE = r"-?\d+(?:\.\d+)?"
16
+
17
+ _STD_PHRASES = [
18
+ "standard deviation",
19
+ "std dev",
20
+ "std. dev",
21
+ "stdev",
22
+ "sd ",
23
+ " s.d.",
24
+ ]
25
+
26
+ _COMPARE_WORDS = [
27
+ "greater",
28
+ "larger",
29
+ "higher",
30
+ "smaller",
31
+ "lower",
32
+ "less",
33
+ "same",
34
+ "equal",
35
+ "compare",
36
+ "comparison",
37
+ ]
38
+
39
+ _SET_LABEL_RE = re.compile(
40
+ rf"""
41
+ (?:
42
+ \b([A-Z])\b\s*[:=]\s* # A: 1,2,3
43
+ |
44
+ \bset\s+([A-Z])\b\s*[:=]?\s* # Set A: 1,2,3
45
+ |
46
+ \bgroup\s+([A-Z])\b\s*[:=]?\s* # Group A: 1,2,3
47
+ )
48
+ ([^\n;|]+)
49
+ """,
50
+ re.IGNORECASE | re.VERBOSE,
51
+ )
52
+
53
+
54
+ def _clean(text: str) -> str:
55
+ return re.sub(r"\s+", " ", (text or "").strip().lower())
56
+
57
+
58
  def _nums(text: str) -> List[float]:
59
+ return [float(x) for x in re.findall(_NUMBER_RE, text)]
60
 
61
 
62
+ def _is_close(a: float, b: float, tol: float = 1e-9) -> bool:
63
+ return abs(a - b) <= tol
64
+
65
+
66
+ def _all_equal(vals: List[float]) -> bool:
67
+ return bool(vals) and all(_is_close(v, vals[0]) for v in vals)
68
+
69
+
70
+ def _mean(vals: List[float]) -> float:
71
+ return sum(vals) / len(vals)
72
+
73
+
74
+ def _spread_score(vals: List[float]) -> float:
75
+ """
76
+ Cheap comparison proxy for spread. For same-length sets,
77
+ pstdev is best, but this helper can still support quick comparisons.
78
+ """
79
+ if not vals:
80
+ return 0.0
81
+ return pstdev(vals)
82
+
83
+
84
+ def _safe_number_text(x: float) -> str:
85
+ if _is_close(x, round(x)):
86
+ return str(int(round(x)))
87
+ return f"{x:.6g}"
88
+
89
+
90
+ def _mentions_standard_deviation(lower: str) -> bool:
91
+ return any(p in lower for p in _STD_PHRASES)
92
+
93
+
94
+ def _mentions_variability(lower: str) -> bool:
95
+ return any(
96
+ p in lower
97
+ for p in [
98
+ "spread",
99
+ "more spread out",
100
+ "less spread out",
101
+ "dispersion",
102
+ "variability",
103
+ "variation",
104
+ ]
105
+ )
106
+
107
+
108
+ def _extract_labeled_sets(text: str) -> List[Tuple[str, List[float]]]:
109
+ sets: List[Tuple[str, List[float]]] = []
110
+ for m in _SET_LABEL_RE.finditer(text):
111
+ label = (m.group(1) or m.group(2) or "").upper()
112
+ body = m.group(3)
113
+ nums = _nums(body)
114
+ if len(nums) >= 2:
115
+ sets.append((label, nums))
116
+ return sets
117
+
118
+
119
+ def _extract_braced_sets(text: str) -> List[List[float]]:
120
+ groups = re.findall(r"\{([^{}]+)\}|\(([^()]+)\)|\[([^\[\]]+)\]", text)
121
+ out: List[List[float]] = []
122
+ for g in groups:
123
+ body = next((part for part in g if part), "")
124
+ nums = _nums(body)
125
+ if len(nums) >= 2:
126
+ out.append(nums)
127
+ return out
128
+
129
+
130
+ def _describe_shift_rule() -> List[str]:
131
+ return [
132
+ "Adding or subtracting the same constant shifts every value equally.",
133
+ "That changes the center, but not the spread.",
134
+ "So the standard deviation stays unchanged.",
135
+ ]
136
+
137
 
138
+ def _describe_scale_rule(factor: float) -> List[str]:
139
+ return [
140
+ "Multiplying or dividing every value rescales every distance from the mean by the same factor.",
141
+ f"So the standard deviation is multiplied by |{_safe_number_text(factor)}|.",
142
+ "The key idea is that spread scales with the absolute value of the multiplier.",
143
+ ]
144
+
145
+
146
+ def _build_result(
147
+ *,
148
+ solved: bool,
149
+ internal_answer: Optional[str],
150
+ steps: List[str],
151
+ answer_value: Optional[str] = None,
152
+ ) -> SolverResult:
153
+ # Keep answer_value intentionally non-revealing for direct numeric solves.
154
+ return SolverResult(
155
+ domain="quant",
156
+ solved=solved,
157
+ topic="standard_deviation",
158
+ answer_value=answer_value if answer_value is not None else "computed internally",
159
+ internal_answer=internal_answer,
160
+ steps=steps,
161
+ )
162
+
163
+
164
+ # -----------------------------
165
+ # Pattern detectors
166
+ # -----------------------------
167
+
168
+ def _detect_add_sub_constant(lower: str) -> bool:
169
+ return any(
170
+ p in lower
171
+ for p in [
172
+ "add the same",
173
+ "added the same",
174
+ "increased by the same",
175
+ "decreased by the same",
176
+ "plus a constant",
177
+ "minus a constant",
178
+ "subtract the same",
179
+ "subtracted the same",
180
+ "add 5 to every",
181
+ "subtract 5 from every",
182
+ "each value is increased by",
183
+ "each value is decreased by",
184
+ "every value is increased by",
185
+ "every value is decreased by",
186
+ ]
187
+ )
188
+
189
+
190
+ def _detect_scaling(lower: str) -> Optional[float]:
191
+ patterns = [
192
+ r"(?:multiplied by|scaled by|times)\s*(" + _NUMBER_RE + r")",
193
+ r"(?:each|every)\s+value\s+(?:is\s+)?multiplied\s+by\s*(" + _NUMBER_RE + r")",
194
+ r"(?:each|every)\s+value\s+(?:is\s+)?divided\s+by\s*(" + _NUMBER_RE + r")",
195
+ ]
196
+
197
+ for pat in patterns:
198
+ m = re.search(pat, lower)
199
+ if m:
200
+ val = float(m.group(1))
201
+ if "divided by" in m.group(0):
202
+ if not _is_close(val, 0.0):
203
+ return 1.0 / val
204
+ return val
205
+
206
+ # Percent scaling language
207
+ m = re.search(r"(increase|decrease)\s+by\s+(\d+(?:\.\d+)?)\s*percent", lower)
208
+ if m:
209
+ pct = float(m.group(2)) / 100.0
210
+ if m.group(1) == "increase":
211
+ return 1.0 + pct
212
+ return 1.0 - pct
213
+
214
+ return None
215
+
216
+
217
+ def _detect_zero_sd_prompt(lower: str) -> bool:
218
+ return any(
219
+ p in lower
220
+ for p in [
221
+ "standard deviation is 0",
222
+ "std dev is 0",
223
+ "zero standard deviation",
224
+ "when is the standard deviation zero",
225
+ ]
226
+ )
227
+
228
+
229
+ def _detect_outlier_prompt(lower: str) -> bool:
230
+ return "outlier" in lower or "extreme value" in lower
231
+
232
+
233
+ def _detect_same_mean_diff_spread(lower: str) -> bool:
234
+ return (
235
+ ("same mean" in lower or "equal mean" in lower)
236
+ and any(p in lower for p in ["more spread", "less spread", "farther from the mean", "closer to the mean"])
237
+ )
238
+
239
+
240
+ def _detect_compare_sets(lower: str) -> bool:
241
+ return any(w in lower for w in _COMPARE_WORDS) and (
242
+ "set" in lower or "group" in lower or "list" in lower or "data set" in lower
243
+ )
244
+
245
+
246
+ # -----------------------------
247
+ # Solver blocks
248
+ # -----------------------------
249
+
250
+ def _solve_conceptual_constant_shift(lower: str) -> Optional[SolverResult]:
251
+ if not _detect_add_sub_constant(lower):
252
+ return None
253
+
254
+ return _build_result(
255
+ solved=True,
256
+ answer_value="unchanged",
257
+ internal_answer="unchanged",
258
+ steps=_describe_shift_rule(),
259
+ )
260
+
261
+
262
+ def _solve_conceptual_scaling(lower: str) -> Optional[SolverResult]:
263
+ factor = _detect_scaling(lower)
264
+ if factor is None:
265
  return None
266
 
267
+ return _build_result(
268
+ solved=True,
269
+ answer_value=f"scaled by |{_safe_number_text(factor)}|",
270
+ internal_answer=f"scaled by |{_safe_number_text(factor)}|",
271
+ steps=_describe_scale_rule(factor),
272
+ )
273
+
274
 
275
+ def _solve_zero_standard_deviation(lower: str, nums: List[float]) -> Optional[SolverResult]:
276
+ if nums and _all_equal(nums):
277
+ return _build_result(
 
 
278
  solved=True,
279
+ answer_value="zero",
280
+ internal_answer="0",
 
281
  steps=[
282
+ "All values are identical, so every value is exactly at the mean.",
283
+ "That means every deviation from the mean is 0.",
284
+ "So the standard deviation is 0.",
285
  ],
286
  )
287
 
288
+ if _detect_zero_sd_prompt(lower):
289
+ return _build_result(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  solved=True,
291
+ answer_value="all values equal",
292
+ internal_answer="standard deviation is zero exactly when all values are equal",
 
293
  steps=[
294
+ "Standard deviation measures how far values are from the mean.",
295
+ "It is zero only when every value has zero distance from the mean.",
296
+ "That happens exactly when all values are the same.",
297
  ],
298
  )
299
 
300
+ return None
301
+
302
+
303
+ def _solve_outlier_concept(lower: str) -> Optional[SolverResult]:
304
+ if not _detect_outlier_prompt(lower):
305
+ return None
306
+
307
+ return _build_result(
308
+ solved=True,
309
+ answer_value="typically increases",
310
+ internal_answer="adding or making an outlier more extreme typically increases standard deviation",
311
+ steps=[
312
+ "Standard deviation increases when values lie farther from the mean.",
313
+ "An outlier is an unusually distant value, so it usually increases spread.",
314
+ "So introducing a more extreme outlier typically increases the standard deviation.",
315
+ ],
316
+ )
317
+
318
+
319
+ def _solve_labeled_set_comparison(text: str, lower: str) -> Optional[SolverResult]:
320
+ sets = _extract_labeled_sets(text)
321
+
322
+ if len(sets) < 2:
323
+ return None
324
+ if not (_detect_compare_sets(lower) or _mentions_standard_deviation(lower) or _mentions_variability(lower)):
325
+ return None
326
+
327
+ scored = [(label, vals, _spread_score(vals)) for label, vals in sets]
328
+ scored_sorted = sorted(scored, key=lambda t: t[2])
329
+
330
+ smallest = scored_sorted[0]
331
+ largest = scored_sorted[-1]
332
+
333
+ if _is_close(smallest[2], largest[2]):
334
+ answer = "equal"
335
+ internal = "equal standard deviation"
336
+ steps = [
337
+ "Compare how far each set’s values lie from its own mean.",
338
+ "After measuring the spreads, the sets have equal spread.",
339
+ "So their standard deviations are equal.",
340
+ ]
341
+ else:
342
+ wants_small = any(w in lower for w in ["smaller", "lower", "less"])
343
+ chosen = smallest if wants_small else largest
344
+ answer = chosen[0]
345
+ internal = chosen[0]
346
+ steps = [
347
+ "For comparison questions, focus on spread rather than just the mean.",
348
+ "The set whose values sit farther from its mean has the larger standard deviation.",
349
+ f"Internal comparison identifies set {chosen[0]} as the correct choice.",
350
+ ]
351
+
352
+ return _build_result(
353
+ solved=True,
354
+ answer_value=answer,
355
+ internal_answer=internal,
356
+ steps=steps,
357
+ )
358
+
359
+
360
+ def _solve_braced_set_comparison(text: str, lower: str) -> Optional[SolverResult]:
361
+ sets = _extract_braced_sets(text)
362
+ if len(sets) != 2:
363
+ return None
364
+ if not (_detect_compare_sets(lower) or "which" in lower):
365
+ return None
366
+
367
+ s1 = _spread_score(sets[0])
368
+ s2 = _spread_score(sets[1])
369
+
370
+ if _is_close(s1, s2):
371
+ answer = "equal"
372
+ internal = "equal standard deviation"
373
+ else:
374
+ wants_small = any(w in lower for w in ["smaller", "lower", "less"])
375
+ if wants_small:
376
+ answer = "first set" if s1 < s2 else "second set"
377
+ internal = answer
378
+ else:
379
+ answer = "first set" if s1 > s2 else "second set"
380
+ internal = answer
381
+
382
+ return _build_result(
383
+ solved=True,
384
+ answer_value=answer,
385
+ internal_answer=internal,
386
+ steps=[
387
+ "Compare distance from each set’s mean, not just the raw values.",
388
+ "The more spread-out set has the larger standard deviation.",
389
+ "The choice above is determined internally from that spread comparison.",
390
+ ],
391
+ )
392
+
393
+
394
+ def _solve_same_mean_spread_concept(lower: str) -> Optional[SolverResult]:
395
+ if not _detect_same_mean_diff_spread(lower):
396
+ return None
397
+
398
+ return _build_result(
399
+ solved=True,
400
+ answer_value="the more spread-out set",
401
+ internal_answer="with same mean, the more spread-out set has larger standard deviation",
402
+ steps=[
403
+ "If two sets have the same mean, standard deviation depends on how far values sit from that mean.",
404
+ "Values farther from the mean create larger deviations.",
405
+ "So the more spread-out set has the larger standard deviation.",
406
+ ],
407
+ )
408
+
409
+
410
+ def _solve_symmetric_spacing_concept(text: str, lower: str) -> Optional[SolverResult]:
411
+ # Lightweight conceptual handling for classic GMAT patterns such as:
412
+ # {m-d, m, m+d} vs {m-2d, m, m+2d}
413
+ if "equally spaced" not in lower and "symmetric" not in lower and "centered at" not in lower:
414
+ return None
415
+
416
+ nums = _nums(text)
417
+ if len(nums) < 3:
418
+ return None
419
+
420
+ return _build_result(
421
+ solved=True,
422
+ answer_value="greater spacing means greater SD",
423
+ internal_answer="for symmetric equally spaced sets, larger common distance from center means larger SD",
424
+ steps=[
425
+ "For symmetric sets, the mean is the center point.",
426
+ "Standard deviation is driven by how far the outer values are from that center.",
427
+ "So if one set has larger equal spacing from the center, it has the larger standard deviation.",
428
+ ],
429
+ )
430
+
431
+
432
+ def _solve_direct_numeric(nums: List[float], lower: str) -> Optional[SolverResult]:
433
+ if len(nums) < 2:
434
+ return None
435
+
436
+ # Avoid hijacking transformation questions that happen to include numbers.
437
+ if _detect_add_sub_constant(lower) or _detect_scaling(lower) is not None:
438
+ return None
439
+
440
+ sd = pstdev(nums)
441
+
442
+ return _build_result(
443
+ solved=True,
444
+ answer_value="computed internally",
445
+ internal_answer=_safe_number_text(sd),
446
+ steps=[
447
+ "Find the mean of the data set.",
448
+ "Measure each value’s distance from the mean and square those distances.",
449
+ "Average those squared deviations, then take the square root.",
450
+ "The exact numeric standard deviation has been computed internally.",
451
+ ],
452
+ )
453
+
454
+
455
+ # -----------------------------
456
+ # Public solver
457
+ # -----------------------------
458
+
459
+ def solve_standard_deviation(text: str) -> Optional[SolverResult]:
460
+ lower = _clean(text)
461
+
462
+ if not (
463
+ _mentions_standard_deviation(lower)
464
+ or _mentions_variability(lower)
465
+ or "variance" in lower
466
+ or "outlier" in lower
467
+ ):
468
+ return None
469
+
470
+ nums = _nums(text)
471
+
472
+ # 1. Core conceptual transformations
473
+ for block in (
474
+ _solve_conceptual_constant_shift,
475
+ _solve_conceptual_scaling,
476
+ ):
477
+ result = block(lower)
478
+ if result is not None:
479
+ return result
480
+
481
+ # 2. Zero / all-equal concept
482
+ result = _solve_zero_standard_deviation(lower, nums)
483
+ if result is not None:
484
+ return result
485
+
486
+ # 3. Outlier concept
487
+ result = _solve_outlier_concept(lower)
488
+ if result is not None:
489
+ return result
490
+
491
+ # 4. Comparison-style questions
492
+ result = _solve_labeled_set_comparison(text, lower)
493
+ if result is not None:
494
+ return result
495
+
496
+ result = _solve_braced_set_comparison(text, lower)
497
+ if result is not None:
498
+ return result
499
+
500
+ result = _solve_same_mean_spread_concept(lower)
501
+ if result is not None:
502
+ return result
503
+
504
+ result = _solve_symmetric_spacing_concept(text, lower)
505
+ if result is not None:
506
+ return result
507
+
508
+ # 5. Exact numeric computation from a visible list
509
+ result = _solve_direct_numeric(nums, lower)
510
+ if result is not None:
511
+ return result
512
+
513
+ # 6. Fallback conceptual explanation
514
+ return _build_result(
515
+ solved=False,
516
+ answer_value="not fully resolved",
517
+ internal_answer=None,
518
+ steps=[
519
+ "This looks like a standard deviation question, so focus on spread around the mean.",
520
+ "Check whether the task is about a transformation, a comparison of spreads, or an exact computation.",
521
+ "If you want exact solving coverage for a missed pattern, add a dedicated parsing block for that wording.",
522
+ ],
523
+ )