Spaces:
Sleeping
Sleeping
File size: 2,126 Bytes
46eecf4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | """Tests for observation model."""
import pytest
from app.core.observation import Observation, TaskContext, ExtractedField, MemoryContext
def test_observation_creation() -> None:
"""Test creating an observation."""
obs = Observation(
episode_id="ep_001",
task_id="task_001",
step_number=0,
)
assert obs.episode_id == "ep_001"
assert obs.step_number == 0
assert obs.extraction_progress == 0.0
def test_observation_with_task_context() -> None:
"""Test observation with task context."""
task_ctx = TaskContext(
task_id="task_001",
task_name="Extract Product",
task_type="extraction",
target_fields=["name", "price"],
required_fields=["name"],
)
obs = Observation(
episode_id="ep_001",
task_id="task_001",
step_number=0,
task_context=task_ctx,
)
assert obs.task_context is not None
assert obs.task_context.task_name == "Extract Product"
def test_observation_extraction_tracking() -> None:
"""Test extraction progress tracking."""
obs = Observation(
episode_id="ep_001",
task_id="task_001",
step_number=1,
extracted_so_far=[
ExtractedField(field_name="name", value="Test Product", confidence=0.95),
],
fields_remaining=["price", "description"],
)
assert len(obs.extracted_so_far) == 1
assert obs.is_field_extracted("name")
assert not obs.is_field_extracted("price")
extraction_dict = obs.get_extraction_dict()
assert extraction_dict["name"] == "Test Product"
def test_observation_context_summary() -> None:
"""Test context summary generation."""
obs = Observation(
episode_id="ep_001",
task_id="task_001",
step_number=5,
current_url="https://example.com",
extraction_progress=0.5,
extracted_so_far=[
ExtractedField(field_name="name", value="Test"),
],
fields_remaining=["price"],
)
summary = obs.get_context_summary()
assert "Step 5" in summary
assert "example.com" in summary
|