File size: 22,859 Bytes
80d8c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
"""Tests for JDG 01–06 scoring functions."""

from __future__ import annotations

from replicalab.agents.lab_manager_policy import check_feasibility
from replicalab.models import Protocol, RewardBreakdown
from replicalab.scenarios import generate_scenario
from replicalab.scenarios.templates import AllowedSubstitution, HiddenReferenceSpec
from replicalab.scoring import (
    build_reward_breakdown,
    compute_total_reward,
    explain_reward,
    score_feasibility,
    score_fidelity,
    score_rigor,
)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _scenario(template: str = "ml_benchmark", difficulty: str = "easy"):
    return generate_scenario(seed=42, template=template, difficulty=difficulty)


def _good_protocol(scenario) -> Protocol:
    """Build a well-formed protocol aligned to the scenario."""
    lab = scenario.lab_manager_observation
    spec = scenario.hidden_reference_spec
    return Protocol(
        sample_size=10,
        controls=["baseline", "ablation"],
        technique=spec.summary[:60] if spec.summary else "replication_plan",
        duration_days=max(1, min(2, lab.time_limit_days)),
        required_equipment=(
            list(lab.equipment_available[:1])
            if lab.equipment_available
            else []
        ),
        required_reagents=(
            list(lab.reagents_in_stock[:1])
            if lab.reagents_in_stock
            else []
        ),
        rationale=(
            f"Plan addresses: {', '.join(spec.required_elements[:2])}. "
            f"Target metric: {spec.target_metric}. "
            f"Target value: {spec.target_value}. "
            "Stay within budget and schedule."
        ),
    )


def _bad_protocol() -> Protocol:
    """Build a minimal protocol that misses most requirements."""
    return Protocol(
        sample_size=1,
        controls=[],
        technique="unknown_method",
        duration_days=1,
        required_equipment=[],
        required_reagents=[],
        rationale="No plan.",
    )


def _awful_protocol(scenario) -> Protocol:
    """Build a structurally weak and clearly infeasible protocol."""
    return Protocol(
        sample_size=200,
        controls=[],
        technique="imaginary_method",
        duration_days=scenario.lab_manager_observation.time_limit_days + 5,
        required_equipment=["Imaginary Device"],
        required_reagents=["Imaginary Reagent"],
        rationale="No.",
    )


# ---------------------------------------------------------------------------
# JDG 01 — score_rigor
# ---------------------------------------------------------------------------


def test_rigor_good_protocol_scores_higher_than_bad() -> None:
    scenario = _scenario("ml_benchmark", "easy")
    good = _good_protocol(scenario)
    bad = _bad_protocol()

    good_score = score_rigor(good, scenario)
    bad_score = score_rigor(bad, scenario)

    assert good_score > bad_score
    assert 0.0 <= good_score <= 1.0
    assert 0.0 <= bad_score <= 1.0


def test_rigor_is_deterministic() -> None:
    scenario = _scenario("ml_benchmark", "medium")
    protocol = _good_protocol(scenario)

    first = score_rigor(protocol, scenario)
    second = score_rigor(protocol, scenario)

    assert first == second


def test_rigor_empty_controls_reduces_score() -> None:
    scenario = _scenario("math_reasoning", "easy")
    with_controls = _good_protocol(scenario)
    without_controls = with_controls.model_copy(update={"controls": ["only_one"]})

    score_with = score_rigor(with_controls, scenario)
    score_without = score_rigor(without_controls, scenario)

    assert score_with >= score_without


def test_rigor_short_rationale_reduces_score() -> None:
    scenario = _scenario("finance_trading", "easy")
    good = _good_protocol(scenario)
    short = good.model_copy(update={"rationale": "OK."})

    assert score_rigor(good, scenario) > score_rigor(short, scenario)


def test_rigor_all_domains_return_valid_range() -> None:
    for template in ("ml_benchmark", "math_reasoning", "finance_trading"):
        for difficulty in ("easy", "medium", "hard"):
            scenario = generate_scenario(seed=99, template=template, difficulty=difficulty)
            protocol = _good_protocol(scenario)
            score = score_rigor(protocol, scenario)
            assert 0.0 <= score <= 1.0, f"{template}/{difficulty}: {score}"


# ---------------------------------------------------------------------------
# JDG 02 — score_feasibility
# ---------------------------------------------------------------------------


def test_feasibility_viable_protocol_scores_high() -> None:
    scenario = _scenario("ml_benchmark", "easy")
    protocol = _good_protocol(scenario)

    score = score_feasibility(protocol, scenario)

    assert score > 0.7
    assert 0.0 <= score <= 1.0


def test_feasibility_infeasible_protocol_scores_lower() -> None:
    scenario = _scenario("ml_benchmark", "easy")
    good = _good_protocol(scenario)
    # Blow the budget and schedule
    bad = good.model_copy(update={
        "sample_size": 200,
        "duration_days": scenario.lab_manager_observation.time_limit_days + 5,
        "required_equipment": ["Imaginary Device"],
    })

    good_score = score_feasibility(good, scenario)
    bad_score = score_feasibility(bad, scenario)

    assert good_score > bad_score


def test_feasibility_accepts_precomputed_check() -> None:
    scenario = _scenario("finance_trading", "easy")
    protocol = _good_protocol(scenario)
    check = check_feasibility(protocol, scenario)

    score_with = score_feasibility(protocol, scenario, check=check)
    score_without = score_feasibility(protocol, scenario)

    assert score_with == score_without


def test_feasibility_is_deterministic() -> None:
    scenario = _scenario("math_reasoning", "medium")
    protocol = _good_protocol(scenario)

    first = score_feasibility(protocol, scenario)
    second = score_feasibility(protocol, scenario)

    assert first == second


def test_feasibility_partial_credit_for_near_budget() -> None:
    """A protocol slightly over budget should score higher than one far over."""
    scenario = _scenario("ml_benchmark", "easy")
    good = _good_protocol(scenario)

    slightly_over = good.model_copy(update={"sample_size": 40})
    far_over = good.model_copy(update={"sample_size": 200})

    score_slight = score_feasibility(slightly_over, scenario)
    score_far = score_feasibility(far_over, scenario)

    assert score_slight >= score_far


def test_feasibility_all_domains_return_valid_range() -> None:
    for template in ("ml_benchmark", "math_reasoning", "finance_trading"):
        for difficulty in ("easy", "medium", "hard"):
            scenario = generate_scenario(seed=99, template=template, difficulty=difficulty)
            protocol = _good_protocol(scenario)
            score = score_feasibility(protocol, scenario)
            assert 0.0 <= score <= 1.0, f"{template}/{difficulty}: {score}"


# ---------------------------------------------------------------------------
# JDG 03 — score_fidelity
# ---------------------------------------------------------------------------


def test_fidelity_aligned_protocol_scores_higher() -> None:
    scenario = _scenario("ml_benchmark", "easy")
    aligned = _good_protocol(scenario)
    misaligned = _bad_protocol()

    aligned_score = score_fidelity(aligned, scenario)
    misaligned_score = score_fidelity(misaligned, scenario)

    assert aligned_score > misaligned_score
    assert 0.0 <= aligned_score <= 1.0
    assert 0.0 <= misaligned_score <= 1.0


def test_fidelity_is_deterministic() -> None:
    scenario = _scenario("finance_trading", "hard")
    protocol = _good_protocol(scenario)

    first = score_fidelity(protocol, scenario)
    second = score_fidelity(protocol, scenario)

    assert first == second


def test_fidelity_substitution_gets_partial_credit() -> None:
    """Using an allowed substitution should score better than a total miss."""
    scenario = _scenario("math_reasoning", "easy")
    spec = scenario.hidden_reference_spec

    # Find a required element that has a substitution
    sub_map = {}
    for sub in scenario.allowed_substitutions:
        sub_map[sub.original.lower()] = sub.alternative

    if not sub_map or not spec.required_elements:
        return  # skip if no substitution exists in this scenario

    # Build protocol that uses the substitution alternative
    first_sub_original = list(sub_map.keys())[0]
    first_sub_alt = sub_map[first_sub_original]

    with_sub = _good_protocol(scenario).model_copy(update={
        "rationale": f"We will use {first_sub_alt} instead. " + spec.target_metric,
    })
    without_anything = _bad_protocol()

    score_sub = score_fidelity(with_sub, scenario)
    score_miss = score_fidelity(without_anything, scenario)

    assert score_sub > score_miss


def test_fidelity_mentioning_target_metric_improves_score() -> None:
    scenario = _scenario("ml_benchmark", "easy")
    spec = scenario.hidden_reference_spec

    with_metric = _good_protocol(scenario)
    without_metric = with_metric.model_copy(update={
        "rationale": "Generic plan without any specific metric mentioned.",
    })

    score_with = score_fidelity(with_metric, scenario)
    score_without = score_fidelity(without_metric, scenario)

    assert score_with >= score_without


def test_fidelity_all_domains_return_valid_range() -> None:
    for template in ("ml_benchmark", "math_reasoning", "finance_trading"):
        for difficulty in ("easy", "medium", "hard"):
            scenario = generate_scenario(seed=99, template=template, difficulty=difficulty)
            protocol = _good_protocol(scenario)
            score = score_fidelity(protocol, scenario)
            assert 0.0 <= score <= 1.0, f"{template}/{difficulty}: {score}"


# ---------------------------------------------------------------------------
# Cross-scorer consistency
# ---------------------------------------------------------------------------


def test_all_scores_between_zero_and_one_for_bad_protocol() -> None:
    for template in ("ml_benchmark", "math_reasoning", "finance_trading"):
        scenario = generate_scenario(seed=7, template=template, difficulty="hard")
        bad = _bad_protocol()

        r = score_rigor(bad, scenario)
        fe = score_feasibility(bad, scenario)
        fi = score_fidelity(bad, scenario)

        assert 0.0 <= r <= 1.0, f"rigor {template}: {r}"
        assert 0.0 <= fe <= 1.0, f"feasibility {template}: {fe}"
        assert 0.0 <= fi <= 1.0, f"fidelity {template}: {fi}"


def test_good_protocol_dominates_bad_on_rigor_and_fidelity() -> None:
    """Good protocol beats bad on rigor and fidelity.

    Feasibility is excluded: a protocol that asks for nothing is trivially
    feasible (no equipment, no reagents → nothing can fail).  The other two
    scores correctly penalize an empty plan.
    """
    scenario = _scenario("ml_benchmark", "easy")
    good = _good_protocol(scenario)
    bad = _bad_protocol()

    assert score_rigor(good, scenario) > score_rigor(bad, scenario)
    assert score_fidelity(good, scenario) > score_fidelity(bad, scenario)


def test_good_protocol_beats_awful_protocol_on_all_scores_and_total_reward() -> None:
    """A clearly infeasible and low-quality protocol loses on every judge axis."""
    scenario = _scenario("ml_benchmark", "easy")
    good = _good_protocol(scenario)
    awful = _awful_protocol(scenario)

    good_breakdown = build_reward_breakdown(good, scenario, rounds_used=2, max_rounds=6)
    awful_breakdown = build_reward_breakdown(awful, scenario, rounds_used=2, max_rounds=6)

    assert score_rigor(good, scenario) > score_rigor(awful, scenario)
    assert score_feasibility(good, scenario) > score_feasibility(awful, scenario)
    assert score_fidelity(good, scenario) > score_fidelity(awful, scenario)
    assert compute_total_reward(good_breakdown) > compute_total_reward(awful_breakdown)


def test_rigor_explicit_success_criteria_mentions_improve_score() -> None:
    """Mentioning scenario success criteria should improve rigor coverage."""
    scenario = _scenario("finance_trading", "easy").model_copy(
        update={
            "success_criteria": ["risk-adjusted return", "drawdown control"],
            "hidden_reference_spec": HiddenReferenceSpec(
                summary="risk-aware replication plan",
                required_elements=[],
                flexible_elements=[],
                target_metric="sharpe ratio",
                target_value="> 1.5",
            ),
        }
    )
    generic = _good_protocol(scenario).model_copy(
        update={"rationale": "Follow a generic plan with basic checks."}
    )
    explicit = generic.model_copy(
        update={
            "rationale": (
                "Optimize for risk-adjusted return while preserving drawdown control "
                "through explicit checkpoints."
            )
        }
    )

    assert score_rigor(explicit, scenario) > score_rigor(generic, scenario)


def test_feasibility_partial_equipment_credit_sits_between_full_and_total_miss() -> None:
    """One available requirement should score between full availability and a total miss."""
    scenario = _scenario("ml_benchmark", "easy")
    available = list(scenario.lab_manager_observation.equipment_available)
    assert available, "scenario must expose at least one available equipment item"

    full = _good_protocol(scenario).model_copy(
        update={"required_equipment": [available[0]]}
    )
    partial = full.model_copy(
        update={"required_equipment": [available[0], "Imaginary Device"]}
    )
    miss = full.model_copy(
        update={"required_equipment": ["Imaginary Device", "Missing Device"]}
    )

    full_score = score_feasibility(full, scenario)
    partial_score = score_feasibility(partial, scenario)
    miss_score = score_feasibility(miss, scenario)

    assert full_score > partial_score > miss_score


def test_fidelity_direct_match_beats_substitution_and_miss() -> None:
    """Required-element scoring should prefer direct match > allowed substitution > miss."""
    scenario = _scenario("math_reasoning", "easy").model_copy(
        update={
            "hidden_reference_spec": HiddenReferenceSpec(
                summary="structured proof plan",
                required_elements=["alphaprobe"],
                flexible_elements=[],
                target_metric="accuracy",
                target_value="0.95",
            ),
            "allowed_substitutions": [
                AllowedSubstitution(
                    original="alphaprobe",
                    alternative="betaprobe",
                    condition="when the primary resource is booked",
                    tradeoff="backup sensor is slower",
                )
            ],
        }
    )
    base = Protocol(
        sample_size=10,
        controls=["baseline", "ablation"],
        technique="structured proof plan",
        duration_days=1,
        required_equipment=[],
        required_reagents=[],
        rationale="Target accuracy 0.95 with explicit evaluation.",
    )

    direct = base.model_copy(
        update={"rationale": base.rationale + " Use the alphaprobe."}
    )
    substitution = base.model_copy(
        update={"rationale": base.rationale + " Use the betaprobe."}
    )
    miss = base

    direct_score = score_fidelity(direct, scenario)
    substitution_score = score_fidelity(substitution, scenario)
    miss_score = score_fidelity(miss, scenario)

    assert direct_score > substitution_score > miss_score


# ---------------------------------------------------------------------------
# JDG 04 — compute_total_reward
# ---------------------------------------------------------------------------


def test_total_reward_perfect_beats_broken() -> None:
    """A well-aligned protocol earns a higher total reward than a bad one."""
    scenario = _scenario("ml_benchmark", "easy")
    good = _good_protocol(scenario)
    bad = _bad_protocol()

    good_bd = build_reward_breakdown(good, scenario, rounds_used=1, max_rounds=6)
    bad_bd = build_reward_breakdown(bad, scenario, rounds_used=1, max_rounds=6)

    assert compute_total_reward(good_bd) > compute_total_reward(bad_bd)


def test_zero_feasibility_zeroes_base() -> None:
    """If any component is 0, the multiplicative base is 0."""
    rb = RewardBreakdown(rigor=1.0, feasibility=0.0, fidelity=1.0)
    assert compute_total_reward(rb) == 0.0


def test_efficiency_bonus_higher_when_faster() -> None:
    """Finishing in fewer rounds yields a higher total reward."""
    scenario = _scenario()
    protocol = _good_protocol(scenario)

    fast = build_reward_breakdown(protocol, scenario, rounds_used=1, max_rounds=6)
    slow = build_reward_breakdown(protocol, scenario, rounds_used=5, max_rounds=6)

    assert compute_total_reward(fast) > compute_total_reward(slow)


def test_penalty_subtraction_exact() -> None:
    """Named penalties subtract exactly from the total."""
    rb = RewardBreakdown(
        rigor=1.0,
        feasibility=1.0,
        fidelity=1.0,
        penalties={"invalid_tool_use": 2.0, "unsupported_claim": 0.5},
    )
    total = compute_total_reward(rb)
    assert total == 7.5  # 10*1*1*1 - 2.5


def test_total_reward_clamps_at_zero() -> None:
    """Massive penalties cannot push the total below 0."""
    rb = RewardBreakdown(
        rigor=0.1,
        feasibility=0.1,
        fidelity=0.1,
        penalties={"massive_penalty": 50.0},
    )
    assert compute_total_reward(rb) == 0.0


def test_breakdown_determinism() -> None:
    """Same inputs always produce the same total reward."""
    scenario = _scenario("finance_trading", "medium")
    protocol = _good_protocol(scenario)

    b1 = build_reward_breakdown(protocol, scenario, rounds_used=3, max_rounds=6)
    b2 = build_reward_breakdown(protocol, scenario, rounds_used=3, max_rounds=6)

    assert compute_total_reward(b1) == compute_total_reward(b2)


# ---------------------------------------------------------------------------
# JDG 05 — build_reward_breakdown
# ---------------------------------------------------------------------------


def test_breakdown_accepts_external_penalties() -> None:
    """Callers can inject named penalty keys via the penalties parameter."""
    scenario = _scenario()
    protocol = _good_protocol(scenario)

    bd = build_reward_breakdown(
        protocol, scenario, rounds_used=2, max_rounds=6,
        penalties={"invalid_tool_use": 1.0},
    )

    assert "invalid_tool_use" in bd.penalties
    assert bd.penalties["invalid_tool_use"] == 1.0


def test_breakdown_no_penalties_by_default() -> None:
    """Without external penalties, the dict is empty."""
    scenario = _scenario()
    protocol = _good_protocol(scenario)

    bd = build_reward_breakdown(protocol, scenario, rounds_used=2, max_rounds=6)

    assert bd.penalties == {}


def test_breakdown_matches_with_and_without_precomputed_feasibility_check() -> None:
    """Providing a precomputed feasibility check should not change the breakdown."""
    scenario = _scenario("ml_benchmark", "medium")
    protocol = _good_protocol(scenario)
    precomputed = check_feasibility(protocol, scenario)

    with_check = build_reward_breakdown(
        protocol,
        scenario,
        rounds_used=3,
        max_rounds=6,
        check=precomputed,
    )
    without_check = build_reward_breakdown(
        protocol,
        scenario,
        rounds_used=3,
        max_rounds=6,
    )

    assert with_check == without_check


# ---------------------------------------------------------------------------
# JDG 06 — explain_reward
# ---------------------------------------------------------------------------


def test_explain_mentions_all_rubric_components() -> None:
    """Explanation must reference rigor, feasibility, and fidelity."""
    bd = RewardBreakdown(rigor=0.8, feasibility=0.6, fidelity=0.9)
    text = explain_reward(bd)

    assert "Rigor:" in text
    assert "Feasibility:" in text
    assert "Fidelity:" in text
    assert "0.80" in text
    assert "0.60" in text
    assert "0.90" in text


def test_explain_includes_penalties() -> None:
    """Each named penalty key appears in the explanation."""
    bd = RewardBreakdown(
        rigor=0.5,
        feasibility=0.5,
        fidelity=0.5,
        penalties={"invalid_tool_use": 1.0, "unsupported_claim": 0.5},
    )
    text = explain_reward(bd)

    assert "invalid tool use" in text
    assert "unsupported claim" in text
    assert "-1.00" in text
    assert "-0.50" in text


def test_explain_no_penalties_message() -> None:
    """When no penalties exist, the explanation says so."""
    bd = RewardBreakdown(rigor=1.0, feasibility=1.0, fidelity=1.0)
    text = explain_reward(bd)

    assert "No penalties applied" in text


def test_explain_includes_efficiency_bonus() -> None:
    """Efficiency bonus appears when present."""
    bd = RewardBreakdown(
        rigor=0.7, feasibility=0.7, fidelity=0.7, efficiency_bonus=0.8,
    )
    text = explain_reward(bd)

    assert "Efficiency bonus" in text
    assert "+0.80" in text


def test_explain_omits_efficiency_bonus_when_zero() -> None:
    """Efficiency bonus line is absent when bonus is 0."""
    bd = RewardBreakdown(rigor=0.7, feasibility=0.7, fidelity=0.7)
    text = explain_reward(bd)

    assert "Efficiency bonus" not in text


def test_explain_shows_total_reward() -> None:
    """Explanation ends with the computed total reward."""
    bd = RewardBreakdown(rigor=1.0, feasibility=1.0, fidelity=1.0)
    text = explain_reward(bd)

    assert "Total reward: 10.00" in text


def test_explain_tier_labels() -> None:
    """Quality tier labels map correctly to score ranges."""
    strong = RewardBreakdown(rigor=0.85, feasibility=0.5, fidelity=0.25)
    text = explain_reward(strong)

    assert "strong" in text   # rigor 0.85
    assert "moderate" in text  # feasibility 0.5
    assert "weak" in text      # fidelity 0.25


def test_explain_deterministic() -> None:
    """Same breakdown always produces the same explanation."""
    bd = RewardBreakdown(
        rigor=0.6, feasibility=0.4, fidelity=0.8,
        efficiency_bonus=0.5, penalties={"timeout": 0.3},
    )
    assert explain_reward(bd) == explain_reward(bd)


def test_explain_with_real_breakdown() -> None:
    """Explanation works end-to-end with build_reward_breakdown output."""
    scenario = _scenario("ml_benchmark", "easy")
    protocol = _good_protocol(scenario)
    bd = build_reward_breakdown(protocol, scenario, rounds_used=2, max_rounds=6)
    text = explain_reward(bd)

    assert "Rigor:" in text
    assert "Feasibility:" in text
    assert "Fidelity:" in text
    assert "Total reward:" in text