Synesthesia / tests /test_api_security.py
Ashiedu's picture
Sync unified workbench
0490201 verified
import os
import unittest
from unittest.mock import patch, MagicMock
# Create stubs for FastAPI if it's not available
class MockFastAPI:
def __init__(self, *args, **kwargs):
self.middleware_stack = []
def add_middleware(self, middleware_class, **kwargs):
self.middleware_stack.append((middleware_class, kwargs))
def on_event(self, event_type):
return lambda f: f
def get(self, path):
return lambda f: f
def post(self, path):
return lambda f: f
def websocket(self, path):
return lambda f: f
# Mock the runtime and fastapi modules
with patch.dict("sys.modules", {
"runtime.camera_capture": MagicMock(),
"runtime.mic_capture": MagicMock(),
"runtime.gemma_prompt_engine": MagicMock(),
"runtime.magenta_generation": MagicMock(),
"runtime.clip_scheduler": MagicMock(),
"fastapi": MagicMock(),
"fastapi.middleware.cors": MagicMock(),
"uvicorn": MagicMock(),
}):
import fastapi
import fastapi.middleware.cors
fastapi.FastAPI = MockFastAPI
import sys
sys.path.append(os.getcwd())
from api.server import app
class TestAPISecurity(unittest.TestCase):
def test_cors_configuration(self):
# Find CORSMiddleware in the stack
cors_config = None
for middleware_class, kwargs in app.middleware_stack:
# We can't check middleware_class directly because it's a Mock
# But we can check the kwargs
if "allow_origins" in kwargs:
cors_config = kwargs
break
self.assertIsNotNone(cors_config, "CORSMiddleware not found in app")
# Default should be http://localhost:1420
self.assertEqual(cors_config["allow_origins"], ["http://localhost:1420"])
self.assertTrue(cors_config["allow_credentials"])
@patch.dict(os.environ, {"ALLOWED_ORIGINS": "http://myapp.com,http://another.com", "ALLOW_CREDENTIALS": "False"})
def test_cors_env_override(self):
# We need to re-import or manually trigger the logic because api.server was already imported
from ML_Pipeline.shared.env import apply_defaults
env_config = apply_defaults()
allowed_origins = [o.strip() for o in env_config.get("ALLOWED_ORIGINS").split(",")]
allow_credentials = env_config.get("ALLOW_CREDENTIALS").lower() == "true"
self.assertEqual(allowed_origins, ["http://myapp.com", "http://another.com"])
self.assertFalse(allow_credentials)
if __name__ == "__main__":
unittest.main()