keyshift-api / tests /test_api.py
balakrishna567's picture
feat: FastAPI /analyze — MERT boundaries + KS key detection + guitar info
925d02f
Raw
History Blame Contribute Delete
1.67 kB
import io, numpy as np, soundfile as sf, pytest
from unittest.mock import MagicMock, patch
def make_wav(duration=30.0, sr=22050) -> bytes:
t = np.linspace(0, duration, int(sr * duration))
y = (np.sin(2 * np.pi * 440 * t) * 0.5).astype(np.float32)
buf = io.BytesIO()
sf.write(buf, y, sr, format="WAV")
return buf.getvalue()
@pytest.fixture
def client():
import app.main as main_module
from fastapi.testclient import TestClient
mock_enc = MagicMock()
mock_enc.encode_batch.return_value = np.random.randn(11, 768)
with patch("app.main.MERTEncoder", return_value=mock_enc):
# Reset cached encoder so get_encoder() calls the patched MERTEncoder()
main_module._encoder = None
yield TestClient(main_module.app)
# Cleanup: reset encoder after test
main_module._encoder = None
def test_analyze_200(client):
r = client.post("/analyze", files={"file": ("t.wav", make_wav(), "audio/wav")})
assert r.status_code == 200
def test_response_structure(client):
data = client.post("/analyze", files={"file": ("t.wav", make_wav(), "audio/wav")}).json()
assert {"segments","duration","dominant_key","dominant_mode"} <= data.keys()
assert len(data["segments"]) > 0
def test_segment_has_guitar_info(client):
seg = client.post("/analyze", files={"file": ("t.wav", make_wav(), "audio/wav")}).json()["segments"][0]
assert len(seg["scale_notes"]) == 7
assert len(seg["pentatonic_notes"]) == 5
assert len(seg["fretboard_positions"]) > 0
def test_rejects_non_audio(client):
r = client.post("/analyze", files={"file": ("t.txt", b"nope", "text/plain")})
assert r.status_code == 422