Spaces:
Running
Running
| """Tier 1 β Backend unit tests for Marionette. | |
| Run without hardware, without daemon, in under 5 seconds. | |
| Tests the HTTP API layer, state machine, data validation, and persistence. | |
| """ | |
| import json | |
| import threading | |
| from io import BytesIO | |
| from pathlib import Path | |
| from unittest.mock import MagicMock, patch | |
| import numpy as np | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| from marionette.main import ( | |
| Marionette, | |
| RecordingRequest, | |
| _slugify, | |
| create_app, | |
| DEFAULT_DURATION, | |
| COUNTDOWN_SECONDS, | |
| MOTION_SAMPLE_RATE, | |
| DATASET_DATA_SUBDIR, | |
| ) | |
| # ββββββββ Utility function tests ββββββββββββββββββββββββββββββββββββββ | |
| class TestSlugify: | |
| def test_simple_lowercase(self): | |
| assert _slugify("Hello World") == "hello-world" | |
| def test_special_chars(self): | |
| assert _slugify("my@move#1!") == "my-move-1" | |
| def test_leading_trailing_hyphens(self): | |
| assert _slugify("---test---") == "test" | |
| def test_empty_string(self): | |
| assert _slugify("") == "take" | |
| def test_unicode(self): | |
| result = _slugify("cafΓ© rΓ©sumΓ©") | |
| assert result == "caf-r-sum" | |
| def test_already_slugified(self): | |
| assert _slugify("gentle-nod") == "gentle-nod" | |
| def test_numbers(self): | |
| assert _slugify("take 42") == "take-42" | |
| # ββββββββ State endpoint tests ββββββββββββββββββββββββββββββββββββββββ | |
| class TestStateEndpoint: | |
| def test_returns_200(self, client: TestClient): | |
| resp = client.get("/api/state") | |
| assert resp.status_code == 200 | |
| def test_initial_mode_is_idle(self, client: TestClient): | |
| data = client.get("/api/state").json() | |
| assert data["mode"] == "idle" | |
| def test_initial_message(self, client: TestClient): | |
| data = client.get("/api/state").json() | |
| assert data["message"] == "Ready to capture moves" | |
| def test_state_shape(self, client: TestClient): | |
| data = client.get("/api/state").json() | |
| required_keys = { | |
| "server_time", "mode", "message", "active_move", | |
| "phase_start_at", "phase_end_at", | |
| "countdown_ends_at", | |
| "recording_started_at", "recording_duration", "recording_stats", | |
| "pending_recording", "pending_playback", | |
| "moves", "config", "datasets", | |
| } | |
| assert required_keys.issubset(data.keys()) | |
| def test_server_time_present(self, client: TestClient): | |
| import time | |
| data = client.get("/api/state").json() | |
| assert isinstance(data["server_time"], float) | |
| # Should be close to current time (within 5 seconds) | |
| assert abs(data["server_time"] - time.time()) < 5.0 | |
| def test_idle_phase_timing_null(self, client: TestClient): | |
| data = client.get("/api/state").json() | |
| assert data["phase_start_at"] is None | |
| assert data["phase_end_at"] is None | |
| def test_config_shape(self, client: TestClient): | |
| config = client.get("/api/state").json()["config"] | |
| assert config["default_duration"] == DEFAULT_DURATION | |
| assert config["countdown_seconds"] == COUNTDOWN_SECONDS | |
| assert config["motion_sample_rate"] == MOTION_SAMPLE_RATE | |
| assert isinstance(config["audio_available"], bool) | |
| assert "motion_models" in config | |
| def test_initial_moves_empty(self, client: TestClient): | |
| data = client.get("/api/state").json() | |
| assert data["moves"] == [] | |
| def test_initial_no_pending(self, client: TestClient): | |
| data = client.get("/api/state").json() | |
| assert data["pending_recording"] is None | |
| assert data["pending_playback"] is None | |
| class TestStartingUpMode: | |
| def test_starting_up_rejects_recording(self, client: TestClient, marionette: Marionette): | |
| """Commands are rejected while the robot is still starting up.""" | |
| marionette._set_state(mode="starting_up", message="Starting upβ¦", active_move=None) | |
| resp = client.post("/api/record", json={"duration": 3.0, "record_audio": False}) | |
| assert resp.status_code == 409 | |
| marionette._set_idle_state() | |
| def test_starting_up_rejects_dataset_changes(self, client: TestClient, marionette: Marionette): | |
| marionette._set_state(mode="starting_up", message="Starting upβ¦", active_move=None) | |
| resp = client.post("/api/datasets", json={"name": "test"}) | |
| assert resp.status_code == 409 | |
| marionette._set_idle_state() | |
| def test_starting_up_state_visible(self, client: TestClient, marionette: Marionette): | |
| marionette._set_state(mode="starting_up", message="Starting upβ¦", active_move=None) | |
| data = client.get("/api/state").json() | |
| assert data["mode"] == "starting_up" | |
| assert "starting" in data["message"].lower() | |
| marionette._set_idle_state() | |
| # ββββββββ Recording endpoint tests ββββββββββββββββββββββββββββββββββββ | |
| class TestRecordEndpoint: | |
| def test_accept_basic_recording(self, client: TestClient): | |
| resp = client.post("/api/record", json={ | |
| "duration": 3.0, | |
| "record_audio": False, | |
| }) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["accepted"] is True | |
| assert "move_id" in data | |
| def test_mode_becomes_queued(self, client: TestClient): | |
| client.post("/api/record", json={"duration": 3.0, "record_audio": False}) | |
| state = client.get("/api/state").json() | |
| assert state["mode"] == "queued" | |
| def test_reject_when_busy(self, client: TestClient): | |
| # First recording is accepted | |
| resp1 = client.post("/api/record", json={"duration": 3.0, "record_audio": False}) | |
| assert resp1.status_code == 200 | |
| # Second recording is rejected (mode is now "queued") | |
| resp2 = client.post("/api/record", json={"duration": 3.0, "record_audio": False}) | |
| assert resp2.status_code == 409 | |
| def test_reject_invalid_duration_too_low(self, client: TestClient): | |
| resp = client.post("/api/record", json={"duration": 0.1, "record_audio": False}) | |
| assert resp.status_code == 422 # Pydantic validation | |
| def test_reject_invalid_duration_too_high(self, client: TestClient): | |
| resp = client.post("/api/record", json={"duration": 999.0, "record_audio": False}) | |
| assert resp.status_code == 422 | |
| def test_accept_duration_edge_cases(self, client: TestClient, marionette: Marionette): | |
| # Just above minimum | |
| resp = client.post("/api/record", json={"duration": 0.6, "record_audio": False}) | |
| assert resp.status_code == 200 | |
| # Reset for next test | |
| marionette._set_idle_state() | |
| marionette._pending_recording = None | |
| # At maximum | |
| resp = client.post("/api/record", json={"duration": 300.0, "record_audio": False}) | |
| assert resp.status_code == 200 | |
| def test_custom_label(self, client: TestClient): | |
| resp = client.post("/api/record", json={ | |
| "duration": 3.0, | |
| "record_audio": False, | |
| "label": "happy-dance", | |
| }) | |
| data = resp.json() | |
| assert data["label"] == "happy-dance" | |
| assert data["move_id"] == "happy-dance" | |
| def test_label_collision_appends_index( | |
| self, client: TestClient, marionette: Marionette, tmp_dataset_root: Path | |
| ): | |
| # Create a file that would collide | |
| data_dir = tmp_dataset_root / "local_dataset" / DATASET_DATA_SUBDIR | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| (data_dir / "happy-dance.json").write_text("{}") | |
| resp = client.post("/api/record", json={ | |
| "duration": 3.0, | |
| "record_audio": False, | |
| "label": "happy-dance", | |
| }) | |
| data = resp.json() | |
| assert data["move_id"] == "happy-dance-1" | |
| def test_preferred_duration_saved(self, client: TestClient, marionette: Marionette): | |
| client.post("/api/record", json={"duration": 7.5, "record_audio": False}) | |
| assert marionette._preferred_duration == 7.5 | |
| # ββββββββ Playback endpoint tests βββββββββββββββββββββββββββββββββββββ | |
| class TestPlayEndpoint: | |
| def test_reject_missing_move(self, client: TestClient): | |
| resp = client.post("/api/play", json={"move_id": "nonexistent"}) | |
| assert resp.status_code == 404 | |
| def test_accept_existing_move( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| # Write a move file to the dataset | |
| data_dir = marionette._dataset_dir | |
| move_path = data_dir / "test-move.json" | |
| move_path.write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| resp = client.post("/api/play", json={"move_id": "test-move"}) | |
| assert resp.status_code == 200 | |
| assert resp.json()["accepted"] is True | |
| def test_reject_play_when_busy( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "test-move.json").write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| # First play is accepted | |
| resp1 = client.post("/api/play", json={"move_id": "test-move"}) | |
| assert resp1.status_code == 200 | |
| # Second play is rejected (mode is queued) | |
| resp2 = client.post("/api/play", json={"move_id": "test-move"}) | |
| assert resp2.status_code == 409 | |
| # ββββββββ Stop endpoints ββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestStopEndpoints: | |
| def test_stop_playback_when_not_playing(self, client: TestClient): | |
| resp = client.post("/api/play/stop") | |
| assert resp.status_code == 200 | |
| assert resp.json()["stopped"] is False | |
| def test_stop_recording_when_not_recording(self, client: TestClient): | |
| resp = client.post("/api/record/stop") | |
| assert resp.status_code == 200 | |
| assert resp.json()["stopped"] is False | |
| def test_stop_recording_while_queued(self, client: TestClient, marionette: Marionette): | |
| """Stopping a queued recording should cancel it and return to idle.""" | |
| # Put the app in queued state by submitting a recording | |
| resp = client.post( | |
| "/api/record", | |
| json={"label": "test-queued", "duration": 5.0, "record_audio": False}, | |
| ) | |
| assert resp.status_code == 200 | |
| assert resp.json()["accepted"] is True | |
| # Verify we're in queued mode | |
| state = client.get("/api/state").json() | |
| assert state["mode"] == "queued" | |
| # Stop should work even in queued mode | |
| resp = client.post("/api/record/stop") | |
| assert resp.status_code == 200 | |
| assert resp.json()["stopped"] is True | |
| # Should be back to idle | |
| state = client.get("/api/state").json() | |
| assert state["mode"] == "idle" | |
| assert "cancelled" in state["message"].lower() | |
| # ββββββββ Move deletion tests βββββββββββββββββββββββββββββββββββββββββ | |
| class TestMoveDelete: | |
| def test_delete_existing_move( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| data_dir = marionette._dataset_dir | |
| move_path = data_dir / "to-delete.json" | |
| move_path.write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| resp = client.delete("/api/moves/to-delete") | |
| assert resp.status_code == 200 | |
| assert not move_path.exists() | |
| def test_delete_nonexistent_move(self, client: TestClient): | |
| resp = client.delete("/api/moves/nonexistent") | |
| assert resp.status_code == 404 | |
| def test_delete_removes_wav( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "with-audio.json").write_text(json.dumps(sample_move_json)) | |
| (data_dir / "with-audio.wav").write_bytes(b"RIFF" + b"\x00" * 100) | |
| marionette._refresh_recordings() | |
| client.delete("/api/moves/with-audio") | |
| assert not (data_dir / "with-audio.json").exists() | |
| assert not (data_dir / "with-audio.wav").exists() | |
| def test_delete_updates_move_list( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "test-move.json").write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| state_before = client.get("/api/state").json() | |
| assert len(state_before["moves"]) == 1 | |
| client.delete("/api/moves/test-move") | |
| state_after = client.get("/api/state").json() | |
| assert len(state_after["moves"]) == 0 | |
| # ββββββββ Dataset management tests ββββββββββββββββββββββββββββββββββββ | |
| class TestDatasets: | |
| def test_initial_default_dataset(self, client: TestClient): | |
| data = client.get("/api/state").json() | |
| datasets = data["datasets"] | |
| assert datasets["active_id"] is not None | |
| assert len(datasets["entries"]) >= 1 | |
| def test_create_dataset(self, client: TestClient): | |
| resp = client.post("/api/datasets", json={"name": "My Dances"}) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["status"] == "created" | |
| assert data["dataset"]["folder"] == "my-dances" | |
| def test_create_duplicate_dataset_rejected(self, client: TestClient): | |
| client.post("/api/datasets", json={"name": "dances"}) | |
| resp = client.post("/api/datasets", json={"name": "dances"}) | |
| assert resp.status_code == 409 | |
| def test_select_dataset(self, client: TestClient): | |
| # Create a second dataset | |
| resp = client.post("/api/datasets", json={"name": "second"}) | |
| dataset_id = resp.json()["dataset"]["id"] | |
| # Default is auto-selected after create, so select the original | |
| state = client.get("/api/state").json() | |
| original_id = [ | |
| e["id"] for e in state["datasets"]["entries"] | |
| if e["id"] != dataset_id | |
| ][0] | |
| resp = client.post("/api/datasets/select", json={"dataset_id": original_id}) | |
| assert resp.status_code == 200 | |
| def test_select_nonexistent_dataset(self, client: TestClient): | |
| resp = client.post("/api/datasets/select", json={"dataset_id": "fake"}) | |
| assert resp.status_code == 404 | |
| def test_dataset_root_change(self, client: TestClient, tmp_path: Path): | |
| new_root = tmp_path / "new_root" | |
| new_root.mkdir() | |
| resp = client.post("/api/datasets/root", json={"path": str(new_root)}) | |
| assert resp.status_code == 200 | |
| assert resp.json()["root_path"] == str(new_root) | |
| def test_default_dataset_origin_is_local(self, client: TestClient): | |
| state = client.get("/api/state").json() | |
| entries = state["datasets"]["entries"] | |
| assert all(e.get("origin") == "local" for e in entries) | |
| def test_created_dataset_origin_is_local(self, client: TestClient): | |
| resp = client.post("/api/datasets", json={"name": "my-local"}) | |
| assert resp.status_code == 200 | |
| assert resp.json()["dataset"]["origin"] == "local" | |
| def test_record_blocked_on_downloaded_dataset( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| """Recording should be rejected when the active dataset is downloaded.""" | |
| # Create a dataset and tag it as downloaded | |
| entry = marionette._create_dataset_internal("dl-test", "Downloaded Test", origin="downloaded") | |
| marionette._select_dataset(entry.dataset_id) | |
| marionette._refresh_recordings() | |
| resp = client.post("/api/record", json={"duration": 3.0, "record_audio": False}) | |
| assert resp.status_code == 409 | |
| assert "downloaded" in resp.json()["detail"].lower() | |
| def test_record_allowed_on_local_dataset( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| """Recording should be accepted when the active dataset is local.""" | |
| # Default dataset is local, just verify recording is accepted | |
| resp = client.post("/api/record", json={"duration": 3.0, "record_audio": False}) | |
| assert resp.status_code == 200 | |
| assert resp.json()["accepted"] is True | |
| def test_origin_persisted_in_registry( | |
| self, marionette: Marionette, tmp_registry: Path | |
| ): | |
| """Origin field should be saved and restored from registry.""" | |
| entry = marionette._create_dataset_internal("persisted-dl", "Persisted DL", origin="downloaded") | |
| marionette._save_dataset_registry() | |
| raw = json.loads(tmp_registry.read_text(encoding="utf-8")) | |
| ds_entry = next(d for d in raw["datasets"] if d["folder"] == "persisted-dl") | |
| assert ds_entry["origin"] == "downloaded" | |
| # ββββββββ Registry persistence tests ββββββββββββββββββββββββββββββββββ | |
| class TestRegistryPersistence: | |
| def test_registry_created_on_init(self, tmp_registry: Path, tmp_dataset_root: Path): | |
| create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root) | |
| assert tmp_registry.exists() | |
| data = json.loads(tmp_registry.read_text()) | |
| assert "active" in data | |
| assert "datasets" in data | |
| def test_registry_survives_restart(self, tmp_registry: Path, tmp_dataset_root: Path): | |
| # First instance creates a dataset | |
| app1, m1 = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root) | |
| client1 = TestClient(app1) | |
| client1.post("/api/datasets", json={"name": "persistent-ds"}) | |
| # Second instance should see it | |
| app2, m2 = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root) | |
| client2 = TestClient(app2) | |
| state = client2.get("/api/state").json() | |
| folders = [e["folder"] for e in state["datasets"]["entries"]] | |
| assert "persistent-ds" in folders | |
| def test_preferred_duration_persisted(self, tmp_registry: Path, tmp_dataset_root: Path): | |
| app1, m1 = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root) | |
| client1 = TestClient(app1) | |
| client1.post("/api/record", json={"duration": 8.5, "record_audio": False}) | |
| # Re-create and check | |
| _, m2 = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root) | |
| assert m2._preferred_duration == 8.5 | |
| # ββββββββ Moves list / refresh tests ββββββββββββββββββββββββββββββββββ | |
| class TestMovesRefresh: | |
| def test_moves_appear_after_file_creation( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "my-move.json").write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| state = client.get("/api/state").json() | |
| move_ids = [m["id"] for m in state["moves"]] | |
| assert "my-move" in move_ids | |
| def test_move_duration_computed_correctly( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "timed.json").write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| state = client.get("/api/state").json() | |
| move = next(m for m in state["moves"] if m["id"] == "timed") | |
| # 500 frames at 100Hz = 4.99s (last timestamp is 4.99) | |
| assert 4.5 < move["duration"] < 5.5 | |
| def test_move_has_audio_flag( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "audio-move.json").write_text(json.dumps(sample_move_json)) | |
| (data_dir / "audio-move.wav").write_bytes(b"RIFF" + b"\x00" * 100) | |
| marionette._refresh_recordings() | |
| state = client.get("/api/state").json() | |
| move = next(m for m in state["moves"] if m["id"] == "audio-move") | |
| assert move["has_audio"] is True | |
| def test_move_without_audio( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "silent-move.json").write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| state = client.get("/api/state").json() | |
| move = next(m for m in state["moves"] if m["id"] == "silent-move") | |
| assert move["has_audio"] is False | |
| # ββββββββ Settings tests βββββββββββββββββββββββββββ | |
| class TestExperiments: | |
| def test_motion_models_present(self, client: TestClient): | |
| data = client.get("/api/state").json() | |
| mm = data["config"]["motion_models"] | |
| assert mm["active"] == "lead_compensation" | |
| assert "lead_compensation" in mm["params"] | |
| def test_update_duration(self, client: TestClient): | |
| resp = client.post("/api/experiments", json={"duration_seconds": 10.0}) | |
| assert resp.status_code == 200 | |
| assert resp.json()["preferred_duration"] == 10.0 | |
| def test_no_changes(self, client: TestClient): | |
| resp = client.post("/api/experiments", json={}) | |
| assert resp.json()["status"] == "unchanged" | |
| class TestWelcomeMessages: | |
| def test_default_is_two(self, client: TestClient): | |
| """Default welcome_messages is 2 (full greeting).""" | |
| config = client.get("/api/state").json()["config"] | |
| assert config["welcome_messages"] == 2 | |
| def test_set_to_zero(self, client: TestClient): | |
| resp = client.post("/api/experiments", json={"welcome_messages": 0}) | |
| assert resp.status_code == 200 | |
| config = client.get("/api/state").json()["config"] | |
| assert config["welcome_messages"] == 0 | |
| def test_set_to_one(self, client: TestClient): | |
| resp = client.post("/api/experiments", json={"welcome_messages": 1}) | |
| assert resp.status_code == 200 | |
| config = client.get("/api/state").json()["config"] | |
| assert config["welcome_messages"] == 1 | |
| def test_set_to_two(self, client: TestClient): | |
| resp = client.post("/api/experiments", json={"welcome_messages": 2}) | |
| assert resp.status_code == 200 | |
| config = client.get("/api/state").json()["config"] | |
| assert config["welcome_messages"] == 2 | |
| def test_clamped_above_max(self, client: TestClient): | |
| """Values above 2 are rejected by Pydantic validation.""" | |
| resp = client.post("/api/experiments", json={"welcome_messages": 5}) | |
| assert resp.status_code == 422 | |
| def test_clamped_below_min(self, client: TestClient): | |
| """Negative values are rejected by Pydantic validation.""" | |
| resp = client.post("/api/experiments", json={"welcome_messages": -1}) | |
| assert resp.status_code == 422 | |
| def test_persisted_in_registry(self, client: TestClient, marionette: Marionette): | |
| """welcome_messages is saved to and loaded from the dataset registry.""" | |
| client.post("/api/experiments", json={"welcome_messages": 0}) | |
| reg_path = Path(marionette._registry_path) | |
| data = json.loads(reg_path.read_text()) | |
| assert data["welcome_messages"] == 0 | |
| # ββββββββ Sensor data dummy endpoint ββββββββββββββββββββββββββββββββββ | |
| class TestHfAutoLogin: | |
| def test_state_includes_hf_username_key(self, client: TestClient): | |
| config = client.get("/api/state").json()["config"] | |
| assert "hf_username" in config | |
| # Value is None when not logged in, or a string when logged in | |
| assert config["hf_username"] is None or isinstance(config["hf_username"], str) | |
| def test_auto_detected_username_in_state(self, client: TestClient, marionette: Marionette): | |
| """When HF login is detected, username appears in state.""" | |
| marionette._hf_checked = False | |
| marionette._hf_username = None | |
| import marionette.datasets as md | |
| original_whoami = md.hf_whoami | |
| md.hf_whoami = lambda: {"name": "testuser"} | |
| try: | |
| config = client.get("/api/state").json()["config"] | |
| assert config["hf_username"] == "testuser" | |
| finally: | |
| md.hf_whoami = original_whoami | |
| def test_cached_after_first_check(self, marionette: Marionette): | |
| """HF login check is cached after first call.""" | |
| import marionette.datasets as md | |
| call_count = 0 | |
| original_whoami = md.hf_whoami | |
| def counting_whoami(): | |
| nonlocal call_count | |
| call_count += 1 | |
| return {"name": "cached-user"} | |
| md.hf_whoami = counting_whoami | |
| marionette._hf_checked = False | |
| marionette._hf_username = None | |
| try: | |
| result1 = marionette._check_hf_login() | |
| result2 = marionette._check_hf_login() | |
| assert result1 == "cached-user" | |
| assert result2 == "cached-user" | |
| assert call_count == 1 | |
| finally: | |
| md.hf_whoami = original_whoami | |
| def test_sync_without_username_uses_autodetected( | |
| self, client: TestClient, marionette: Marionette, tmp_path: Path | |
| ): | |
| """Sync endpoint uses auto-detected username when none provided.""" | |
| import marionette.datasets as md | |
| original_whoami = md.hf_whoami | |
| marionette._hf_checked = False | |
| marionette._hf_username = None | |
| md.hf_whoami = lambda: {"name": "auto-user"} | |
| try: | |
| # No moves exist, so sync will fail with 400 (no moves found), | |
| # but we verify it gets past the username check | |
| resp = client.post("/api/datasets/sync", json={"move_ids": ["nonexistent"]}) | |
| # 404 = move not found (got past the username validation) | |
| assert resp.status_code == 404 | |
| finally: | |
| md.hf_whoami = original_whoami | |
| def test_sync_without_username_and_no_login_returns_400( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| """Sync without username and no HF login returns 400.""" | |
| import marionette.datasets as md | |
| original_whoami = md.hf_whoami | |
| original_get_token = md.hf_get_token | |
| marionette._hf_checked = False | |
| marionette._hf_username = None | |
| md.hf_whoami = None | |
| md.hf_get_token = lambda: None | |
| try: | |
| resp = client.post("/api/datasets/sync", json={"move_ids": ["some-move"]}) | |
| assert resp.status_code == 400 | |
| assert "not logged in" in resp.json()["detail"].lower() | |
| finally: | |
| md.hf_whoami = original_whoami | |
| md.hf_get_token = original_get_token | |
| class TestHfTokenLogin: | |
| def test_save_token_missing_prefix(self, client: TestClient): | |
| """Token that doesn't start with hf_ is rejected.""" | |
| resp = client.post("/api/hf-auth/save-token", json={"token": "bad_token_value"}) | |
| assert resp.status_code == 422 | |
| def test_save_token_too_short(self, client: TestClient): | |
| """Token shorter than 5 chars is rejected by Pydantic.""" | |
| resp = client.post("/api/hf-auth/save-token", json={"token": "hf_"}) | |
| assert resp.status_code == 422 | |
| def test_save_token_success(self, client: TestClient, marionette: Marionette): | |
| """Valid token saves and returns username.""" | |
| import marionette.datasets as md | |
| original_login = md.hf_login | |
| original_whoami = md.hf_whoami | |
| md.hf_login = lambda token, add_to_git_credential: None | |
| md.hf_whoami = lambda: {"name": "token-user"} | |
| try: | |
| resp = client.post("/api/hf-auth/save-token", json={"token": "hf_valid_token_12345"}) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["status"] == "logged_in" | |
| assert data["username"] == "token-user" | |
| # State should now reflect the username | |
| config = client.get("/api/state").json()["config"] | |
| assert config["hf_username"] == "token-user" | |
| finally: | |
| md.hf_login = original_login | |
| md.hf_whoami = original_whoami | |
| def test_delete_token_success(self, client: TestClient, marionette: Marionette): | |
| """Logout clears the cached username.""" | |
| import marionette.datasets as md | |
| original_logout = md.hf_logout | |
| md.hf_logout = lambda: None | |
| marionette._hf_username = "some-user" | |
| marionette._hf_checked = True | |
| try: | |
| resp = client.delete("/api/hf-auth/token") | |
| assert resp.status_code == 200 | |
| assert resp.json()["status"] == "logged_out" | |
| assert marionette._hf_username is None | |
| assert marionette._hf_checked is False | |
| finally: | |
| md.hf_logout = original_logout | |
| def test_delete_token_clears_state(self, client: TestClient, marionette: Marionette): | |
| """After logout, state reports no username.""" | |
| import marionette.datasets as md | |
| original_logout = md.hf_logout | |
| original_whoami = md.hf_whoami | |
| md.hf_logout = lambda: None | |
| md.hf_whoami = lambda: None # whoami returns None after logout | |
| marionette._hf_username = "old-user" | |
| marionette._hf_checked = True | |
| try: | |
| client.delete("/api/hf-auth/token") | |
| config = client.get("/api/state").json()["config"] | |
| assert config["hf_username"] is None | |
| finally: | |
| md.hf_logout = original_logout | |
| md.hf_whoami = original_whoami | |
| class TestSensorData: | |
| def test_returns_empty(self, client: TestClient): | |
| resp = client.get("/sensor_data") | |
| assert resp.status_code == 200 | |
| assert resp.json() == {} | |
| # ββββββββ Upload audio tests βββββββββββββββββββββββββββββββββββββββββ | |
| class TestUploadAudio: | |
| def test_upload_wav_returns_upload_id(self, client: TestClient): | |
| from conftest import make_wav_bytes | |
| wav = make_wav_bytes(1.0) | |
| resp = client.post( | |
| "/api/upload-audio", | |
| files={"file": ("test.wav", BytesIO(wav), "audio/wav")}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "upload_id" in data | |
| assert data["filename"] == "test.wav" | |
| def test_upload_wav_duration_extracted(self, client: TestClient): | |
| from conftest import make_wav_bytes | |
| wav = make_wav_bytes(2.0) | |
| resp = client.post( | |
| "/api/upload-audio", | |
| files={"file": ("two-sec.wav", BytesIO(wav), "audio/wav")}, | |
| ) | |
| assert resp.status_code == 200 | |
| duration = resp.json().get("duration") | |
| # Duration may be None if soundfile is unavailable, skip check in that case | |
| if duration is not None: | |
| assert abs(duration - 2.0) < 0.5 | |
| def test_upload_rejects_unsupported_format(self, client: TestClient): | |
| resp = client.post( | |
| "/api/upload-audio", | |
| files={"file": ("notes.txt", BytesIO(b"hello"), "text/plain")}, | |
| ) | |
| assert resp.status_code == 400 | |
| def test_upload_rejects_empty_filename(self, client: TestClient): | |
| resp = client.post( | |
| "/api/upload-audio", | |
| files={"file": ("", BytesIO(b"data"), "audio/wav")}, | |
| ) | |
| assert resp.status_code in (400, 422) | |
| def test_upload_mp3_accepted(self, client: TestClient): | |
| # A minimal fake MP3 β server accepts based on extension | |
| resp = client.post( | |
| "/api/upload-audio", | |
| files={"file": ("song.mp3", BytesIO(b"\xff\xfb\x90\x00" + b"\x00" * 100), "audio/mpeg")}, | |
| ) | |
| # Accepted (200) or 500 if soundfile can't parse β never 400 for extension | |
| assert resp.status_code in (200, 500) | |
| def test_uploaded_audio_id_usable_in_record(self, client: TestClient, marionette: Marionette): | |
| from conftest import make_wav_bytes | |
| wav = make_wav_bytes(1.0) | |
| upload_resp = client.post( | |
| "/api/upload-audio", | |
| files={"file": ("rec.wav", BytesIO(wav), "audio/wav")}, | |
| ) | |
| assert upload_resp.status_code == 200 | |
| upload_id = upload_resp.json()["upload_id"] | |
| resp = client.post("/api/record", json={ | |
| "duration": 3.0, | |
| "record_audio": False, | |
| "uploaded_audio_id": upload_id, | |
| }) | |
| assert resp.status_code == 200 | |
| assert resp.json()["accepted"] is True | |
| def test_record_with_invalid_upload_id(self, client: TestClient): | |
| resp = client.post("/api/record", json={ | |
| "duration": 3.0, | |
| "record_audio": False, | |
| "uploaded_audio_id": "nonexistent-uuid", | |
| }) | |
| assert resp.status_code == 400 | |
| class TestTimingSync: | |
| def test_uploaded_audio_waits_for_capture_start( | |
| self, marionette: Marionette, tmp_path: Path, monkeypatch: pytest.MonkeyPatch | |
| ): | |
| """Audio preload is now done in _perform_recording before countdown. | |
| _run_capture_and_save receives the pre-built audio thread and events. | |
| Verify the audio thread waits for audio_start (fired by on_capture_start).""" | |
| wav_path = tmp_path / "uploaded.wav" | |
| wav_path.write_bytes(b"fake") | |
| request = RecordingRequest( | |
| move_id="sync-test", | |
| label="sync-test", | |
| description="", | |
| duration=0.2, | |
| record_audio=False, | |
| record_motion=True, | |
| uploaded_audio_path=wav_path, | |
| ) | |
| events: list[tuple[str, bool | None]] = [] | |
| class _FakeReachy: | |
| class media: | |
| def get_output_audio_samplerate(): | |
| return 16000 | |
| audio_start = threading.Event() | |
| audio_stop = threading.Event() | |
| def fake_play_preloaded_wav( | |
| _reachy, _wav_data, _stop_event, chunk_duration=0.02, pipeline_ready=False, start_signal=None | |
| ): | |
| events.append(("audio_before_wait", start_signal.is_set() if start_signal else None)) | |
| if start_signal is not None: | |
| start_signal.wait(timeout=1.0) | |
| events.append(("audio_after_wait", start_signal.is_set() if start_signal else None)) | |
| audio_thread = threading.Thread( | |
| target=fake_play_preloaded_wav, | |
| args=(_FakeReachy(), (np.zeros(8, dtype=np.float32), 16000), audio_stop), | |
| kwargs={"pipeline_ready": True, "start_signal": audio_start}, | |
| daemon=True, | |
| ) | |
| audio_thread.start() | |
| def fake_capture_motion( | |
| _reachy, | |
| _stop_event, | |
| _duration, | |
| _record_audio, | |
| record_motion=True, | |
| on_capture_start=None, | |
| ): | |
| if on_capture_start is not None: | |
| on_capture_start() | |
| return [0.0, 0.01], [{"head": [], "antennas": [], "body_yaw": 0.0, "check_collision": False}], [], None | |
| monkeypatch.setattr(marionette, "_capture_motion", fake_capture_motion) | |
| monkeypatch.setattr(marionette, "_save_recording", lambda *_a, **_k: None) | |
| monkeypatch.setattr(marionette, "_refresh_recordings", lambda *_a, **_k: None) | |
| marionette._run_capture_and_save( | |
| _FakeReachy(), threading.Event(), request, | |
| audio_thread=audio_thread, | |
| audio_start=audio_start, | |
| audio_stop=audio_stop, | |
| ) | |
| assert ("audio_before_wait", False) in events | |
| assert ("audio_after_wait", True) in events | |
| def test_stream_playback_ignores_wall_clock_time( | |
| self, marionette: Marionette, monkeypatch: pytest.MonkeyPatch | |
| ): | |
| class _FakeMove: | |
| duration = 0.03 | |
| def evaluate(self, _t: float): | |
| return np.eye(4), [0.0, 0.0], 0.0 | |
| class _FakeReachy: | |
| def set_target_head_pose(self, _head): | |
| return None | |
| def set_target_body_yaw(self, _yaw: float): | |
| return None | |
| def set_target_antenna_joint_positions(self, _antennas): | |
| return None | |
| monkeypatch.setattr("marionette.recording.time.time", lambda: (_ for _ in ()).throw(RuntimeError("wall clock used"))) | |
| marionette._playback_cancel_event.clear() | |
| assert marionette._stream_playback(_FakeReachy(), _FakeMove()) is True | |
| # ββββββββ Lead compensation endpoint tests ββββββββββββββββββββββββββββββββ | |
| class TestMotionModelEndpoint: | |
| def test_update_lead_compensation_params(self, client: TestClient): | |
| resp = client.post( | |
| "/api/motion-model/lead", | |
| json={"lead_frames_head": 3, "lead_frames_antennas": 1}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["status"] == "updated" | |
| assert data["params"]["lead_frames_head"] == 3 | |
| assert data["params"]["lead_frames_antennas"] == 1 | |
| state = client.get("/api/state").json() | |
| params = state["config"]["motion_models"]["params"]["lead_compensation"] | |
| assert params["lead_frames_head"] == 3 | |
| assert params["lead_frames_antennas"] == 1 | |
| def test_lead_params_persisted_in_registry( | |
| self, tmp_registry: Path, tmp_dataset_root: Path | |
| ): | |
| app1, _ = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root) | |
| c1 = TestClient(app1) | |
| c1.post( | |
| "/api/motion-model/lead", | |
| json={"lead_frames_head": 4, "lead_frames_antennas": 2}, | |
| ) | |
| _, m2 = create_app(registry_path=tmp_registry, dataset_root=tmp_dataset_root) | |
| params = m2._motion_model_registry.get_model_params("lead_compensation") | |
| assert params["lead_frames_head"] == 4 | |
| assert params["lead_frames_antennas"] == 2 | |
| # ββββββββ Community datasets tests βββββββββββββββββββββββββββββββββββ | |
| class TestCommunityDatasets: | |
| def test_community_returns_empty_list(self, client: TestClient, marionette: Marionette): | |
| """When HTTP fetch returns empty, endpoint returns an empty list.""" | |
| original = marionette._fetch_community_datasets_http | |
| marionette._fetch_community_datasets_http = lambda: [] | |
| try: | |
| resp = client.get("/api/datasets/community") | |
| assert resp.status_code == 200 | |
| # May be empty list (HfApi also not available in test env) | |
| assert isinstance(resp.json()["datasets"], list) | |
| finally: | |
| marionette._fetch_community_datasets_http = original | |
| def test_download_rejects_invalid_repo_id(self, client: TestClient): | |
| resp = client.post("/api/datasets/download", json={"repo_id": "no-slash"}) | |
| assert resp.status_code == 400 | |
| def test_download_removes_existing_folder_on_redownload(self, client: TestClient, marionette: Marionette): | |
| # Create a dataset first | |
| client.post("/api/datasets", json={"name": "existing-ds"}) | |
| old_ids = {e.dataset_id for e in marionette._datasets.values() if e.folder == "existing-ds"} | |
| assert old_ids, "dataset should exist" | |
| # Re-download replaces the old entry (will fail at HF fetch, but the old entry | |
| # should already be removed before the network call). | |
| resp = client.post("/api/datasets/download", json={ | |
| "repo_id": "someone/existing-ds", | |
| "name": "existing-ds", | |
| }) | |
| # Network call fails β 502, but old entry was cleaned up | |
| assert resp.status_code == 502 | |
| remaining = {e.dataset_id for e in marionette._datasets.values() if e.folder == "existing-ds"} | |
| assert remaining.isdisjoint(old_ids), "old dataset entry should have been removed" | |
| # ββββββββ Corrupt data tests ββββββββββββββββββββββββββββββββββββββββ | |
| class TestCorruptData: | |
| def test_malformed_json_skipped(self, marionette: Marionette): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "bad-file.json").write_text("{{{", encoding="utf-8") | |
| marionette._refresh_recordings() | |
| assert "bad-file" not in marionette._recordings | |
| def test_json_missing_time_key(self, marionette: Marionette): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "no-time.json").write_text( | |
| json.dumps({"description": "test"}), encoding="utf-8" | |
| ) | |
| marionette._refresh_recordings() | |
| # File is loaded but with duration 0 (empty timestamps) | |
| if "no-time" in marionette._recordings: | |
| assert marionette._recordings["no-time"].duration == 0.0 | |
| def test_json_empty_time_array(self, marionette: Marionette): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "empty-time.json").write_text( | |
| json.dumps({"time": [], "set_target_data": []}), encoding="utf-8" | |
| ) | |
| marionette._refresh_recordings() | |
| if "empty-time" in marionette._recordings: | |
| assert marionette._recordings["empty-time"].duration == 0.0 | |
| def test_corrupt_registry_recovers(self, tmp_path: Path): | |
| reg = tmp_path / "corrupt_reg.json" | |
| reg.write_text("NOT VALID JSON {{{", encoding="utf-8") | |
| ds_root = tmp_path / "ds" | |
| ds_root.mkdir() | |
| app, m = create_app(registry_path=reg, dataset_root=ds_root) | |
| # Should have recovered with defaults | |
| assert m._active_dataset_id is not None | |
| assert len(m._datasets) >= 1 | |
| # ββββββββ Concurrent state change tests βββββββββββββββββββββββββββββ | |
| class TestConcurrentStateChanges: | |
| def test_play_while_queued_rejected( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "play-test.json").write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| # Submit a recording to enter queued state | |
| client.post("/api/record", json={"duration": 3.0, "record_audio": False}) | |
| assert client.get("/api/state").json()["mode"] == "queued" | |
| resp = client.post("/api/play", json={"move_id": "play-test"}) | |
| assert resp.status_code == 409 | |
| def test_record_while_playing_rejected( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| marionette._set_state(mode="playing", message="Playingβ¦", active_move="x") | |
| resp = client.post("/api/record", json={"duration": 3.0, "record_audio": False}) | |
| assert resp.status_code == 409 | |
| marionette._set_idle_state() | |
| def test_sync_while_busy_rejected( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| marionette._set_state(mode="recording", message="Recordingβ¦", active_move=None) | |
| resp = client.post("/api/datasets/sync", json={"move_ids": ["x"]}) | |
| assert resp.status_code == 409 | |
| marionette._set_idle_state() | |
| def test_dataset_root_change_while_busy( | |
| self, client: TestClient, marionette: Marionette, tmp_path: Path | |
| ): | |
| marionette._set_state(mode="recording", message="Recordingβ¦", active_move=None) | |
| resp = client.post("/api/datasets/root", json={"path": str(tmp_path)}) | |
| assert resp.status_code == 409 | |
| marionette._set_idle_state() | |
| # ββββββββ Duration edge case tests ββββββββββββββββββββββββββββββββββ | |
| class TestDurationEdgeCases: | |
| def test_duration_just_above_minimum(self, client: TestClient): | |
| resp = client.post("/api/record", json={"duration": 0.51, "record_audio": False}) | |
| assert resp.status_code == 200 | |
| def test_duration_at_maximum(self, client: TestClient, marionette: Marionette): | |
| resp = client.post("/api/record", json={"duration": 300.0, "record_audio": False}) | |
| assert resp.status_code == 200 | |
| def test_duration_at_minimum_boundary_rejected(self, client: TestClient): | |
| """Pydantic field has gt=0.5, so exactly 0.5 should be rejected.""" | |
| resp = client.post("/api/record", json={"duration": 0.5, "record_audio": False}) | |
| assert resp.status_code == 422 | |
| # ββββββββ Sync dataset extended tests βββββββββββββββββββββββββββββββ | |
| class TestSyncDatasetExtended: | |
| def test_sync_empty_move_ids_rejected(self, client: TestClient): | |
| """Pydantic min_items=1 should reject empty move_ids.""" | |
| resp = client.post("/api/datasets/sync", json={"move_ids": []}) | |
| assert resp.status_code == 422 | |
| def test_sync_nonexistent_moves(self, client: TestClient, marionette: Marionette): | |
| import marionette.datasets as md | |
| original_whoami = md.hf_whoami | |
| marionette._hf_checked = False | |
| marionette._hf_username = None | |
| md.hf_whoami = lambda: {"name": "testuser"} | |
| try: | |
| resp = client.post("/api/datasets/sync", json={ | |
| "move_ids": ["fake-move-id"], | |
| }) | |
| assert resp.status_code == 404 | |
| finally: | |
| md.hf_whoami = original_whoami | |
| def test_sync_no_active_dataset(self, client: TestClient, marionette: Marionette): | |
| import marionette.datasets as md | |
| original_whoami = md.hf_whoami | |
| md.hf_whoami = lambda: {"name": "testuser"} | |
| marionette._hf_checked = False | |
| marionette._hf_username = None | |
| # Save and clear active dataset | |
| old_id = marionette._active_dataset_id | |
| marionette._active_dataset_id = None | |
| # Clear recordings to avoid "move not found" before "no active dataset" | |
| marionette._recordings = {} | |
| try: | |
| resp = client.post("/api/datasets/sync", json={ | |
| "move_ids": ["any-move"], | |
| }) | |
| # Should be 404 for the move not found (since recordings is empty) | |
| assert resp.status_code == 404 | |
| finally: | |
| marionette._active_dataset_id = old_id | |
| md.hf_whoami = original_whoami | |
| def test_record_on_downloaded_dataset_rejected( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| entry = marionette._create_dataset_internal("dl-sync-test", "DL Sync Test", origin="downloaded") | |
| marionette._select_dataset(entry.dataset_id) | |
| marionette._refresh_recordings() | |
| resp = client.post("/api/record", json={"duration": 3.0, "record_audio": False}) | |
| assert resp.status_code == 409 | |
| assert "downloaded" in resp.json()["detail"].lower() | |
| # ββββββββ API contract tests (refactoring protection) ββββββββββββββ | |
| class TestApiContracts: | |
| """Verify exact response shapes that the frontend depends on. | |
| These tests protect against accidental field renames or removals | |
| during frontend refactoring. | |
| """ | |
| def test_state_top_level_keys(self, client: TestClient): | |
| data = client.get("/api/state").json() | |
| required = { | |
| "server_time", "mode", "message", "active_move", | |
| "phase_start_at", "phase_end_at", | |
| "countdown_ends_at", | |
| "recording_started_at", "recording_duration", "recording_stats", | |
| "pending_recording", "pending_playback", | |
| "moves", "config", "datasets", | |
| } | |
| assert required == required.intersection(data.keys()), ( | |
| f"Missing keys: {required - data.keys()}" | |
| ) | |
| def test_config_keys(self, client: TestClient): | |
| config = client.get("/api/state").json()["config"] | |
| required = { | |
| "default_duration", "preferred_duration", "countdown_seconds", | |
| "motion_sample_rate", "audio_available", | |
| "active_dataset_path", "dataset_root_path", | |
| "hf_username", "motion_models", | |
| } | |
| assert required.issubset(config.keys()), ( | |
| f"Missing config keys: {required - config.keys()}" | |
| ) | |
| def test_datasets_payload_shape(self, client: TestClient): | |
| datasets = client.get("/api/state").json()["datasets"] | |
| assert "active_id" in datasets | |
| assert "root_path" in datasets | |
| assert "entries" in datasets | |
| assert isinstance(datasets["entries"], list) | |
| def test_dataset_entry_keys(self, client: TestClient): | |
| entries = client.get("/api/state").json()["datasets"]["entries"] | |
| assert len(entries) >= 1, "Should have at least one default dataset" | |
| entry = entries[0] | |
| required = {"id", "label", "path", "folder", "origin"} | |
| assert required.issubset(entry.keys()), ( | |
| f"Missing dataset entry keys: {required - entry.keys()}" | |
| ) | |
| def test_move_payload_keys( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "contract-test.json").write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| moves = client.get("/api/state").json()["moves"] | |
| assert len(moves) >= 1 | |
| move = moves[0] | |
| required = { | |
| "id", "label", "duration", "created_at", | |
| "has_audio", "description", "is_uploaded", | |
| } | |
| assert required.issubset(move.keys()), ( | |
| f"Missing move keys: {required - move.keys()}" | |
| ) | |
| def test_error_responses_have_detail(self, client: TestClient): | |
| """All 4xx responses should include a 'detail' field.""" | |
| # 409: record while busy | |
| from marionette.main import Marionette as M | |
| # 422: invalid payload | |
| resp = client.post("/api/record", json={"duration": -1}) | |
| assert resp.status_code == 422 | |
| assert "detail" in resp.json() | |
| # 400: play without move_id | |
| resp = client.post("/api/play", json={}) | |
| assert resp.status_code == 422 | |
| assert "detail" in resp.json() | |
| def test_record_response_shape(self, client: TestClient): | |
| resp = client.post("/api/record", json={ | |
| "duration": 2.0, "record_audio": False, | |
| }) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "accepted" in data | |
| assert "move_id" in data | |
| assert "label" in data | |
| def test_create_dataset_response_shape(self, client: TestClient): | |
| resp = client.post("/api/datasets", json={"name": "contract-ds"}) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "status" in data | |
| assert "dataset" in data | |
| ds = data["dataset"] | |
| assert "id" in ds | |
| assert "label" in ds | |
| assert "folder" in ds | |
| def test_experiments_response_shape(self, client: TestClient): | |
| resp = client.post("/api/experiments", json={"duration_seconds": 7.0}) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "status" in data | |
| assert "preferred_duration" in data | |
| # ββββββββ Audio-only recording tests βββββββββββββββββββββββββββββββββ | |
| class TestAudioOnlyRecording: | |
| """Tests for the audio-only (mic-only, no motion) recording mode.""" | |
| def test_record_motion_defaults_true(self, client: TestClient): | |
| resp = client.post("/api/record", json={"duration": 2.0, "record_audio": False}) | |
| assert resp.status_code == 200 | |
| # Mode transitions to queued (record_motion defaults to True) | |
| state = client.get("/api/state").json() | |
| assert state["mode"] == "queued" | |
| def test_record_motion_false_accepted(self, client: TestClient, marionette: Marionette): | |
| resp = client.post("/api/record", json={ | |
| "duration": 2.0, | |
| "record_audio": True, | |
| "record_motion": False, | |
| }) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["accepted"] is True | |
| # The pending request should have record_motion=False | |
| assert marionette._pending_recording is not None | |
| assert marionette._pending_recording.record_motion is False | |
| # Cleanup | |
| marionette._set_idle_state() | |
| marionette._pending_recording = None | |
| def test_audio_only_save_has_flag( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| """Simulate saving an audio-only recording and verify the JSON has audio_only flag.""" | |
| from marionette.main import RecordingRequest | |
| req = RecordingRequest( | |
| move_id="audio-only-test", | |
| label="Audio Only Test", | |
| description="test", | |
| duration=2.0, | |
| record_audio=True, | |
| record_motion=False, | |
| ) | |
| timestamps = [i * 0.01 for i in range(200)] | |
| frames = [] # No motion frames | |
| marionette._save_recording(req, timestamps, frames, [], None) | |
| json_path = marionette._dataset_dir / "audio-only-test.json" | |
| assert json_path.exists() | |
| data = json.loads(json_path.read_text()) | |
| assert data.get("audio_only") is True | |
| assert data["set_target_data"] == [] | |
| assert len(data["time"]) == 200 | |
| # Cleanup | |
| json_path.unlink(missing_ok=True) | |
| def test_audio_only_metadata_flag( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| """Verify _refresh_recordings reads audio_only flag correctly.""" | |
| data_dir = marionette._dataset_dir | |
| move_data = { | |
| "description": "audio only test", | |
| "audio_only": True, | |
| "time": [0.0, 0.01, 0.02], | |
| "set_target_data": [], | |
| } | |
| (data_dir / "ao-test.json").write_text(json.dumps(move_data)) | |
| marionette._refresh_recordings() | |
| state = client.get("/api/state").json() | |
| ao_move = next((m for m in state["moves"] if m["id"] == "ao-test"), None) | |
| assert ao_move is not None | |
| assert ao_move["audio_only"] is True | |
| # Cleanup | |
| (data_dir / "ao-test.json").unlink(missing_ok=True) | |
| marionette._refresh_recordings() | |
| def test_normal_recording_no_audio_only_flag( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| """Normal recordings should not have audio_only flag.""" | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "normal-test.json").write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| state = client.get("/api/state").json() | |
| normal = next((m for m in state["moves"] if m["id"] == "normal-test"), None) | |
| assert normal is not None | |
| assert normal["audio_only"] is False | |
| # Cleanup | |
| (data_dir / "normal-test.json").unlink(missing_ok=True) | |
| marionette._refresh_recordings() | |
| # ββββββββ Stop endpoint tests (extended) βββββββββββββββββββββββββββββ | |
| class TestStopEndpointsExtended: | |
| """Additional stop endpoint tests for edge cases.""" | |
| def test_state_is_idle_after_record_stop(self, client: TestClient, marionette: Marionette): | |
| """After /api/record/stop, /api/state should return mode=idle.""" | |
| # Queue a recording | |
| resp = client.post("/api/record", json={"duration": 5.0, "record_audio": False}) | |
| assert resp.status_code == 200 | |
| # Stop it | |
| client.post("/api/record/stop") | |
| # Manually set idle since there's no robot thread to process | |
| marionette._set_idle_state() | |
| marionette._pending_recording = None | |
| state = client.get("/api/state").json() | |
| assert state["mode"] == "idle" | |
| def test_state_is_idle_after_play_stop(self, client: TestClient, marionette: Marionette): | |
| """After /api/play/stop, /api/state should return mode=idle.""" | |
| state = client.get("/api/state").json() | |
| assert state["mode"] == "idle" | |
| # Stop when already idle should be harmless | |
| resp = client.post("/api/play/stop") | |
| assert resp.status_code == 200 | |
| state = client.get("/api/state").json() | |
| assert state["mode"] == "idle" | |
| # ββββββββ Version endpoint tests βββββββββββββββββββββββββββββββββββββ | |
| class TestVersionEndpoint: | |
| def test_version_returns_200(self, client: TestClient): | |
| resp = client.get("/api/version") | |
| assert resp.status_code == 200 | |
| def test_version_has_required_fields(self, client: TestClient): | |
| data = client.get("/api/version").json() | |
| assert "version" in data | |
| assert "platform" in data | |
| assert "hostname" in data | |
| def test_version_is_string(self, client: TestClient): | |
| data = client.get("/api/version").json() | |
| assert isinstance(data["version"], str) | |
| assert len(data["version"]) > 0 | |
| # ββββββββ Robot audio endpoint tests βββββββββββββββββββββββββββββ | |
| class TestRobotAudioEndpoints: | |
| """Tests for GET /api/robot-audio and POST /api/robot-audio/select.""" | |
| def test_robot_audio_list_empty_by_default(self, client: TestClient): | |
| resp = client.get("/api/robot-audio") | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "files" in data | |
| assert isinstance(data["files"], list) | |
| def test_robot_audio_list_finds_audio_only_wav( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| """Audio-only recordings appear in robot-audio list.""" | |
| from conftest import make_wav_bytes | |
| import json | |
| wav_bytes = make_wav_bytes(1.0) | |
| wav_path = marionette._dataset_dir / "robot-test.wav" | |
| wav_path.write_bytes(wav_bytes) | |
| json_path = marionette._dataset_dir / "robot-test.json" | |
| json_path.write_text(json.dumps({ | |
| "time": [0.0], "set_target_data": [], "audio_only": True, | |
| })) | |
| marionette._refresh_recordings() | |
| resp = client.get("/api/robot-audio") | |
| assert resp.status_code == 200 | |
| files = resp.json()["files"] | |
| names = [f["name"] for f in files] | |
| assert "robot-test" in names | |
| # Each file should have name, path, duration_seconds | |
| for f in files: | |
| assert "name" in f | |
| assert "path" in f | |
| assert "duration_seconds" in f | |
| # Cleanup | |
| wav_path.unlink(missing_ok=True) | |
| json_path.unlink(missing_ok=True) | |
| marionette._refresh_recordings() | |
| def test_robot_audio_excludes_non_audio_only( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| """Normal motion moves with WAV files should NOT appear in robot-audio list.""" | |
| from conftest import make_wav_bytes | |
| import json | |
| wav_bytes = make_wav_bytes(1.0) | |
| wav_path = marionette._dataset_dir / "motion-move.wav" | |
| wav_path.write_bytes(wav_bytes) | |
| json_path = marionette._dataset_dir / "motion-move.json" | |
| json_path.write_text(json.dumps({ | |
| "time": [0.0, 0.01], | |
| "set_target_data": [ | |
| {"head": [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]], "antennas": [0,0]}, | |
| {"head": [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]], "antennas": [0,0]}, | |
| ], | |
| })) | |
| marionette._refresh_recordings() | |
| resp = client.get("/api/robot-audio") | |
| assert resp.status_code == 200 | |
| files = resp.json()["files"] | |
| names = [f["name"] for f in files] | |
| assert "motion-move" not in names | |
| # Cleanup | |
| wav_path.unlink(missing_ok=True) | |
| json_path.unlink(missing_ok=True) | |
| marionette._refresh_recordings() | |
| def test_robot_audio_select_returns_upload_id( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| """POST /api/robot-audio/select with valid WAV returns upload_id.""" | |
| from conftest import make_wav_bytes | |
| wav_bytes = make_wav_bytes(1.5) | |
| wav_path = marionette._dataset_dir / "selectable.wav" | |
| wav_path.write_bytes(wav_bytes) | |
| resp = client.post( | |
| "/api/robot-audio/select", | |
| json={"path": str(wav_path)}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "upload_id" in data | |
| assert data["filename"] == "selectable.wav" | |
| # Cleanup | |
| wav_path.unlink(missing_ok=True) | |
| def test_robot_audio_select_rejects_path_traversal(self, client: TestClient): | |
| """Path outside dataset/temp dirs is rejected with 403.""" | |
| resp = client.post( | |
| "/api/robot-audio/select", | |
| json={"path": "/etc/passwd"}, | |
| ) | |
| assert resp.status_code in (403, 404) | |
| def test_robot_audio_select_rejects_nonexistent(self, client: TestClient): | |
| """Non-existent file returns 404.""" | |
| resp = client.post( | |
| "/api/robot-audio/select", | |
| json={"path": "/tmp/does-not-exist-at-all.wav"}, | |
| ) | |
| assert resp.status_code in (403, 404) | |
| def test_robot_audio_select_rejects_non_wav( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| """Non-WAV file within allowed dirs is rejected.""" | |
| txt_path = marionette._dataset_dir / "notes.txt" | |
| txt_path.write_text("not audio") | |
| resp = client.post( | |
| "/api/robot-audio/select", | |
| json={"path": str(txt_path)}, | |
| ) | |
| assert resp.status_code == 400 | |
| # Cleanup | |
| txt_path.unlink(missing_ok=True) | |
| # ββββββββ Move metadata tests ββββββββββββββββββββββββββββββββββββ | |
| class TestMoveMetadata: | |
| """Tests for move metadata fields in the state endpoint.""" | |
| def test_move_created_at_is_float( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| """created_at should be a numeric Unix timestamp.""" | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "meta-test.json").write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| state = client.get("/api/state").json() | |
| move = next((m for m in state["moves"] if m["id"] == "meta-test"), None) | |
| assert move is not None | |
| assert isinstance(move["created_at"], float) | |
| assert move["created_at"] > 0 | |
| # Cleanup | |
| (data_dir / "meta-test.json").unlink(missing_ok=True) | |
| marionette._refresh_recordings() | |
| def test_move_created_at_is_recent( | |
| self, client: TestClient, marionette: Marionette, sample_move_json: dict | |
| ): | |
| """A freshly created move should have a recent timestamp.""" | |
| import time | |
| data_dir = marionette._dataset_dir | |
| (data_dir / "recent-test.json").write_text(json.dumps(sample_move_json)) | |
| marionette._refresh_recordings() | |
| state = client.get("/api/state").json() | |
| move = next((m for m in state["moves"] if m["id"] == "recent-test"), None) | |
| assert move is not None | |
| # Should be within last 60 seconds | |
| assert abs(move["created_at"] - time.time()) < 60 | |
| # Cleanup | |
| (data_dir / "recent-test.json").unlink(missing_ok=True) | |
| marionette._refresh_recordings() | |
| def test_audio_only_move_in_robot_audio_list( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| """Audio-only recordings with WAV files should appear in robot-audio list.""" | |
| from conftest import make_wav_bytes | |
| data_dir = marionette._dataset_dir | |
| move_data = { | |
| "description": "audio only for robot list", | |
| "audio_only": True, | |
| "time": [0.0, 0.01, 0.02], | |
| "set_target_data": [], | |
| } | |
| (data_dir / "ao-robot.json").write_text(json.dumps(move_data)) | |
| (data_dir / "ao-robot.wav").write_bytes(make_wav_bytes(1.0)) | |
| marionette._refresh_recordings() | |
| resp = client.get("/api/robot-audio") | |
| assert resp.status_code == 200 | |
| names = [f["name"] for f in resp.json()["files"]] | |
| assert "ao-robot" in names | |
| # Cleanup | |
| (data_dir / "ao-robot.json").unlink(missing_ok=True) | |
| (data_dir / "ao-robot.wav").unlink(missing_ok=True) | |
| marionette._refresh_recordings() | |
| def test_record_with_uploaded_audio_accepted( | |
| self, client: TestClient, marionette: Marionette | |
| ): | |
| """Upload audio, then start a recording using that audio ID.""" | |
| from conftest import make_wav_bytes | |
| wav = make_wav_bytes(2.0) | |
| upload_resp = client.post( | |
| "/api/upload-audio", | |
| files={"file": ("track.wav", BytesIO(wav), "audio/wav")}, | |
| ) | |
| assert upload_resp.status_code == 200 | |
| upload_id = upload_resp.json()["upload_id"] | |
| resp = client.post("/api/record", json={ | |
| "duration": 2.0, | |
| "record_audio": False, | |
| "record_motion": True, | |
| "uploaded_audio_id": upload_id, | |
| "label": "with-audio", | |
| }) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["accepted"] is True | |
| assert marionette._pending_recording is not None | |
| assert marionette._pending_recording.uploaded_audio_path is not None | |
| # Cleanup | |
| marionette._set_idle_state() | |
| marionette._pending_recording = None | |
| # ββββββββ Audio analysis unit tests βββββββββββββββββββββββββββββββ | |
| class TestAudioAnalysis: | |
| """Unit tests for the audio analysis utilities (no hardware needed).""" | |
| def test_generate_sync_test_audio(self): | |
| from audio_analysis import generate_sync_test_audio | |
| audio, beep_times = generate_sync_test_audio(sr=48000, duration=8.0) | |
| assert audio.shape == (8 * 48000,) | |
| assert len(beep_times) == 5 | |
| # Audio should have non-zero samples at beep positions | |
| for t in beep_times: | |
| idx = int(t * 48000) | |
| assert np.max(np.abs(audio[idx : idx + 4800])) > 0.1 | |
| def test_detect_beep_onsets_synthetic(self): | |
| from audio_analysis import detect_beep_onsets, generate_sync_test_audio | |
| sr = 48000 | |
| audio, expected_times = generate_sync_test_audio(sr=sr, duration=8.0) | |
| detected = detect_beep_onsets(audio, sr, freq=1000.0) | |
| # Should detect all 5 beeps | |
| assert len(detected) >= 4, f"Only detected {len(detected)}/5 beeps" | |
| # Each detected onset should be within 100ms of an expected time | |
| for dt in detected: | |
| distances = [abs(dt - et) for et in expected_times] | |
| assert min(distances) < 0.1, f"Detected onset {dt:.3f}s doesn't match any expected beep" | |
| def test_generate_collision_trajectory(self): | |
| from audio_analysis import generate_collision_trajectory | |
| beep_times = [1.0, 2.5, 4.0] | |
| timestamps, frames = generate_collision_trajectory(beep_times, duration=5.0) | |
| assert len(timestamps) == 500 # 5s * 100Hz | |
| assert len(frames) == 500 | |
| # At a beep time, antennas should be at 0 (collided) | |
| idx_at_beep = int(1.0 * 100) | |
| antennas = frames[idx_at_beep]["antennas"] | |
| assert abs(antennas[0]) < 0.05, f"Antennas should be near 0 at collision: {antennas}" | |
| # Between beeps, antennas should be apart | |
| idx_between = int(1.5 * 100) | |
| antennas = frames[idx_between]["antennas"] | |
| assert abs(antennas[0]) > 0.1, f"Antennas should be apart between beeps: {antennas}" | |
| def test_measure_sync_offsets(self): | |
| from audio_analysis import measure_sync_offsets | |
| beep_onsets = [1.0, 2.5, 4.0, 5.0, 7.0] | |
| # Collisions arrive 50ms after each beep | |
| collision_onsets = [1.05, 2.55, 4.05, 5.05, 7.05] | |
| result = measure_sync_offsets(beep_onsets, collision_onsets) | |
| assert result["n_matched"] == 5 | |
| assert abs(result["mean_offset_ms"] - 50.0) < 1.0 | |
| assert result["max_offset_ms"] < 55.0 | |
| def test_measure_sync_offsets_missing_collisions(self): | |
| from audio_analysis import measure_sync_offsets | |
| beep_onsets = [1.0, 2.5, 4.0, 5.0, 7.0] | |
| collision_onsets = [1.02, 4.01] # Only 2 matched | |
| result = measure_sync_offsets(beep_onsets, collision_onsets) | |
| assert result["n_matched"] == 2 | |
| assert result["n_beeps"] == 5 | |
| def test_datasets_payload_sorted_by_label(self, marionette: Marionette, tmp_path: Path): | |
| """Verify _datasets_payload returns entries sorted by label.""" | |
| ds_root = marionette._dataset_root | |
| (ds_root / "zzz_first_label").mkdir(exist_ok=True) | |
| (ds_root / "aaa_last_label").mkdir(exist_ok=True) | |
| marionette._load_dataset_registry() | |
| payload = marionette._datasets_payload() | |
| labels = [e.get("label", e["id"]) for e in payload["entries"]] | |
| assert labels == sorted(labels, key=str.lower), f"Entries not sorted by label: {labels}" | |