import os import unittest from unittest.mock import patch, MagicMock # Create stubs for FastAPI if it's not available class MockFastAPI: def __init__(self, *args, **kwargs): self.middleware_stack = [] def add_middleware(self, middleware_class, **kwargs): self.middleware_stack.append((middleware_class, kwargs)) def on_event(self, event_type): return lambda f: f def get(self, path): return lambda f: f def post(self, path): return lambda f: f def websocket(self, path): return lambda f: f # Mock the runtime and fastapi modules with patch.dict("sys.modules", { "runtime.camera_capture": MagicMock(), "runtime.mic_capture": MagicMock(), "runtime.gemma_prompt_engine": MagicMock(), "runtime.magenta_generation": MagicMock(), "runtime.clip_scheduler": MagicMock(), "fastapi": MagicMock(), "fastapi.middleware.cors": MagicMock(), "uvicorn": MagicMock(), }): import fastapi import fastapi.middleware.cors fastapi.FastAPI = MockFastAPI import sys sys.path.append(os.getcwd()) from api.server import app class TestAPISecurity(unittest.TestCase): def test_cors_configuration(self): # Find CORSMiddleware in the stack cors_config = None for middleware_class, kwargs in app.middleware_stack: # We can't check middleware_class directly because it's a Mock # But we can check the kwargs if "allow_origins" in kwargs: cors_config = kwargs break self.assertIsNotNone(cors_config, "CORSMiddleware not found in app") # Default should be http://localhost:1420 self.assertEqual(cors_config["allow_origins"], ["http://localhost:1420"]) self.assertTrue(cors_config["allow_credentials"]) @patch.dict(os.environ, {"ALLOWED_ORIGINS": "http://myapp.com,http://another.com", "ALLOW_CREDENTIALS": "False"}) def test_cors_env_override(self): # We need to re-import or manually trigger the logic because api.server was already imported from ML_Pipeline.shared.env import apply_defaults env_config = apply_defaults() allowed_origins = [o.strip() for o in env_config.get("ALLOWED_ORIGINS").split(",")] allow_credentials = env_config.get("ALLOW_CREDENTIALS").lower() == "true" self.assertEqual(allowed_origins, ["http://myapp.com", "http://another.com"]) self.assertFalse(allow_credentials) if __name__ == "__main__": unittest.main()