"""Tests for Action/Observation model creation and validation.""" import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS from models import ExplainerAction, ExplainerObservation def test_action_explore(): a = ExplainerAction( action_type="explore", tool="search_arxiv", query="attention mechanism", intent="visual intuition", ) assert a.action_type == "explore" assert a.tool == "search_arxiv" assert a.query == "attention mechanism" assert a.intent == "visual intuition" assert a.code == "" assert a.format is None def test_action_generate_marimo(): a = ExplainerAction( action_type="generate", format="marimo", code="import marimo as mo\napp = mo.App()", ) assert a.action_type == "generate" assert a.format == "marimo" assert a.narration == "" def test_action_generate_manim(): a = ExplainerAction( action_type="generate", format="manim", code="from manim import *\nclass S(Scene): pass", narration="First we show the scene.", ) assert a.format == "manim" assert a.narration != "" def test_action_repair(): a = ExplainerAction( action_type="repair", format="marimo", code="x = 1", repair_notes="fixed syntax", ) assert a.action_type == "repair" assert a.repair_notes == "fixed syntax" def test_observation_defaults(): obs = ExplainerObservation() assert obs.topic == "" assert obs.tier == "beginner" assert obs.phase == "explore" assert obs.explore_steps_left == MAX_EXPLORE_STEPS assert obs.repair_attempts_left == MAX_REPAIR_STEPS assert obs.done is False def test_observation_full(): obs = ExplainerObservation( topic="Gradient Descent", content="GD iteratively updates params.", tier="intermediate", keywords="gradient,learning rate", data_available=True, phase="generate", feedback="looks good", search_results="paper1...", explored_context="accumulated...", explore_steps_left=1, done=True, reward=0.85, ) assert obs.topic == "Gradient Descent" assert obs.phase == "generate" assert obs.explore_steps_left == 1 assert obs.reward == 0.85 if __name__ == "__main__": test_action_explore() test_action_generate_marimo() test_action_generate_manim() test_action_repair() test_observation_defaults() test_observation_full() print("PASS: test_models (6/6)")