File size: 2,657 Bytes
eb1ebe6
 
 
 
 
 
 
8fa7af1
eb1ebe6
 
 
 
43f41de
 
 
 
 
 
eb1ebe6
43f41de
eb1ebe6
43f41de
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43f41de
 
 
 
 
 
 
 
 
 
 
eb1ebe6
 
 
 
 
8fa7af1
 
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43f41de
eb1ebe6
 
43f41de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Tests for Action/Observation model creation and validation."""

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS
from models import ExplainerAction, ExplainerObservation


def test_action_explore():
    a = ExplainerAction(
        action_type="explore",
        tool="search_arxiv",
        query="attention mechanism",
        intent="visual intuition",
    )
    assert a.action_type == "explore"
    assert a.tool == "search_arxiv"
    assert a.query == "attention mechanism"
    assert a.intent == "visual intuition"
    assert a.code == ""
    assert a.format is None


def test_action_generate_marimo():
    a = ExplainerAction(
        action_type="generate",
        format="marimo",
        code="import marimo as mo\napp = mo.App()",
    )
    assert a.action_type == "generate"
    assert a.format == "marimo"
    assert a.narration == ""


def test_action_generate_manim():
    a = ExplainerAction(
        action_type="generate",
        format="manim",
        code="from manim import *\nclass S(Scene): pass",
        narration="First we show the scene.",
    )
    assert a.format == "manim"
    assert a.narration != ""


def test_action_repair():
    a = ExplainerAction(
        action_type="repair",
        format="marimo",
        code="x = 1",
        repair_notes="fixed syntax",
    )
    assert a.action_type == "repair"
    assert a.repair_notes == "fixed syntax"


def test_observation_defaults():
    obs = ExplainerObservation()
    assert obs.topic == ""
    assert obs.tier == "beginner"
    assert obs.phase == "explore"
    assert obs.explore_steps_left == MAX_EXPLORE_STEPS
    assert obs.repair_attempts_left == MAX_REPAIR_STEPS
    assert obs.done is False


def test_observation_full():
    obs = ExplainerObservation(
        topic="Gradient Descent",
        content="GD iteratively updates params.",
        tier="intermediate",
        keywords="gradient,learning rate",
        data_available=True,
        phase="generate",
        feedback="looks good",
        search_results="paper1...",
        explored_context="accumulated...",
        explore_steps_left=1,
        done=True,
        reward=0.85,
    )
    assert obs.topic == "Gradient Descent"
    assert obs.phase == "generate"
    assert obs.explore_steps_left == 1
    assert obs.reward == 0.85


if __name__ == "__main__":
    test_action_explore()
    test_action_generate_marimo()
    test_action_generate_manim()
    test_action_repair()
    test_observation_defaults()
    test_observation_full()
    print("PASS: test_models (6/6)")