File size: 2,538 Bytes
0490201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import unittest
from unittest.mock import patch, MagicMock

# Create stubs for FastAPI if it's not available
class MockFastAPI:
    def __init__(self, *args, **kwargs):
        self.middleware_stack = []
    def add_middleware(self, middleware_class, **kwargs):
        self.middleware_stack.append((middleware_class, kwargs))
    def on_event(self, event_type):
        return lambda f: f
    def get(self, path):
        return lambda f: f
    def post(self, path):
        return lambda f: f
    def websocket(self, path):
        return lambda f: f

# Mock the runtime and fastapi modules
with patch.dict("sys.modules", {
    "runtime.camera_capture": MagicMock(),
    "runtime.mic_capture": MagicMock(),
    "runtime.gemma_prompt_engine": MagicMock(),
    "runtime.magenta_generation": MagicMock(),
    "runtime.clip_scheduler": MagicMock(),
    "fastapi": MagicMock(),
    "fastapi.middleware.cors": MagicMock(),
    "uvicorn": MagicMock(),
}):
    import fastapi
    import fastapi.middleware.cors
    fastapi.FastAPI = MockFastAPI

    import sys
    sys.path.append(os.getcwd())
    from api.server import app

class TestAPISecurity(unittest.TestCase):
    def test_cors_configuration(self):
        # Find CORSMiddleware in the stack
        cors_config = None
        for middleware_class, kwargs in app.middleware_stack:
            # We can't check middleware_class directly because it's a Mock
            # But we can check the kwargs
            if "allow_origins" in kwargs:
                cors_config = kwargs
                break

        self.assertIsNotNone(cors_config, "CORSMiddleware not found in app")

        # Default should be http://localhost:1420
        self.assertEqual(cors_config["allow_origins"], ["http://localhost:1420"])
        self.assertTrue(cors_config["allow_credentials"])

    @patch.dict(os.environ, {"ALLOWED_ORIGINS": "http://myapp.com,http://another.com", "ALLOW_CREDENTIALS": "False"})
    def test_cors_env_override(self):
        # We need to re-import or manually trigger the logic because api.server was already imported
        from ML_Pipeline.shared.env import apply_defaults
        env_config = apply_defaults()

        allowed_origins = [o.strip() for o in env_config.get("ALLOWED_ORIGINS").split(",")]
        allow_credentials = env_config.get("ALLOW_CREDENTIALS").lower() == "true"

        self.assertEqual(allowed_origins, ["http://myapp.com", "http://another.com"])
        self.assertFalse(allow_credentials)

if __name__ == "__main__":
    unittest.main()