File size: 3,447 Bytes
9a22bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import unittest
import app as app_module
from app import app
import io
import time

class SecurityTestCase(unittest.TestCase):
    def setUp(self):
        app.config['TESTING'] = True
        self.client = app.test_client()
        # Reset rate limiting counts for each test
        app_module.upload_counts.clear()
        # Mock model_loaded to True to bypass model initialization in tests
        app_module.model_loaded = True

    def test_download_valid_filename(self):
        # Even if the file doesn't exist, it should pass the regex and reach send_from_directory (404)
        valid_uuid = "a" * 32
        response = self.client.get(f'/download/colorized_{valid_uuid}.png')
        self.assertEqual(response.status_code, 404)

    def test_download_invalid_format(self):
        # Too short UUID
        response = self.client.get('/download/colorized_abc123.png')
        self.assertEqual(response.status_code, 400)
        self.assertIn(b'Invalid filename format', response.data)

        # Missing prefix
        valid_uuid = "a" * 32
        response = self.client.get(f'/download/{valid_uuid}.png')
        self.assertEqual(response.status_code, 400)

        # Invalid extension
        response = self.client.get(f'/download/colorized_{valid_uuid}.exe')
        self.assertEqual(response.status_code, 400)

    def test_path_traversal_prevention(self):
        # Path traversal with encoded slashes might still be caught by Flask or return 400 due to regex
        response = self.client.get('/download/..%2f..%2fapp.py')
        # If it reaches the handler, it should be 400. If Flask blocks it early, it might be 404.
        self.assertIn(response.status_code, [400, 404])

        # Test a case that DEFINITELY reaches the handler but is malicious
        response = self.client.get('/download/colorized_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.png%00')
        self.assertEqual(response.status_code, 400)

    def test_rate_limiting(self):
        # The limit is 10 uploads per 60 seconds
        for _ in range(10):
            data = {'file': (io.BytesIO(b"fake image data"), 'test.png')}
            response = self.client.post('/upload', data=data, content_type='multipart/form-data')
            # It might return 400 or 500 depending on model state, but shouldn't be 429 yet
            self.assertNotEqual(response.status_code, 429)

        # 11th request should be rate limited
        data = {'file': (io.BytesIO(b"fake image data"), 'test.png')}
        response = self.client.post('/upload', data=data, content_type='multipart/form-data')
        self.assertEqual(response.status_code, 429)
        self.assertIn(b'Rate limit exceeded', response.data)

    def test_upload_invalid_magic_bytes(self):
        # Test JPEG extension with plain text content
        data = {'file': (io.BytesIO(b"this is not a jpeg"), 'test.jpg')}
        response = self.client.post('/upload', data=data, content_type='multipart/form-data')
        self.assertEqual(response.status_code, 400)
        self.assertIn(b'Invalid image content', response.data)

        # Test PNG extension with plain text content
        data = {'file': (io.BytesIO(b"this is not a png"), 'test.png')}
        response = self.client.post('/upload', data=data, content_type='multipart/form-data')
        self.assertEqual(response.status_code, 400)
        self.assertIn(b'Invalid image content', response.data)

if __name__ == '__main__':
    unittest.main()