marionette / tests /test_api.py
RemiFabre
Refactor marionette into modules, fix audio sync, improve tests
dbc544f
"""Tier 1 β€” Backend unit tests for Marionette.
Run without hardware, without daemon, in under 5 seconds.
Tests the HTTP API layer, state machine, data validation, and persistence.
"""
import json
import threading
from io import BytesIO
from pathlib import Path
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from fastapi.testclient import TestClient
from marionette.main import (
Marionette,
RecordingRequest,
_slugify,
create_app,
DEFAULT_DURATION,
COUNTDOWN_SECONDS,
MOTION_SAMPLE_RATE,
DATASET_DATA_SUBDIR,
)
# ──────── Utility function tests ──────────────────────────────────────
class TestSlugify:
def test_simple_lowercase(self):
assert _slugify("Hello World") == "hello-world"
def test_special_chars(self):
assert _slugify("my@move#1!") == "my-move-1"
def test_leading_trailing_hyphens(self):
assert _slugify("---test---") == "test"
def test_empty_string(self):
assert _slugify("") == "take"
def test_unicode(self):
result = _slugify("cafΓ© rΓ©sumΓ©")
assert result == "caf-r-sum"
def test_already_slugified(self):
assert _slugify("gentle-nod") == "gentle-nod"
def test_numbers(self):
assert _slugify("take 42") == "take-42"
# ──────── State endpoint tests ────────────────────────────────────────
class TestStateEndpoint:
def test_returns_200(self, client: TestClient):
resp = client.get("/api/state")
assert resp.status_code == 200
def test_initial_mode_is_idle(self, client: TestClient):
data = client.get("/api/state").json()
assert data["mode"] == "idle"
def test_initial_message(self, client: TestClient):
data = client.get("/api/state").json()
assert data["message"] == "Ready to capture moves"
def test_state_shape(self, client: TestClient):
data = client.get("/api/state").json()
required_keys = {
"server_time", "mode", "message", "active_move",
"phase_start_at", "phase_end_at",
"countdown_ends_at",
"recording_started_at", "recording_duration", "recording_stats",
"pending_recording", "pending_playback",
"moves", "config", "datasets",
}
assert required_keys.issubset(data.keys())
def test_server_time_present(self, client: TestClient):
import time
data = client.get("/api/state").json()
assert isinstance(data["server_time"], float)
# Should be close to current time (within 5 seconds)
assert abs(data["server_time"] - time.time()) < 5.0
def test_idle_phase_timing_null(self, client: TestClient):
data = client.get("/api/state").json()
assert data["phase_start_at"] is None
assert data["phase_end_at"] is None
def test_config_shape(self, client: TestClient):
config = client.get("/api/state").json()["config"]
assert config["default_duration"] == DEFAULT_DURATION
assert config["countdown_seconds"] == COUNTDOWN_SECONDS
assert config["motion_sample_rate"] == MOTION_SAMPLE_RATE
assert isinstance(config["audio_available"], bool)
assert "motion_models" in config
def test_initial_moves_empty(self, client: TestClient):
data = client.get("/api/state").json()
assert data["moves"] == []
def test_initial_no_pending(self, client: TestClient):
data = client.get("/api/state").json()
assert data["pending_recording"] is None
assert data["pending_playback"] is None
class TestStartingUpMode:
def test_starting_up_rejects_recording(self, client: TestClient, marionette: Marionette):
"""Commands are rejected while the robot is still starting up."""
marionette._set_state(mode="starting_up", message="Starting up…", active_move=None)
resp = client.post("/api/record", json={"duration": 3.0, "record_audio": False})
assert resp.status_code == 409
marionette._set_idle_state()
def test_starting_up_rejects_dataset_changes(self, client: TestClient, marionette: Marionette):
marionette._set_state(mode="starting_up", message="Starting up…", active_move=None)
resp = client.post("/api/datasets", json={"name": "test"})
assert resp.status_code == 409
marionette._set_idle_state()
def test_starting_up_state_visible(self, client: TestClient, marionette: Marionette):
marionette._set_state(mode="starting_up", message="Starting up…", active_move=None)
data = client.get("/api/state").json()
assert data["mode"] == "starting_up"
assert "starting" in data["message"].lower()
marionette._set_idle_state()
# ──────── Recording endpoint tests ────────────────────────────────────
class TestRecordEndpoint:
def test_accept_basic_recording(self, client: TestClient):
resp = client.post("/api/record", json={
"duration": 3.0,
"record_audio": False,
})
assert resp.status_code == 200
data = resp.json()
assert data["accepted"] is True
assert "move_id" in data
def test_mode_becomes_queued(self, client: TestClient):
client.post("/api/record", json={"duration": 3.0, "record_audio": False})
state = client.get("/api/state").json()
assert state["mode"] == "queued"
def test_reject_when_busy(self, client: TestClient):
# First recording is accepted
resp1 = client.post("/api/record", json={"duration": 3.0, "record_audio": False})
assert resp1.status_code == 200
# Second recording is rejected (mode is now "queued")
resp2 = client.post("/api/record", json={"duration": 3.0, "record_audio": False})
assert resp2.status_code == 409
def test_reject_invalid_duration_too_low(self, client: TestClient):
resp = client.post("/api/record", json={"duration": 0.1, "record_audio": False})
assert resp.status_code == 422 # Pydantic validation
def test_reject_invalid_duration_too_high(self, client: TestClient):
resp = client.post("/api/record", json={"duration": 999.0, "record_audio": False})
assert resp.status_code == 422
def test_accept_duration_edge_cases(self, client: TestClient, marionette: Marionette):
# Just above minimum
resp = client.post("/api/record", json={"duration": 0.6, "record_audio": False})
assert resp.status_code == 200
# Reset for next test
marionette._set_idle_state()
marionette._pending_recording = None
# At maximum
resp = client.post("/api/record", json={"duration": 300.0, "record_audio": False})
assert resp.status_code == 200
def test_custom_label(self, client: TestClient):
resp = client.post("/api/record", json={
"duration": 3.0,
"record_audio": False,
"label": "happy-dance",
})
data = resp.json()
assert data["label"] == "happy-dance"
assert data["move_id"] == "happy-dance"
def test_label_collision_appends_index(
self, client: TestClient, marionette: Marionette, tmp_dataset_root: Path
):
# Create a file that would collide
data_dir = tmp_dataset_root / "local_dataset" / DATASET_DATA_SUBDIR
data_dir.mkdir(parents=True, exist_ok=True)
(data_dir / "happy-dance.json").write_text("{}")
resp = client.post("/api/record", json={
"duration": 3.0,
"record_audio": False,
"label": "happy-dance",
})
data = resp.json()
assert data["move_id"] == "happy-dance-1"
def test_preferred_duration_saved(self, client: TestClient, marionette: Marionette):
client.post("/api/record", json={"duration": 7.5, "record_audio": False})
assert marionette._preferred_duration == 7.5
# ──────── Playback endpoint tests ─────────────────────────────────────
class TestPlayEndpoint:
def test_reject_missing_move(self, client: TestClient):
resp = client.post("/api/play", json={"move_id": "nonexistent"})
assert resp.status_code == 404
def test_accept_existing_move(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
# Write a move file to the dataset
data_dir = marionette._dataset_dir
move_path = data_dir / "test-move.json"
move_path.write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
resp = client.post("/api/play", json={"move_id": "test-move"})
assert resp.status_code == 200
assert resp.json()["accepted"] is True
def test_reject_play_when_busy(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
data_dir = marionette._dataset_dir
(data_dir / "test-move.json").write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
# First play is accepted
resp1 = client.post("/api/play", json={"move_id": "test-move"})
assert resp1.status_code == 200
# Second play is rejected (mode is queued)
resp2 = client.post("/api/play", json={"move_id": "test-move"})
assert resp2.status_code == 409
# ──────── Stop endpoints ──────────────────────────────────────────────
class TestStopEndpoints:
def test_stop_playback_when_not_playing(self, client: TestClient):
resp = client.post("/api/play/stop")
assert resp.status_code == 200
assert resp.json()["stopped"] is False
def test_stop_recording_when_not_recording(self, client: TestClient):
resp = client.post("/api/record/stop")
assert resp.status_code == 200
assert resp.json()["stopped"] is False
def test_stop_recording_while_queued(self, client: TestClient, marionette: Marionette):
"""Stopping a queued recording should cancel it and return to idle."""
# Put the app in queued state by submitting a recording
resp = client.post(
"/api/record",
json={"label": "test-queued", "duration": 5.0, "record_audio": False},
)
assert resp.status_code == 200
assert resp.json()["accepted"] is True
# Verify we're in queued mode
state = client.get("/api/state").json()
assert state["mode"] == "queued"
# Stop should work even in queued mode
resp = client.post("/api/record/stop")
assert resp.status_code == 200
assert resp.json()["stopped"] is True
# Should be back to idle
state = client.get("/api/state").json()
assert state["mode"] == "idle"
assert "cancelled" in state["message"].lower()
# ──────── Move deletion tests ─────────────────────────────────────────
class TestMoveDelete:
def test_delete_existing_move(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
data_dir = marionette._dataset_dir
move_path = data_dir / "to-delete.json"
move_path.write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
resp = client.delete("/api/moves/to-delete")
assert resp.status_code == 200
assert not move_path.exists()
def test_delete_nonexistent_move(self, client: TestClient):
resp = client.delete("/api/moves/nonexistent")
assert resp.status_code == 404
def test_delete_removes_wav(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
data_dir = marionette._dataset_dir
(data_dir / "with-audio.json").write_text(json.dumps(sample_move_json))
(data_dir / "with-audio.wav").write_bytes(b"RIFF" + b"\x00" * 100)
marionette._refresh_recordings()
client.delete("/api/moves/with-audio")
assert not (data_dir / "with-audio.json").exists()
assert not (data_dir / "with-audio.wav").exists()
def test_delete_updates_move_list(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
data_dir = marionette._dataset_dir
(data_dir / "test-move.json").write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
state_before = client.get("/api/state").json()
assert len(state_before["moves"]) == 1
client.delete("/api/moves/test-move")
state_after = client.get("/api/state").json()
assert len(state_after["moves"]) == 0
# ──────── Dataset management tests ────────────────────────────────────
class TestDatasets:
def test_initial_default_dataset(self, client: TestClient):
data = client.get("/api/state").json()
datasets = data["datasets"]
assert datasets["active_id"] is not None
assert len(datasets["entries"]) >= 1
def test_create_dataset(self, client: TestClient):
resp = client.post("/api/datasets", json={"name": "My Dances"})
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "created"
assert data["dataset"]["folder"] == "my-dances"
def test_create_duplicate_dataset_rejected(self, client: TestClient):
client.post("/api/datasets", json={"name": "dances"})
resp = client.post("/api/datasets", json={"name": "dances"})
assert resp.status_code == 409
def test_select_dataset(self, client: TestClient):
# Create a second dataset
resp = client.post("/api/datasets", json={"name": "second"})
dataset_id = resp.json()["dataset"]["id"]
# Default is auto-selected after create, so select the original
state = client.get("/api/state").json()
original_id = [
e["id"] for e in state["datasets"]["entries"]
if e["id"] != dataset_id
][0]
resp = client.post("/api/datasets/select", json={"dataset_id": original_id})
assert resp.status_code == 200
def test_select_nonexistent_dataset(self, client: TestClient):
resp = client.post("/api/datasets/select", json={"dataset_id": "fake"})
assert resp.status_code == 404
def test_dataset_root_change(self, client: TestClient, tmp_path: Path):
new_root = tmp_path / "new_root"
new_root.mkdir()
resp = client.post("/api/datasets/root", json={"path": str(new_root)})
assert resp.status_code == 200
assert resp.json()["root_path"] == str(new_root)
def test_default_dataset_origin_is_local(self, client: TestClient):
state = client.get("/api/state").json()
entries = state["datasets"]["entries"]
assert all(e.get("origin") == "local" for e in entries)
def test_created_dataset_origin_is_local(self, client: TestClient):
resp = client.post("/api/datasets", json={"name": "my-local"})
assert resp.status_code == 200
assert resp.json()["dataset"]["origin"] == "local"
def test_record_blocked_on_downloaded_dataset(
self, client: TestClient, marionette: Marionette
):
"""Recording should be rejected when the active dataset is downloaded."""
# Create a dataset and tag it as downloaded
entry = marionette._create_dataset_internal("dl-test", "Downloaded Test", origin="downloaded")
marionette._select_dataset(entry.dataset_id)
marionette._refresh_recordings()
resp = client.post("/api/record", json={"duration": 3.0, "record_audio": False})
assert resp.status_code == 409
assert "downloaded" in resp.json()["detail"].lower()
def test_record_allowed_on_local_dataset(
self, client: TestClient, marionette: Marionette
):
"""Recording should be accepted when the active dataset is local."""
# Default dataset is local, just verify recording is accepted
resp = client.post("/api/record", json={"duration": 3.0, "record_audio": False})
assert resp.status_code == 200
assert resp.json()["accepted"] is True
def test_origin_persisted_in_registry(
self, marionette: Marionette, tmp_registry: Path
):
"""Origin field should be saved and restored from registry."""
entry = marionette._create_dataset_internal("persisted-dl", "Persisted DL", origin="downloaded")
marionette._save_dataset_registry()
raw = json.loads(tmp_registry.read_text(encoding="utf-8"))
ds_entry = next(d for d in raw["datasets"] if d["folder"] == "persisted-dl")
assert ds_entry["origin"] == "downloaded"
# ──────── Registry persistence tests ──────────────────────────────────
class TestRegistryPersistence:
def test_registry_created_on_init(self, tmp_registry: Path, tmp_dataset_root: Path):
create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root)
assert tmp_registry.exists()
data = json.loads(tmp_registry.read_text())
assert "active" in data
assert "datasets" in data
def test_registry_survives_restart(self, tmp_registry: Path, tmp_dataset_root: Path):
# First instance creates a dataset
app1, m1 = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root)
client1 = TestClient(app1)
client1.post("/api/datasets", json={"name": "persistent-ds"})
# Second instance should see it
app2, m2 = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root)
client2 = TestClient(app2)
state = client2.get("/api/state").json()
folders = [e["folder"] for e in state["datasets"]["entries"]]
assert "persistent-ds" in folders
def test_preferred_duration_persisted(self, tmp_registry: Path, tmp_dataset_root: Path):
app1, m1 = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root)
client1 = TestClient(app1)
client1.post("/api/record", json={"duration": 8.5, "record_audio": False})
# Re-create and check
_, m2 = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root)
assert m2._preferred_duration == 8.5
# ──────── Moves list / refresh tests ──────────────────────────────────
class TestMovesRefresh:
def test_moves_appear_after_file_creation(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
data_dir = marionette._dataset_dir
(data_dir / "my-move.json").write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
state = client.get("/api/state").json()
move_ids = [m["id"] for m in state["moves"]]
assert "my-move" in move_ids
def test_move_duration_computed_correctly(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
data_dir = marionette._dataset_dir
(data_dir / "timed.json").write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
state = client.get("/api/state").json()
move = next(m for m in state["moves"] if m["id"] == "timed")
# 500 frames at 100Hz = 4.99s (last timestamp is 4.99)
assert 4.5 < move["duration"] < 5.5
def test_move_has_audio_flag(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
data_dir = marionette._dataset_dir
(data_dir / "audio-move.json").write_text(json.dumps(sample_move_json))
(data_dir / "audio-move.wav").write_bytes(b"RIFF" + b"\x00" * 100)
marionette._refresh_recordings()
state = client.get("/api/state").json()
move = next(m for m in state["moves"] if m["id"] == "audio-move")
assert move["has_audio"] is True
def test_move_without_audio(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
data_dir = marionette._dataset_dir
(data_dir / "silent-move.json").write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
state = client.get("/api/state").json()
move = next(m for m in state["moves"] if m["id"] == "silent-move")
assert move["has_audio"] is False
# ──────── Settings tests ───────────────────────────
class TestExperiments:
def test_motion_models_present(self, client: TestClient):
data = client.get("/api/state").json()
mm = data["config"]["motion_models"]
assert mm["active"] == "lead_compensation"
assert "lead_compensation" in mm["params"]
def test_update_duration(self, client: TestClient):
resp = client.post("/api/experiments", json={"duration_seconds": 10.0})
assert resp.status_code == 200
assert resp.json()["preferred_duration"] == 10.0
def test_no_changes(self, client: TestClient):
resp = client.post("/api/experiments", json={})
assert resp.json()["status"] == "unchanged"
class TestWelcomeMessages:
def test_default_is_two(self, client: TestClient):
"""Default welcome_messages is 2 (full greeting)."""
config = client.get("/api/state").json()["config"]
assert config["welcome_messages"] == 2
def test_set_to_zero(self, client: TestClient):
resp = client.post("/api/experiments", json={"welcome_messages": 0})
assert resp.status_code == 200
config = client.get("/api/state").json()["config"]
assert config["welcome_messages"] == 0
def test_set_to_one(self, client: TestClient):
resp = client.post("/api/experiments", json={"welcome_messages": 1})
assert resp.status_code == 200
config = client.get("/api/state").json()["config"]
assert config["welcome_messages"] == 1
def test_set_to_two(self, client: TestClient):
resp = client.post("/api/experiments", json={"welcome_messages": 2})
assert resp.status_code == 200
config = client.get("/api/state").json()["config"]
assert config["welcome_messages"] == 2
def test_clamped_above_max(self, client: TestClient):
"""Values above 2 are rejected by Pydantic validation."""
resp = client.post("/api/experiments", json={"welcome_messages": 5})
assert resp.status_code == 422
def test_clamped_below_min(self, client: TestClient):
"""Negative values are rejected by Pydantic validation."""
resp = client.post("/api/experiments", json={"welcome_messages": -1})
assert resp.status_code == 422
def test_persisted_in_registry(self, client: TestClient, marionette: Marionette):
"""welcome_messages is saved to and loaded from the dataset registry."""
client.post("/api/experiments", json={"welcome_messages": 0})
reg_path = Path(marionette._registry_path)
data = json.loads(reg_path.read_text())
assert data["welcome_messages"] == 0
# ──────── Sensor data dummy endpoint ──────────────────────────────────
class TestHfAutoLogin:
def test_state_includes_hf_username_key(self, client: TestClient):
config = client.get("/api/state").json()["config"]
assert "hf_username" in config
# Value is None when not logged in, or a string when logged in
assert config["hf_username"] is None or isinstance(config["hf_username"], str)
def test_auto_detected_username_in_state(self, client: TestClient, marionette: Marionette):
"""When HF login is detected, username appears in state."""
marionette._hf_checked = False
marionette._hf_username = None
import marionette.datasets as md
original_whoami = md.hf_whoami
md.hf_whoami = lambda: {"name": "testuser"}
try:
config = client.get("/api/state").json()["config"]
assert config["hf_username"] == "testuser"
finally:
md.hf_whoami = original_whoami
def test_cached_after_first_check(self, marionette: Marionette):
"""HF login check is cached after first call."""
import marionette.datasets as md
call_count = 0
original_whoami = md.hf_whoami
def counting_whoami():
nonlocal call_count
call_count += 1
return {"name": "cached-user"}
md.hf_whoami = counting_whoami
marionette._hf_checked = False
marionette._hf_username = None
try:
result1 = marionette._check_hf_login()
result2 = marionette._check_hf_login()
assert result1 == "cached-user"
assert result2 == "cached-user"
assert call_count == 1
finally:
md.hf_whoami = original_whoami
def test_sync_without_username_uses_autodetected(
self, client: TestClient, marionette: Marionette, tmp_path: Path
):
"""Sync endpoint uses auto-detected username when none provided."""
import marionette.datasets as md
original_whoami = md.hf_whoami
marionette._hf_checked = False
marionette._hf_username = None
md.hf_whoami = lambda: {"name": "auto-user"}
try:
# No moves exist, so sync will fail with 400 (no moves found),
# but we verify it gets past the username check
resp = client.post("/api/datasets/sync", json={"move_ids": ["nonexistent"]})
# 404 = move not found (got past the username validation)
assert resp.status_code == 404
finally:
md.hf_whoami = original_whoami
def test_sync_without_username_and_no_login_returns_400(
self, client: TestClient, marionette: Marionette
):
"""Sync without username and no HF login returns 400."""
import marionette.datasets as md
original_whoami = md.hf_whoami
original_get_token = md.hf_get_token
marionette._hf_checked = False
marionette._hf_username = None
md.hf_whoami = None
md.hf_get_token = lambda: None
try:
resp = client.post("/api/datasets/sync", json={"move_ids": ["some-move"]})
assert resp.status_code == 400
assert "not logged in" in resp.json()["detail"].lower()
finally:
md.hf_whoami = original_whoami
md.hf_get_token = original_get_token
class TestHfTokenLogin:
def test_save_token_missing_prefix(self, client: TestClient):
"""Token that doesn't start with hf_ is rejected."""
resp = client.post("/api/hf-auth/save-token", json={"token": "bad_token_value"})
assert resp.status_code == 422
def test_save_token_too_short(self, client: TestClient):
"""Token shorter than 5 chars is rejected by Pydantic."""
resp = client.post("/api/hf-auth/save-token", json={"token": "hf_"})
assert resp.status_code == 422
def test_save_token_success(self, client: TestClient, marionette: Marionette):
"""Valid token saves and returns username."""
import marionette.datasets as md
original_login = md.hf_login
original_whoami = md.hf_whoami
md.hf_login = lambda token, add_to_git_credential: None
md.hf_whoami = lambda: {"name": "token-user"}
try:
resp = client.post("/api/hf-auth/save-token", json={"token": "hf_valid_token_12345"})
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "logged_in"
assert data["username"] == "token-user"
# State should now reflect the username
config = client.get("/api/state").json()["config"]
assert config["hf_username"] == "token-user"
finally:
md.hf_login = original_login
md.hf_whoami = original_whoami
def test_delete_token_success(self, client: TestClient, marionette: Marionette):
"""Logout clears the cached username."""
import marionette.datasets as md
original_logout = md.hf_logout
md.hf_logout = lambda: None
marionette._hf_username = "some-user"
marionette._hf_checked = True
try:
resp = client.delete("/api/hf-auth/token")
assert resp.status_code == 200
assert resp.json()["status"] == "logged_out"
assert marionette._hf_username is None
assert marionette._hf_checked is False
finally:
md.hf_logout = original_logout
def test_delete_token_clears_state(self, client: TestClient, marionette: Marionette):
"""After logout, state reports no username."""
import marionette.datasets as md
original_logout = md.hf_logout
original_whoami = md.hf_whoami
md.hf_logout = lambda: None
md.hf_whoami = lambda: None # whoami returns None after logout
marionette._hf_username = "old-user"
marionette._hf_checked = True
try:
client.delete("/api/hf-auth/token")
config = client.get("/api/state").json()["config"]
assert config["hf_username"] is None
finally:
md.hf_logout = original_logout
md.hf_whoami = original_whoami
class TestSensorData:
def test_returns_empty(self, client: TestClient):
resp = client.get("/sensor_data")
assert resp.status_code == 200
assert resp.json() == {}
# ──────── Upload audio tests ─────────────────────────────────────────
class TestUploadAudio:
def test_upload_wav_returns_upload_id(self, client: TestClient):
from conftest import make_wav_bytes
wav = make_wav_bytes(1.0)
resp = client.post(
"/api/upload-audio",
files={"file": ("test.wav", BytesIO(wav), "audio/wav")},
)
assert resp.status_code == 200
data = resp.json()
assert "upload_id" in data
assert data["filename"] == "test.wav"
def test_upload_wav_duration_extracted(self, client: TestClient):
from conftest import make_wav_bytes
wav = make_wav_bytes(2.0)
resp = client.post(
"/api/upload-audio",
files={"file": ("two-sec.wav", BytesIO(wav), "audio/wav")},
)
assert resp.status_code == 200
duration = resp.json().get("duration")
# Duration may be None if soundfile is unavailable, skip check in that case
if duration is not None:
assert abs(duration - 2.0) < 0.5
def test_upload_rejects_unsupported_format(self, client: TestClient):
resp = client.post(
"/api/upload-audio",
files={"file": ("notes.txt", BytesIO(b"hello"), "text/plain")},
)
assert resp.status_code == 400
def test_upload_rejects_empty_filename(self, client: TestClient):
resp = client.post(
"/api/upload-audio",
files={"file": ("", BytesIO(b"data"), "audio/wav")},
)
assert resp.status_code in (400, 422)
def test_upload_mp3_accepted(self, client: TestClient):
# A minimal fake MP3 β€” server accepts based on extension
resp = client.post(
"/api/upload-audio",
files={"file": ("song.mp3", BytesIO(b"\xff\xfb\x90\x00" + b"\x00" * 100), "audio/mpeg")},
)
# Accepted (200) or 500 if soundfile can't parse β€” never 400 for extension
assert resp.status_code in (200, 500)
def test_uploaded_audio_id_usable_in_record(self, client: TestClient, marionette: Marionette):
from conftest import make_wav_bytes
wav = make_wav_bytes(1.0)
upload_resp = client.post(
"/api/upload-audio",
files={"file": ("rec.wav", BytesIO(wav), "audio/wav")},
)
assert upload_resp.status_code == 200
upload_id = upload_resp.json()["upload_id"]
resp = client.post("/api/record", json={
"duration": 3.0,
"record_audio": False,
"uploaded_audio_id": upload_id,
})
assert resp.status_code == 200
assert resp.json()["accepted"] is True
def test_record_with_invalid_upload_id(self, client: TestClient):
resp = client.post("/api/record", json={
"duration": 3.0,
"record_audio": False,
"uploaded_audio_id": "nonexistent-uuid",
})
assert resp.status_code == 400
class TestTimingSync:
def test_uploaded_audio_waits_for_capture_start(
self, marionette: Marionette, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
"""Audio preload is now done in _perform_recording before countdown.
_run_capture_and_save receives the pre-built audio thread and events.
Verify the audio thread waits for audio_start (fired by on_capture_start)."""
wav_path = tmp_path / "uploaded.wav"
wav_path.write_bytes(b"fake")
request = RecordingRequest(
move_id="sync-test",
label="sync-test",
description="",
duration=0.2,
record_audio=False,
record_motion=True,
uploaded_audio_path=wav_path,
)
events: list[tuple[str, bool | None]] = []
class _FakeReachy:
class media:
@staticmethod
def get_output_audio_samplerate():
return 16000
audio_start = threading.Event()
audio_stop = threading.Event()
def fake_play_preloaded_wav(
_reachy, _wav_data, _stop_event, chunk_duration=0.02, pipeline_ready=False, start_signal=None
):
events.append(("audio_before_wait", start_signal.is_set() if start_signal else None))
if start_signal is not None:
start_signal.wait(timeout=1.0)
events.append(("audio_after_wait", start_signal.is_set() if start_signal else None))
audio_thread = threading.Thread(
target=fake_play_preloaded_wav,
args=(_FakeReachy(), (np.zeros(8, dtype=np.float32), 16000), audio_stop),
kwargs={"pipeline_ready": True, "start_signal": audio_start},
daemon=True,
)
audio_thread.start()
def fake_capture_motion(
_reachy,
_stop_event,
_duration,
_record_audio,
record_motion=True,
on_capture_start=None,
):
if on_capture_start is not None:
on_capture_start()
return [0.0, 0.01], [{"head": [], "antennas": [], "body_yaw": 0.0, "check_collision": False}], [], None
monkeypatch.setattr(marionette, "_capture_motion", fake_capture_motion)
monkeypatch.setattr(marionette, "_save_recording", lambda *_a, **_k: None)
monkeypatch.setattr(marionette, "_refresh_recordings", lambda *_a, **_k: None)
marionette._run_capture_and_save(
_FakeReachy(), threading.Event(), request,
audio_thread=audio_thread,
audio_start=audio_start,
audio_stop=audio_stop,
)
assert ("audio_before_wait", False) in events
assert ("audio_after_wait", True) in events
def test_stream_playback_ignores_wall_clock_time(
self, marionette: Marionette, monkeypatch: pytest.MonkeyPatch
):
class _FakeMove:
duration = 0.03
def evaluate(self, _t: float):
return np.eye(4), [0.0, 0.0], 0.0
class _FakeReachy:
def set_target_head_pose(self, _head):
return None
def set_target_body_yaw(self, _yaw: float):
return None
def set_target_antenna_joint_positions(self, _antennas):
return None
monkeypatch.setattr("marionette.recording.time.time", lambda: (_ for _ in ()).throw(RuntimeError("wall clock used")))
marionette._playback_cancel_event.clear()
assert marionette._stream_playback(_FakeReachy(), _FakeMove()) is True
# ──────── Lead compensation endpoint tests ────────────────────────────────
class TestMotionModelEndpoint:
def test_update_lead_compensation_params(self, client: TestClient):
resp = client.post(
"/api/motion-model/lead",
json={"lead_frames_head": 3, "lead_frames_antennas": 1},
)
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "updated"
assert data["params"]["lead_frames_head"] == 3
assert data["params"]["lead_frames_antennas"] == 1
state = client.get("/api/state").json()
params = state["config"]["motion_models"]["params"]["lead_compensation"]
assert params["lead_frames_head"] == 3
assert params["lead_frames_antennas"] == 1
def test_lead_params_persisted_in_registry(
self, tmp_registry: Path, tmp_dataset_root: Path
):
app1, _ = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root)
c1 = TestClient(app1)
c1.post(
"/api/motion-model/lead",
json={"lead_frames_head": 4, "lead_frames_antennas": 2},
)
_, m2 = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root)
params = m2._motion_model_registry.get_model_params("lead_compensation")
assert params["lead_frames_head"] == 4
assert params["lead_frames_antennas"] == 2
# ──────── Community datasets tests ───────────────────────────────────
class TestCommunityDatasets:
def test_community_returns_empty_list(self, client: TestClient, marionette: Marionette):
"""When HTTP fetch returns empty, endpoint returns an empty list."""
original = marionette._fetch_community_datasets_http
marionette._fetch_community_datasets_http = lambda: []
try:
resp = client.get("/api/datasets/community")
assert resp.status_code == 200
# May be empty list (HfApi also not available in test env)
assert isinstance(resp.json()["datasets"], list)
finally:
marionette._fetch_community_datasets_http = original
def test_download_rejects_invalid_repo_id(self, client: TestClient):
resp = client.post("/api/datasets/download", json={"repo_id": "no-slash"})
assert resp.status_code == 400
def test_download_removes_existing_folder_on_redownload(self, client: TestClient, marionette: Marionette):
# Create a dataset first
client.post("/api/datasets", json={"name": "existing-ds"})
old_ids = {e.dataset_id for e in marionette._datasets.values() if e.folder == "existing-ds"}
assert old_ids, "dataset should exist"
# Re-download replaces the old entry (will fail at HF fetch, but the old entry
# should already be removed before the network call).
resp = client.post("/api/datasets/download", json={
"repo_id": "someone/existing-ds",
"name": "existing-ds",
})
# Network call fails β†’ 502, but old entry was cleaned up
assert resp.status_code == 502
remaining = {e.dataset_id for e in marionette._datasets.values() if e.folder == "existing-ds"}
assert remaining.isdisjoint(old_ids), "old dataset entry should have been removed"
# ──────── Corrupt data tests ────────────────────────────────────────
class TestCorruptData:
def test_malformed_json_skipped(self, marionette: Marionette):
data_dir = marionette._dataset_dir
(data_dir / "bad-file.json").write_text("{{{", encoding="utf-8")
marionette._refresh_recordings()
assert "bad-file" not in marionette._recordings
def test_json_missing_time_key(self, marionette: Marionette):
data_dir = marionette._dataset_dir
(data_dir / "no-time.json").write_text(
json.dumps({"description": "test"}), encoding="utf-8"
)
marionette._refresh_recordings()
# File is loaded but with duration 0 (empty timestamps)
if "no-time" in marionette._recordings:
assert marionette._recordings["no-time"].duration == 0.0
def test_json_empty_time_array(self, marionette: Marionette):
data_dir = marionette._dataset_dir
(data_dir / "empty-time.json").write_text(
json.dumps({"time": [], "set_target_data": []}), encoding="utf-8"
)
marionette._refresh_recordings()
if "empty-time" in marionette._recordings:
assert marionette._recordings["empty-time"].duration == 0.0
def test_corrupt_registry_recovers(self, tmp_path: Path):
reg = tmp_path / "corrupt_reg.json"
reg.write_text("NOT VALID JSON {{{", encoding="utf-8")
ds_root = tmp_path / "ds"
ds_root.mkdir()
app, m = create_app(registry_path=reg, dataset_root=ds_root)
# Should have recovered with defaults
assert m._active_dataset_id is not None
assert len(m._datasets) >= 1
# ──────── Concurrent state change tests ─────────────────────────────
class TestConcurrentStateChanges:
def test_play_while_queued_rejected(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
data_dir = marionette._dataset_dir
(data_dir / "play-test.json").write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
# Submit a recording to enter queued state
client.post("/api/record", json={"duration": 3.0, "record_audio": False})
assert client.get("/api/state").json()["mode"] == "queued"
resp = client.post("/api/play", json={"move_id": "play-test"})
assert resp.status_code == 409
def test_record_while_playing_rejected(
self, client: TestClient, marionette: Marionette
):
marionette._set_state(mode="playing", message="Playing…", active_move="x")
resp = client.post("/api/record", json={"duration": 3.0, "record_audio": False})
assert resp.status_code == 409
marionette._set_idle_state()
def test_sync_while_busy_rejected(
self, client: TestClient, marionette: Marionette
):
marionette._set_state(mode="recording", message="Recording…", active_move=None)
resp = client.post("/api/datasets/sync", json={"move_ids": ["x"]})
assert resp.status_code == 409
marionette._set_idle_state()
def test_dataset_root_change_while_busy(
self, client: TestClient, marionette: Marionette, tmp_path: Path
):
marionette._set_state(mode="recording", message="Recording…", active_move=None)
resp = client.post("/api/datasets/root", json={"path": str(tmp_path)})
assert resp.status_code == 409
marionette._set_idle_state()
# ──────── Duration edge case tests ──────────────────────────────────
class TestDurationEdgeCases:
def test_duration_just_above_minimum(self, client: TestClient):
resp = client.post("/api/record", json={"duration": 0.51, "record_audio": False})
assert resp.status_code == 200
def test_duration_at_maximum(self, client: TestClient, marionette: Marionette):
resp = client.post("/api/record", json={"duration": 300.0, "record_audio": False})
assert resp.status_code == 200
def test_duration_at_minimum_boundary_rejected(self, client: TestClient):
"""Pydantic field has gt=0.5, so exactly 0.5 should be rejected."""
resp = client.post("/api/record", json={"duration": 0.5, "record_audio": False})
assert resp.status_code == 422
# ──────── Sync dataset extended tests ───────────────────────────────
class TestSyncDatasetExtended:
def test_sync_empty_move_ids_rejected(self, client: TestClient):
"""Pydantic min_items=1 should reject empty move_ids."""
resp = client.post("/api/datasets/sync", json={"move_ids": []})
assert resp.status_code == 422
def test_sync_nonexistent_moves(self, client: TestClient, marionette: Marionette):
import marionette.datasets as md
original_whoami = md.hf_whoami
marionette._hf_checked = False
marionette._hf_username = None
md.hf_whoami = lambda: {"name": "testuser"}
try:
resp = client.post("/api/datasets/sync", json={
"move_ids": ["fake-move-id"],
})
assert resp.status_code == 404
finally:
md.hf_whoami = original_whoami
def test_sync_no_active_dataset(self, client: TestClient, marionette: Marionette):
import marionette.datasets as md
original_whoami = md.hf_whoami
md.hf_whoami = lambda: {"name": "testuser"}
marionette._hf_checked = False
marionette._hf_username = None
# Save and clear active dataset
old_id = marionette._active_dataset_id
marionette._active_dataset_id = None
# Clear recordings to avoid "move not found" before "no active dataset"
marionette._recordings = {}
try:
resp = client.post("/api/datasets/sync", json={
"move_ids": ["any-move"],
})
# Should be 404 for the move not found (since recordings is empty)
assert resp.status_code == 404
finally:
marionette._active_dataset_id = old_id
md.hf_whoami = original_whoami
def test_record_on_downloaded_dataset_rejected(
self, client: TestClient, marionette: Marionette
):
entry = marionette._create_dataset_internal("dl-sync-test", "DL Sync Test", origin="downloaded")
marionette._select_dataset(entry.dataset_id)
marionette._refresh_recordings()
resp = client.post("/api/record", json={"duration": 3.0, "record_audio": False})
assert resp.status_code == 409
assert "downloaded" in resp.json()["detail"].lower()
# ──────── API contract tests (refactoring protection) ──────────────
class TestApiContracts:
"""Verify exact response shapes that the frontend depends on.
These tests protect against accidental field renames or removals
during frontend refactoring.
"""
def test_state_top_level_keys(self, client: TestClient):
data = client.get("/api/state").json()
required = {
"server_time", "mode", "message", "active_move",
"phase_start_at", "phase_end_at",
"countdown_ends_at",
"recording_started_at", "recording_duration", "recording_stats",
"pending_recording", "pending_playback",
"moves", "config", "datasets",
}
assert required == required.intersection(data.keys()), (
f"Missing keys: {required - data.keys()}"
)
def test_config_keys(self, client: TestClient):
config = client.get("/api/state").json()["config"]
required = {
"default_duration", "preferred_duration", "countdown_seconds",
"motion_sample_rate", "audio_available",
"active_dataset_path", "dataset_root_path",
"hf_username", "motion_models",
}
assert required.issubset(config.keys()), (
f"Missing config keys: {required - config.keys()}"
)
def test_datasets_payload_shape(self, client: TestClient):
datasets = client.get("/api/state").json()["datasets"]
assert "active_id" in datasets
assert "root_path" in datasets
assert "entries" in datasets
assert isinstance(datasets["entries"], list)
def test_dataset_entry_keys(self, client: TestClient):
entries = client.get("/api/state").json()["datasets"]["entries"]
assert len(entries) >= 1, "Should have at least one default dataset"
entry = entries[0]
required = {"id", "label", "path", "folder", "origin"}
assert required.issubset(entry.keys()), (
f"Missing dataset entry keys: {required - entry.keys()}"
)
def test_move_payload_keys(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
data_dir = marionette._dataset_dir
(data_dir / "contract-test.json").write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
moves = client.get("/api/state").json()["moves"]
assert len(moves) >= 1
move = moves[0]
required = {
"id", "label", "duration", "created_at",
"has_audio", "description", "is_uploaded",
}
assert required.issubset(move.keys()), (
f"Missing move keys: {required - move.keys()}"
)
def test_error_responses_have_detail(self, client: TestClient):
"""All 4xx responses should include a 'detail' field."""
# 409: record while busy
from marionette.main import Marionette as M
# 422: invalid payload
resp = client.post("/api/record", json={"duration": -1})
assert resp.status_code == 422
assert "detail" in resp.json()
# 400: play without move_id
resp = client.post("/api/play", json={})
assert resp.status_code == 422
assert "detail" in resp.json()
def test_record_response_shape(self, client: TestClient):
resp = client.post("/api/record", json={
"duration": 2.0, "record_audio": False,
})
assert resp.status_code == 200
data = resp.json()
assert "accepted" in data
assert "move_id" in data
assert "label" in data
def test_create_dataset_response_shape(self, client: TestClient):
resp = client.post("/api/datasets", json={"name": "contract-ds"})
assert resp.status_code == 200
data = resp.json()
assert "status" in data
assert "dataset" in data
ds = data["dataset"]
assert "id" in ds
assert "label" in ds
assert "folder" in ds
def test_experiments_response_shape(self, client: TestClient):
resp = client.post("/api/experiments", json={"duration_seconds": 7.0})
assert resp.status_code == 200
data = resp.json()
assert "status" in data
assert "preferred_duration" in data
# ──────── Audio-only recording tests ─────────────────────────────────
class TestAudioOnlyRecording:
"""Tests for the audio-only (mic-only, no motion) recording mode."""
def test_record_motion_defaults_true(self, client: TestClient):
resp = client.post("/api/record", json={"duration": 2.0, "record_audio": False})
assert resp.status_code == 200
# Mode transitions to queued (record_motion defaults to True)
state = client.get("/api/state").json()
assert state["mode"] == "queued"
def test_record_motion_false_accepted(self, client: TestClient, marionette: Marionette):
resp = client.post("/api/record", json={
"duration": 2.0,
"record_audio": True,
"record_motion": False,
})
assert resp.status_code == 200
data = resp.json()
assert data["accepted"] is True
# The pending request should have record_motion=False
assert marionette._pending_recording is not None
assert marionette._pending_recording.record_motion is False
# Cleanup
marionette._set_idle_state()
marionette._pending_recording = None
def test_audio_only_save_has_flag(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
"""Simulate saving an audio-only recording and verify the JSON has audio_only flag."""
from marionette.main import RecordingRequest
req = RecordingRequest(
move_id="audio-only-test",
label="Audio Only Test",
description="test",
duration=2.0,
record_audio=True,
record_motion=False,
)
timestamps = [i * 0.01 for i in range(200)]
frames = [] # No motion frames
marionette._save_recording(req, timestamps, frames, [], None)
json_path = marionette._dataset_dir / "audio-only-test.json"
assert json_path.exists()
data = json.loads(json_path.read_text())
assert data.get("audio_only") is True
assert data["set_target_data"] == []
assert len(data["time"]) == 200
# Cleanup
json_path.unlink(missing_ok=True)
def test_audio_only_metadata_flag(
self, client: TestClient, marionette: Marionette
):
"""Verify _refresh_recordings reads audio_only flag correctly."""
data_dir = marionette._dataset_dir
move_data = {
"description": "audio only test",
"audio_only": True,
"time": [0.0, 0.01, 0.02],
"set_target_data": [],
}
(data_dir / "ao-test.json").write_text(json.dumps(move_data))
marionette._refresh_recordings()
state = client.get("/api/state").json()
ao_move = next((m for m in state["moves"] if m["id"] == "ao-test"), None)
assert ao_move is not None
assert ao_move["audio_only"] is True
# Cleanup
(data_dir / "ao-test.json").unlink(missing_ok=True)
marionette._refresh_recordings()
def test_normal_recording_no_audio_only_flag(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
"""Normal recordings should not have audio_only flag."""
data_dir = marionette._dataset_dir
(data_dir / "normal-test.json").write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
state = client.get("/api/state").json()
normal = next((m for m in state["moves"] if m["id"] == "normal-test"), None)
assert normal is not None
assert normal["audio_only"] is False
# Cleanup
(data_dir / "normal-test.json").unlink(missing_ok=True)
marionette._refresh_recordings()
# ──────── Stop endpoint tests (extended) ─────────────────────────────
class TestStopEndpointsExtended:
"""Additional stop endpoint tests for edge cases."""
def test_state_is_idle_after_record_stop(self, client: TestClient, marionette: Marionette):
"""After /api/record/stop, /api/state should return mode=idle."""
# Queue a recording
resp = client.post("/api/record", json={"duration": 5.0, "record_audio": False})
assert resp.status_code == 200
# Stop it
client.post("/api/record/stop")
# Manually set idle since there's no robot thread to process
marionette._set_idle_state()
marionette._pending_recording = None
state = client.get("/api/state").json()
assert state["mode"] == "idle"
def test_state_is_idle_after_play_stop(self, client: TestClient, marionette: Marionette):
"""After /api/play/stop, /api/state should return mode=idle."""
state = client.get("/api/state").json()
assert state["mode"] == "idle"
# Stop when already idle should be harmless
resp = client.post("/api/play/stop")
assert resp.status_code == 200
state = client.get("/api/state").json()
assert state["mode"] == "idle"
# ──────── Version endpoint tests ─────────────────────────────────────
class TestVersionEndpoint:
def test_version_returns_200(self, client: TestClient):
resp = client.get("/api/version")
assert resp.status_code == 200
def test_version_has_required_fields(self, client: TestClient):
data = client.get("/api/version").json()
assert "version" in data
assert "platform" in data
assert "hostname" in data
def test_version_is_string(self, client: TestClient):
data = client.get("/api/version").json()
assert isinstance(data["version"], str)
assert len(data["version"]) > 0
# ──────── Robot audio endpoint tests ─────────────────────────────
class TestRobotAudioEndpoints:
"""Tests for GET /api/robot-audio and POST /api/robot-audio/select."""
def test_robot_audio_list_empty_by_default(self, client: TestClient):
resp = client.get("/api/robot-audio")
assert resp.status_code == 200
data = resp.json()
assert "files" in data
assert isinstance(data["files"], list)
def test_robot_audio_list_finds_audio_only_wav(
self, client: TestClient, marionette: Marionette
):
"""Audio-only recordings appear in robot-audio list."""
from conftest import make_wav_bytes
import json
wav_bytes = make_wav_bytes(1.0)
wav_path = marionette._dataset_dir / "robot-test.wav"
wav_path.write_bytes(wav_bytes)
json_path = marionette._dataset_dir / "robot-test.json"
json_path.write_text(json.dumps({
"time": [0.0], "set_target_data": [], "audio_only": True,
}))
marionette._refresh_recordings()
resp = client.get("/api/robot-audio")
assert resp.status_code == 200
files = resp.json()["files"]
names = [f["name"] for f in files]
assert "robot-test" in names
# Each file should have name, path, duration_seconds
for f in files:
assert "name" in f
assert "path" in f
assert "duration_seconds" in f
# Cleanup
wav_path.unlink(missing_ok=True)
json_path.unlink(missing_ok=True)
marionette._refresh_recordings()
def test_robot_audio_excludes_non_audio_only(
self, client: TestClient, marionette: Marionette
):
"""Normal motion moves with WAV files should NOT appear in robot-audio list."""
from conftest import make_wav_bytes
import json
wav_bytes = make_wav_bytes(1.0)
wav_path = marionette._dataset_dir / "motion-move.wav"
wav_path.write_bytes(wav_bytes)
json_path = marionette._dataset_dir / "motion-move.json"
json_path.write_text(json.dumps({
"time": [0.0, 0.01],
"set_target_data": [
{"head": [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]], "antennas": [0,0]},
{"head": [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]], "antennas": [0,0]},
],
}))
marionette._refresh_recordings()
resp = client.get("/api/robot-audio")
assert resp.status_code == 200
files = resp.json()["files"]
names = [f["name"] for f in files]
assert "motion-move" not in names
# Cleanup
wav_path.unlink(missing_ok=True)
json_path.unlink(missing_ok=True)
marionette._refresh_recordings()
def test_robot_audio_select_returns_upload_id(
self, client: TestClient, marionette: Marionette
):
"""POST /api/robot-audio/select with valid WAV returns upload_id."""
from conftest import make_wav_bytes
wav_bytes = make_wav_bytes(1.5)
wav_path = marionette._dataset_dir / "selectable.wav"
wav_path.write_bytes(wav_bytes)
resp = client.post(
"/api/robot-audio/select",
json={"path": str(wav_path)},
)
assert resp.status_code == 200
data = resp.json()
assert "upload_id" in data
assert data["filename"] == "selectable.wav"
# Cleanup
wav_path.unlink(missing_ok=True)
def test_robot_audio_select_rejects_path_traversal(self, client: TestClient):
"""Path outside dataset/temp dirs is rejected with 403."""
resp = client.post(
"/api/robot-audio/select",
json={"path": "/etc/passwd"},
)
assert resp.status_code in (403, 404)
def test_robot_audio_select_rejects_nonexistent(self, client: TestClient):
"""Non-existent file returns 404."""
resp = client.post(
"/api/robot-audio/select",
json={"path": "/tmp/does-not-exist-at-all.wav"},
)
assert resp.status_code in (403, 404)
def test_robot_audio_select_rejects_non_wav(
self, client: TestClient, marionette: Marionette
):
"""Non-WAV file within allowed dirs is rejected."""
txt_path = marionette._dataset_dir / "notes.txt"
txt_path.write_text("not audio")
resp = client.post(
"/api/robot-audio/select",
json={"path": str(txt_path)},
)
assert resp.status_code == 400
# Cleanup
txt_path.unlink(missing_ok=True)
# ──────── Move metadata tests ────────────────────────────────────
class TestMoveMetadata:
"""Tests for move metadata fields in the state endpoint."""
def test_move_created_at_is_float(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
"""created_at should be a numeric Unix timestamp."""
data_dir = marionette._dataset_dir
(data_dir / "meta-test.json").write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
state = client.get("/api/state").json()
move = next((m for m in state["moves"] if m["id"] == "meta-test"), None)
assert move is not None
assert isinstance(move["created_at"], float)
assert move["created_at"] > 0
# Cleanup
(data_dir / "meta-test.json").unlink(missing_ok=True)
marionette._refresh_recordings()
def test_move_created_at_is_recent(
self, client: TestClient, marionette: Marionette, sample_move_json: dict
):
"""A freshly created move should have a recent timestamp."""
import time
data_dir = marionette._dataset_dir
(data_dir / "recent-test.json").write_text(json.dumps(sample_move_json))
marionette._refresh_recordings()
state = client.get("/api/state").json()
move = next((m for m in state["moves"] if m["id"] == "recent-test"), None)
assert move is not None
# Should be within last 60 seconds
assert abs(move["created_at"] - time.time()) < 60
# Cleanup
(data_dir / "recent-test.json").unlink(missing_ok=True)
marionette._refresh_recordings()
def test_audio_only_move_in_robot_audio_list(
self, client: TestClient, marionette: Marionette
):
"""Audio-only recordings with WAV files should appear in robot-audio list."""
from conftest import make_wav_bytes
data_dir = marionette._dataset_dir
move_data = {
"description": "audio only for robot list",
"audio_only": True,
"time": [0.0, 0.01, 0.02],
"set_target_data": [],
}
(data_dir / "ao-robot.json").write_text(json.dumps(move_data))
(data_dir / "ao-robot.wav").write_bytes(make_wav_bytes(1.0))
marionette._refresh_recordings()
resp = client.get("/api/robot-audio")
assert resp.status_code == 200
names = [f["name"] for f in resp.json()["files"]]
assert "ao-robot" in names
# Cleanup
(data_dir / "ao-robot.json").unlink(missing_ok=True)
(data_dir / "ao-robot.wav").unlink(missing_ok=True)
marionette._refresh_recordings()
def test_record_with_uploaded_audio_accepted(
self, client: TestClient, marionette: Marionette
):
"""Upload audio, then start a recording using that audio ID."""
from conftest import make_wav_bytes
wav = make_wav_bytes(2.0)
upload_resp = client.post(
"/api/upload-audio",
files={"file": ("track.wav", BytesIO(wav), "audio/wav")},
)
assert upload_resp.status_code == 200
upload_id = upload_resp.json()["upload_id"]
resp = client.post("/api/record", json={
"duration": 2.0,
"record_audio": False,
"record_motion": True,
"uploaded_audio_id": upload_id,
"label": "with-audio",
})
assert resp.status_code == 200
data = resp.json()
assert data["accepted"] is True
assert marionette._pending_recording is not None
assert marionette._pending_recording.uploaded_audio_path is not None
# Cleanup
marionette._set_idle_state()
marionette._pending_recording = None
# ──────── Audio analysis unit tests ───────────────────────────────
class TestAudioAnalysis:
"""Unit tests for the audio analysis utilities (no hardware needed)."""
def test_generate_sync_test_audio(self):
from audio_analysis import generate_sync_test_audio
audio, beep_times = generate_sync_test_audio(sr=48000, duration=8.0)
assert audio.shape == (8 * 48000,)
assert len(beep_times) == 5
# Audio should have non-zero samples at beep positions
for t in beep_times:
idx = int(t * 48000)
assert np.max(np.abs(audio[idx : idx + 4800])) > 0.1
def test_detect_beep_onsets_synthetic(self):
from audio_analysis import detect_beep_onsets, generate_sync_test_audio
sr = 48000
audio, expected_times = generate_sync_test_audio(sr=sr, duration=8.0)
detected = detect_beep_onsets(audio, sr, freq=1000.0)
# Should detect all 5 beeps
assert len(detected) >= 4, f"Only detected {len(detected)}/5 beeps"
# Each detected onset should be within 100ms of an expected time
for dt in detected:
distances = [abs(dt - et) for et in expected_times]
assert min(distances) < 0.1, f"Detected onset {dt:.3f}s doesn't match any expected beep"
def test_generate_collision_trajectory(self):
from audio_analysis import generate_collision_trajectory
beep_times = [1.0, 2.5, 4.0]
timestamps, frames = generate_collision_trajectory(beep_times, duration=5.0)
assert len(timestamps) == 500 # 5s * 100Hz
assert len(frames) == 500
# At a beep time, antennas should be at 0 (collided)
idx_at_beep = int(1.0 * 100)
antennas = frames[idx_at_beep]["antennas"]
assert abs(antennas[0]) < 0.05, f"Antennas should be near 0 at collision: {antennas}"
# Between beeps, antennas should be apart
idx_between = int(1.5 * 100)
antennas = frames[idx_between]["antennas"]
assert abs(antennas[0]) > 0.1, f"Antennas should be apart between beeps: {antennas}"
def test_measure_sync_offsets(self):
from audio_analysis import measure_sync_offsets
beep_onsets = [1.0, 2.5, 4.0, 5.0, 7.0]
# Collisions arrive 50ms after each beep
collision_onsets = [1.05, 2.55, 4.05, 5.05, 7.05]
result = measure_sync_offsets(beep_onsets, collision_onsets)
assert result["n_matched"] == 5
assert abs(result["mean_offset_ms"] - 50.0) < 1.0
assert result["max_offset_ms"] < 55.0
def test_measure_sync_offsets_missing_collisions(self):
from audio_analysis import measure_sync_offsets
beep_onsets = [1.0, 2.5, 4.0, 5.0, 7.0]
collision_onsets = [1.02, 4.01] # Only 2 matched
result = measure_sync_offsets(beep_onsets, collision_onsets)
assert result["n_matched"] == 2
assert result["n_beeps"] == 5
def test_datasets_payload_sorted_by_label(self, marionette: Marionette, tmp_path: Path):
"""Verify _datasets_payload returns entries sorted by label."""
ds_root = marionette._dataset_root
(ds_root / "zzz_first_label").mkdir(exist_ok=True)
(ds_root / "aaa_last_label").mkdir(exist_ok=True)
marionette._load_dataset_registry()
payload = marionette._datasets_payload()
labels = [e.get("label", e["id"]) for e in payload["entries"]]
assert labels == sorted(labels, key=str.lower), f"Entries not sorted by label: {labels}"