import json import os import requests from copy import deepcopy from pathlib import Path import pytest import time api_key = os.environ.get("API_KEY") port = os.environ.get("PORT", 9001) base_url = os.environ.get("BASE_URL", "http://localhost") with open(os.path.join(Path(__file__).resolve().parent, "sam_tests.json"), "r") as f: TESTS = json.load(f) @pytest.mark.skip(reason="SAM testing is broken, to be fixed in future release") @pytest.mark.parametrize("test", TESTS) def test_sam(test): payload = deepcopy(test["payload"]) payload["api_key"] = api_key response = requests.post( f"{base_url}:{port}/sam/{test['type']}", json=payload, ) try: response.raise_for_status() data = response.json() if test["type"] == "embed_image": try: assert "embeddings" in data except: print(f"Invalid response: {data}, expected 'embeddings' in response") try: assert len(data["embeddings"]) == len( test["expected_response"]["embeddings"] ) except: print( f"Invalid response: {data}, expected length of embeddings to be {len(test['expected_response']['embeddings'])}, got {len(data['embeddings'])}" ) if test["type"] == "segment_image": try: assert "masks" in data except: print(f"Invalid response: {data}, expected 'masks' in response") try: assert data["masks"] == test["expected_response"]["masks"] except: print( f"Invalid response: {data}, expected masks to be {test['expected_response']['masks']}, got {data['masks']}" ) except Exception as e: raise e @pytest.fixture(scope="session", autouse=True) def setup(): try: res = requests.get(f"{base_url}:{port}") res.raise_for_status() success = True except: success = False waited = 0 while not success: print("Waiting for server to start...") time.sleep(5) waited += 5 try: res = requests.get(f"{base_url}:{port}") res.raise_for_status() success = True except: success = False if waited > 30: raise Exception("Test server failed to start") if __name__ == "__main__": test_sam()