explainer-env / tests /test_models.py
kgdrathan's picture
Upload folder using huggingface_hub
8fa7af1 verified
"""Tests for Action/Observation model creation and validation."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS
from models import ExplainerAction, ExplainerObservation
def test_action_explore():
a = ExplainerAction(
action_type="explore",
tool="search_arxiv",
query="attention mechanism",
intent="visual intuition",
)
assert a.action_type == "explore"
assert a.tool == "search_arxiv"
assert a.query == "attention mechanism"
assert a.intent == "visual intuition"
assert a.code == ""
assert a.format is None
def test_action_generate_marimo():
a = ExplainerAction(
action_type="generate",
format="marimo",
code="import marimo as mo\napp = mo.App()",
)
assert a.action_type == "generate"
assert a.format == "marimo"
assert a.narration == ""
def test_action_generate_manim():
a = ExplainerAction(
action_type="generate",
format="manim",
code="from manim import *\nclass S(Scene): pass",
narration="First we show the scene.",
)
assert a.format == "manim"
assert a.narration != ""
def test_action_repair():
a = ExplainerAction(
action_type="repair",
format="marimo",
code="x = 1",
repair_notes="fixed syntax",
)
assert a.action_type == "repair"
assert a.repair_notes == "fixed syntax"
def test_observation_defaults():
obs = ExplainerObservation()
assert obs.topic == ""
assert obs.tier == "beginner"
assert obs.phase == "explore"
assert obs.explore_steps_left == MAX_EXPLORE_STEPS
assert obs.repair_attempts_left == MAX_REPAIR_STEPS
assert obs.done is False
def test_observation_full():
obs = ExplainerObservation(
topic="Gradient Descent",
content="GD iteratively updates params.",
tier="intermediate",
keywords="gradient,learning rate",
data_available=True,
phase="generate",
feedback="looks good",
search_results="paper1...",
explored_context="accumulated...",
explore_steps_left=1,
done=True,
reward=0.85,
)
assert obs.topic == "Gradient Descent"
assert obs.phase == "generate"
assert obs.explore_steps_left == 1
assert obs.reward == 0.85
if __name__ == "__main__":
test_action_explore()
test_action_generate_marimo()
test_action_generate_manim()
test_action_repair()
test_observation_defaults()
test_observation_full()
print("PASS: test_models (6/6)")