RCaz commited on
Commit
6d64ea5
·
1 Parent(s): e4f3020

added usage restriction

Browse files
agent/restrict_usage.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from datetime import datetime, timedelta
3
+
4
+ # Rate limiter class
5
+ class RateLimiter:
6
+ def __init__(self, max_requests=10, window_minutes=60):
7
+ self.max_requests = max_requests
8
+ self.window = timedelta(minutes=window_minutes)
9
+ self.requests = defaultdict(list)
10
+
11
+ def is_allowed(self, identifier):
12
+ now = datetime.now()
13
+ # Clean old requests
14
+ self.requests[identifier] = [
15
+ req_time for req_time in self.requests[identifier]
16
+ if now - req_time < self.window
17
+ ]
18
+
19
+ if len(self.requests[identifier]) < self.max_requests:
20
+ self.requests[identifier].append(now)
21
+ return True
22
+ return False
23
+
24
+ def get_remaining(self, identifier):
25
+ now = datetime.now()
26
+ self.requests[identifier] = [
27
+ req_time for req_time in self.requests[identifier]
28
+ if now - req_time < self.window
29
+ ]
30
+ return self.max_requests - len(self.requests[identifier])
app.py CHANGED
@@ -18,14 +18,24 @@ from agent.create_retreiver import load_vector_store
18
  retriever = load_vector_store("intfloat/e5-base-v2","data/FAISS/512-intfloat-e5-base-v2-2026-01-16")
19
 
20
 
 
 
 
21
 
22
  #%% setup chatbot
23
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
24
  from langchain.chat_models import init_chat_model
25
 
26
 
27
- def predict(message, history):
28
 
 
 
 
 
 
 
 
29
  # Safeguard
30
  TRIAGE_PROMPT_TEMPLATE="""You are a Safeguard assistant making sure the user only ask for information related to Rémi Cazelles's projects, work and education.
31
  If the question is not related to this subjects, or if the request is harmfull you should flag the user by answering '*** FLAGGED ***' else simply answer '*** OK ***' """
 
18
  retriever = load_vector_store("intfloat/e5-base-v2","data/FAISS/512-intfloat-e5-base-v2-2026-01-16")
19
 
20
 
21
+ #%% Include a rate limiter
22
+ from agent.restric_usage import RateLimiter
23
+ limiter = RateLimiter(max_requests=10, window_minutes=60)
24
 
25
  #%% setup chatbot
26
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
27
  from langchain.chat_models import init_chat_model
28
 
29
 
30
+ def predict(message, history,request: gr.Request):
31
 
32
+ # Get client IP and check rate limit
33
+ client_ip = request.client.host
34
+ if not limiter.is_allowed(client_ip):
35
+ remaining_time = "an hour" # You could calculate exact time if needed
36
+ return f"**Rate limit exceeded.** You've used your 10 requests per hour. Please try again in {remaining_time}."
37
+
38
+
39
  # Safeguard
40
  TRIAGE_PROMPT_TEMPLATE="""You are a Safeguard assistant making sure the user only ask for information related to Rémi Cazelles's projects, work and education.
41
  If the question is not related to this subjects, or if the request is harmfull you should flag the user by answering '*** FLAGGED ***' else simply answer '*** OK ***' """
data/encryption.py CHANGED
@@ -46,6 +46,14 @@ def save_key(key):
46
 
47
  return True
48
 
 
 
 
 
 
 
 
 
49
 
50
  if __name__ == "__main__":
51
  key, fernet = get_key()
 
46
 
47
  return True
48
 
49
+ # def decrypt_files(file, key):
50
+ # with open(".env", "w") as f:
51
+ # f.readlines()
52
+ # for l in f:
53
+ # if l.startswith("secret_key"):
54
+ # key=f.split("=")[1]
55
+ # return
56
+
57
 
58
  if __name__ == "__main__":
59
  key, fernet = get_key()
tests/test_create_retreiver.py CHANGED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import pytest
3
+ # import tempfile
4
+ # import shutil
5
+ # from pathlib import Path
6
+ # import sys
7
+
8
+ # sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
+
10
+ # from data.encryption import get_key, encrypt, save_key
11
+ # from cryptography.fernet import Fernet
12
+
13
+
14
+ # # ============================================================================
15
+ # # UNIT TESTS
16
+ # # ============================================================================
17
+
18
+ # class TestGetKey:
19
+ # """Unit tests for get_key function"""
20
+
21
+ # def test_get_key_returns_tuple(self):
22
+ # """Test that get_key returns a tuple"""
23
+ # result = get_key()
24
+ # assert isinstance(result, tuple)
25
+ # assert len(result) == 2
26
+
27
+ # def test_get_key_returns_valid_key(self):
28
+ # """Test that get_key returns valid Fernet key"""
29
+ # key, fernet = get_key()
30
+ # assert isinstance(key, bytes)
31
+ # assert isinstance(fernet, Fernet)
32
+
33
+ # def test_key_can_encrypt_decrypt(self):
34
+ # """Test that generated key can encrypt and decrypt"""
35
+ # key, fernet = get_key()
36
+ # message = b"Test message"
37
+ # encrypted = fernet.encrypt(message)
38
+ # decrypted = fernet.decrypt(encrypted)
39
+ # assert decrypted == message
40
+
41
+ # def test_keys_are_unique(self):
42
+ # """Test that each call generates a unique key"""
43
+ # key1, _ = get_key()
44
+ # key2, _ = get_key()
45
+ # assert key1 != key2
46
+
47
+
48
+
49
+ # class TestSaveKey:
50
+ # """Unit tests for save_key function"""
51
+
52
+ # @pytest.fixture
53
+ # def temp_env_file(self):
54
+ # """Create temporary .env file"""
55
+ # fd, path = tempfile.mkstemp(suffix=".env")
56
+ # os.close(fd)
57
+ # yield path
58
+ # if os.path.exists(path):
59
+ # os.remove(path)
60
+
61
+ # def test_save_key_creates_file(self, temp_env_file):
62
+ # """Test that save_key creates .env file"""
63
+ # os.remove(temp_env_file) # Remove to test creation
64
+ # key, _ = get_key()
65
+ # result = save_key(key, temp_env_file)
66
+ # assert result is True
67
+ # assert os.path.exists(temp_env_file)
68
+
69
+ # def test_save_key_appends_to_existing(self, temp_env_file):
70
+ # """Test that save_key appends to existing file"""
71
+ # # Write initial content
72
+ # with open(temp_env_file, "w") as f:
73
+ # f.write("EXISTING_VAR=value\n")
74
+
75
+ # key, _ = get_key()
76
+ # save_key(key, temp_env_file)
77
+
78
+ # with open(temp_env_file, "r") as f:
79
+ # content = f.read()
80
+
81
+ # assert "EXISTING_VAR=value" in content
82
+ # assert "SECRET_KEY=" in content
83
+
84
+ # def test_saved_key_format(self, temp_env_file):
85
+ # """Test that saved key has correct format"""
86
+ # key, _ = get_key()
87
+ # save_key(key, temp_env_file)
88
+
89
+ # with open(temp_env_file, "r") as f:
90
+ # content = f.read()
91
+
92
+ # assert f"SECRET_KEY='{key.decode()}'" in content
93
+
94
+
95
+ # class TestDecrypt:
96
+ # """Unit tests for decrypt function"""
97
+
98
+ # @pytest.fixture
99
+ # def encrypted_file(self):
100
+ # """Create temporary encrypted file"""
101
+ # key, fernet = get_key()
102
+ # fd, path = tempfile.mkstemp()
103
+
104
+ # content = b"Test content to encrypt"
105
+ # encrypted = fernet.encrypt(content)
106
+
107
+ # with os.fdopen(fd, 'wb') as f:
108
+ # f.write(encrypted)
109
+
110
+ # yield path, fernet, content
111
+
112
+ # if os.path.exists(path):
113
+ # os.remove(path)
114
+
115
+ # def test_decrypt_success(self, encrypted_file):
116
+ # """Test successful decryption"""
117
+ # encrypted_path, fernet, original_content = encrypted_file
118
+
119
+ # with tempfile.NamedTemporaryFile(delete=False) as output:
120
+ # output_path = output.name
121
+
122
+ # try:
123
+ # result = decrypt(fernet, encrypted_path, output_path)
124
+ # assert result is True
125
+
126
+ # with open(output_path, 'rb') as f:
127
+ # decrypted = f.read()
128
+
129
+ # assert decrypted == original_content
130
+ # finally:
131
+ # if os.path.exists(output_path):
132
+ # os.remove(output_path)
133
+
134
+ # def test_decrypt_wrong_key(self, encrypted_file):
135
+ # """Test decryption with wrong key fails"""
136
+ # encrypted_path, _, _ = encrypted_file
137
+ # _, wrong_fernet = get_key() # Different key
138
+
139
+ # with tempfile.NamedTemporaryFile(delete=False) as output:
140
+ # output_path = output.name
141
+
142
+ # try:
143
+ # result = decrypt(wrong_fernet, encrypted_path, output_path)
144
+ # assert result is False
145
+ # finally:
146
+ # if os.path.exists(output_path):
147
+ # os.remove(output_path)
148
+
149
+
150
+ # # ============================================================================
151
+ # # INTEGRATION TESTS
152
+ # # ============================================================================
153
+
154
+ # class TestEncryptionWorkflow:
155
+ # """Integration tests for complete encryption workflow"""
156
+
157
+ # @pytest.fixture
158
+ # def test_environment(self):
159
+ # """Create complete test environment"""
160
+ # # Create temporary directories
161
+ # source_dir = tempfile.mkdtemp()
162
+ # output_dir = tempfile.mkdtemp()
163
+ # env_file = tempfile.NamedTemporaryFile(delete=False, suffix=".env")
164
+ # env_file.close()
165
+
166
+ # # Create test files
167
+ # Path(source_dir, "test1.txt").write_text("Content 1")
168
+ # Path(source_dir, "test2.csv").write_text("a,b,c\n1,2,3")
169
+
170
+ # subdir = Path(source_dir, "subdir")
171
+ # subdir.mkdir()
172
+ # Path(subdir, "test3.json").write_text('{"key": "value"}')
173
+
174
+ # yield source_dir, output_dir, env_file.name
175
+
176
+ # # Cleanup
177
+ # shutil.rmtree(source_dir, ignore_errors=True)
178
+ # shutil.rmtree(output_dir, ignore_errors=True)
179
+ # if os.path.exists(env_file.name):
180
+ # os.remove(env_file.name)
181
+
182
+ # def test_full_encryption_workflow(self, test_environment):
183
+ # """Test complete encryption workflow"""
184
+ # source_dir, output_dir, env_file = test_environment
185
+
186
+ # # Generate key
187
+ # key, fernet = get_key()
188
+
189
+ # # Encrypt files
190
+ # result = encrypt(fernet, source_dir, output_dir)
191
+ # assert result is True
192
+
193
+ # # Verify encrypted files exist
194
+ # encrypted_files = get_all_files(output_dir)
195
+ # assert len(encrypted_files) == 3
196
+
197
+ # # Save key
198
+ # result = save_key(key, env_file)
199
+ # assert result is True
200
+
201
+ # # Verify key was saved
202
+ # with open(env_file, "r") as f:
203
+ # content = f.read()
204
+ # assert "SECRET_KEY=" in content
205
+
206
+ # def test_encrypt_decrypt_roundtrip(self, test_environment):
207
+ # """Test that files can be encrypted and then decrypted"""
208
+ # source_dir, output_dir, _ = test_environment
209
+
210
+ # # Generate key
211
+ # key, fernet = get_key()
212
+
213
+ # # Encrypt
214
+ # encrypt(fernet, source_dir, output_dir)
215
+
216
+ # # Decrypt one file
217
+ # encrypted_file = os.path.join(output_dir, "test1.txt")
218
+ # decrypted_dir = tempfile.mkdtemp()
219
+ # decrypted_file = os.path.join(decrypted_dir, "test1.txt")
220
+
221
+ # try:
222
+ # decrypt(fernet, encrypted_file, decrypted_file)
223
+
224
+ # # Verify content matches original
225
+ # with open(decrypted_file, 'r') as f:
226
+ # content = f.read()
227
+ # assert content == "Content 1"
228
+ # finally:
229
+ # shutil.rmtree(decrypted_dir, ignore_errors=True)
230
+
231
+ # def test_preserves_directory_structure(self, test_environment):
232
+ # """Test that directory structure is preserved"""
233
+ # source_dir, output_dir, _ = test_environment
234
+
235
+ # key, fernet = get_key()
236
+ # encrypt(fernet, source_dir, output_dir)
237
+
238
+ # # Check that subdirectory exists
239
+ # encrypted_subdir_file = os.path.join(output_dir, "subdir", "test3.json")
240
+ # assert os.path.exists(encrypted_subdir_file)
241
+
242
+ # def test_empty_directory_handling(self):
243
+ # """Test handling of empty directory"""
244
+ # with tempfile.TemporaryDirectory() as source_dir:
245
+ # with tempfile.TemporaryDirectory() as output_dir:
246
+ # key, fernet = get_key()
247
+ # result = encrypt(fernet, source_dir, output_dir)
248
+ # assert result is False
249
+
250
+
251
+ # # ============================================================================
252
+ # # PARAMETRIZED TESTS
253
+ # # ============================================================================
254
+
255
+ # class TestEncryptionWithDifferentFileTypes:
256
+ # """Test encryption with various file types"""
257
+
258
+ # @pytest.mark.parametrize("filename,content", [
259
+ # ("test.txt", b"Plain text content"),
260
+ # ("test.json", b'{"key": "value"}'),
261
+ # ("test.csv", b"a,b,c\n1,2,3"),
262
+ # ("test.bin", bytes(range(256))),
263
+ # ("test.pdf", b"%PDF-1.4\n%\xe2\xe3\xcf\xd3"),
264
+ # ])
265
+ # def test_encrypt_different_file_types(self, filename, content):
266
+ # """Test encryption of different file types"""
267
+ # with tempfile.TemporaryDirectory() as source_dir:
268
+ # with tempfile.TemporaryDirectory() as output_dir:
269
+ # # Create test file
270
+ # file_path = Path(source_dir, filename)
271
+ # file_path.write_bytes(content)
272
+
273
+ # # Encrypt
274
+ # key, fernet = get_key()
275
+ # result = encrypt(fernet, source_dir, output_dir)
276
+ # assert result is True
277
+
278
+ # # Verify encrypted file exists
279
+ # encrypted_path = Path(output_dir, filename)
280
+ # assert encrypted_path.exists()
281
+
282
+ # # Verify content is different (encrypted)
283
+ # encrypted_content = encrypted_path.read_bytes()
284
+ # assert encrypted_content != content
285
+
286
+
287
+ # if __name__ == "__main__":
288
+ # pytest.main([__file__, "-v", "--tb=short"])
tests/test_encrypt_files.py CHANGED
@@ -1,399 +1,292 @@
1
- import os
2
- import pytest
3
- import tempfile
4
- import shutil
5
- from pathlib import Path
6
- from cryptography.fernet import Fernet
 
7
 
8
- # Import the functions to test
9
- # Assuming the main file is named 'encryption.py'
10
- # from encryption import get_key, encrypt, save_key, decrypt, get_all_files
11
 
12
 
13
- # For testing purposes, include the functions here
14
- # In real scenario, import from the main module
15
- def get_key():
16
- key = Fernet.generate_key()
17
- fernet = Fernet(key)
18
- return key, fernet
19
 
 
 
20
 
21
- def get_all_files(root_path: str = "./"):
22
- files = []
23
- for root, dirs, filenames in os.walk(root_path):
24
- for filename in filenames:
25
- file_path = os.path.join(root, filename)
26
- files.append(file_path)
27
- return files
28
 
29
 
30
- def encrypt(fernet, root_path: str = "./", output_dir: str = "./encrypted_data"):
31
- try:
32
- files = get_all_files(root_path)
33
- if not files:
34
- return False
35
-
36
- os.makedirs(output_dir, exist_ok=True)
37
-
38
- for file_path in files:
39
- try:
40
- with open(file_path, 'rb') as f:
41
- original = f.read()
42
- encrypted = fernet.encrypt(original)
43
- relative_path = os.path.relpath(file_path, root_path)
44
- output_path = os.path.join(output_dir, relative_path)
45
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
46
- with open(output_path, 'wb') as f:
47
- f.write(encrypted)
48
- except Exception:
49
- continue
50
- return True
51
- except Exception:
52
- return False
53
-
54
-
55
- def save_key(key, env_path: str = ".env"):
56
- try:
57
- if os.path.exists(env_path):
58
- with open(env_path, "r") as f:
59
- lines = f.readlines()
60
- else:
61
- lines = []
62
-
63
- if lines and not lines[-1].endswith('\n'):
64
- lines.append('\n')
65
- lines.append(f"\nSECRET_KEY='{key.decode()}'\n")
66
-
67
- with open(env_path, "w") as f:
68
- f.writelines(lines)
69
- return True
70
- except Exception:
71
- return False
72
-
73
 
74
- def decrypt(fernet, encrypted_path: str, output_path: str):
75
- try:
76
- with open(encrypted_path, 'rb') as f:
77
- encrypted = f.read()
78
- decrypted = fernet.decrypt(encrypted)
79
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
80
- with open(output_path, 'wb') as f:
81
- f.write(decrypted)
82
- return True
83
- except Exception:
84
- return False
85
-
86
-
87
- # ============================================================================
88
- # UNIT TESTS
89
- # ============================================================================
90
-
91
- class TestGetKey:
92
- """Unit tests for get_key function"""
93
 
94
- def test_get_key_returns_tuple(self):
95
- """Test that get_key returns a tuple"""
96
- result = get_key()
97
- assert isinstance(result, tuple)
98
- assert len(result) == 2
99
 
100
- def test_get_key_returns_valid_key(self):
101
- """Test that get_key returns valid Fernet key"""
102
- key, fernet = get_key()
103
- assert isinstance(key, bytes)
104
- assert isinstance(fernet, Fernet)
105
 
106
- def test_key_can_encrypt_decrypt(self):
107
- """Test that generated key can encrypt and decrypt"""
108
- key, fernet = get_key()
109
- message = b"Test message"
110
- encrypted = fernet.encrypt(message)
111
- decrypted = fernet.decrypt(encrypted)
112
- assert decrypted == message
113
 
114
- def test_keys_are_unique(self):
115
- """Test that each call generates a unique key"""
116
- key1, _ = get_key()
117
- key2, _ = get_key()
118
- assert key1 != key2
119
 
120
 
121
- class TestGetAllFiles:
122
- """Unit tests for get_all_files function"""
123
-
124
- @pytest.fixture
125
- def temp_dir(self):
126
- """Create temporary directory with test files"""
127
- temp_dir = tempfile.mkdtemp()
128
-
129
- # Create test files
130
- Path(temp_dir, "file1.txt").write_text("content1")
131
- Path(temp_dir, "file2.txt").write_text("content2")
132
-
133
- # Create subdirectory with file
134
- subdir = Path(temp_dir, "subdir")
135
- subdir.mkdir()
136
- Path(subdir, "file3.txt").write_text("content3")
137
-
138
- yield temp_dir
139
-
140
- # Cleanup
141
- shutil.rmtree(temp_dir)
142
-
143
- def test_finds_all_files(self, temp_dir):
144
- """Test that all files are found"""
145
- files = get_all_files(temp_dir)
146
- assert len(files) == 3
147
-
148
- def test_empty_directory(self):
149
- """Test with empty directory"""
150
- with tempfile.TemporaryDirectory() as temp_dir:
151
- files = get_all_files(temp_dir)
152
- assert files == []
153
-
154
- def test_returns_list(self, temp_dir):
155
- """Test that function returns a list"""
156
- files = get_all_files(temp_dir)
157
- assert isinstance(files, list)
158
-
159
 
160
- class TestSaveKey:
161
- """Unit tests for save_key function"""
162
 
163
- @pytest.fixture
164
- def temp_env_file(self):
165
- """Create temporary .env file"""
166
- fd, path = tempfile.mkstemp(suffix=".env")
167
- os.close(fd)
168
- yield path
169
- if os.path.exists(path):
170
- os.remove(path)
171
 
172
- def test_save_key_creates_file(self, temp_env_file):
173
- """Test that save_key creates .env file"""
174
- os.remove(temp_env_file) # Remove to test creation
175
- key, _ = get_key()
176
- result = save_key(key, temp_env_file)
177
- assert result is True
178
- assert os.path.exists(temp_env_file)
179
 
180
- def test_save_key_appends_to_existing(self, temp_env_file):
181
- """Test that save_key appends to existing file"""
182
- # Write initial content
183
- with open(temp_env_file, "w") as f:
184
- f.write("EXISTING_VAR=value\n")
185
 
186
- key, _ = get_key()
187
- save_key(key, temp_env_file)
188
 
189
- with open(temp_env_file, "r") as f:
190
- content = f.read()
191
 
192
- assert "EXISTING_VAR=value" in content
193
- assert "SECRET_KEY=" in content
194
 
195
- def test_saved_key_format(self, temp_env_file):
196
- """Test that saved key has correct format"""
197
- key, _ = get_key()
198
- save_key(key, temp_env_file)
199
 
200
- with open(temp_env_file, "r") as f:
201
- content = f.read()
202
 
203
- assert f"SECRET_KEY='{key.decode()}'" in content
204
 
205
 
206
- class TestDecrypt:
207
- """Unit tests for decrypt function"""
208
 
209
- @pytest.fixture
210
- def encrypted_file(self):
211
- """Create temporary encrypted file"""
212
- key, fernet = get_key()
213
- fd, path = tempfile.mkstemp()
214
 
215
- content = b"Test content to encrypt"
216
- encrypted = fernet.encrypt(content)
217
 
218
- with os.fdopen(fd, 'wb') as f:
219
- f.write(encrypted)
220
 
221
- yield path, fernet, content
222
 
223
- if os.path.exists(path):
224
- os.remove(path)
225
 
226
- def test_decrypt_success(self, encrypted_file):
227
- """Test successful decryption"""
228
- encrypted_path, fernet, original_content = encrypted_file
229
 
230
- with tempfile.NamedTemporaryFile(delete=False) as output:
231
- output_path = output.name
232
 
233
- try:
234
- result = decrypt(fernet, encrypted_path, output_path)
235
- assert result is True
236
 
237
- with open(output_path, 'rb') as f:
238
- decrypted = f.read()
239
 
240
- assert decrypted == original_content
241
- finally:
242
- if os.path.exists(output_path):
243
- os.remove(output_path)
244
 
245
- def test_decrypt_wrong_key(self, encrypted_file):
246
- """Test decryption with wrong key fails"""
247
- encrypted_path, _, _ = encrypted_file
248
- _, wrong_fernet = get_key() # Different key
249
-
250
- with tempfile.NamedTemporaryFile(delete=False) as output:
251
- output_path = output.name
252
-
253
- try:
254
- result = decrypt(wrong_fernet, encrypted_path, output_path)
255
- assert result is False
256
- finally:
257
- if os.path.exists(output_path):
258
- os.remove(output_path)
259
 
260
 
261
- # ============================================================================
262
- # INTEGRATION TESTS
263
- # ============================================================================
264
 
265
- class TestEncryptionWorkflow:
266
- """Integration tests for complete encryption workflow"""
267
 
268
- @pytest.fixture
269
- def test_environment(self):
270
- """Create complete test environment"""
271
- # Create temporary directories
272
- source_dir = tempfile.mkdtemp()
273
- output_dir = tempfile.mkdtemp()
274
- env_file = tempfile.NamedTemporaryFile(delete=False, suffix=".env")
275
- env_file.close()
276
-
277
- # Create test files
278
- Path(source_dir, "test1.txt").write_text("Content 1")
279
- Path(source_dir, "test2.csv").write_text("a,b,c\n1,2,3")
280
-
281
- subdir = Path(source_dir, "subdir")
282
- subdir.mkdir()
283
- Path(subdir, "test3.json").write_text('{"key": "value"}')
284
-
285
- yield source_dir, output_dir, env_file.name
286
-
287
- # Cleanup
288
- shutil.rmtree(source_dir, ignore_errors=True)
289
- shutil.rmtree(output_dir, ignore_errors=True)
290
- if os.path.exists(env_file.name):
291
- os.remove(env_file.name)
292
 
293
- def test_full_encryption_workflow(self, test_environment):
294
- """Test complete encryption workflow"""
295
- source_dir, output_dir, env_file = test_environment
296
 
297
- # Generate key
298
- key, fernet = get_key()
299
 
300
- # Encrypt files
301
- result = encrypt(fernet, source_dir, output_dir)
302
- assert result is True
303
 
304
- # Verify encrypted files exist
305
- encrypted_files = get_all_files(output_dir)
306
- assert len(encrypted_files) == 3
307
 
308
- # Save key
309
- result = save_key(key, env_file)
310
- assert result is True
311
 
312
- # Verify key was saved
313
- with open(env_file, "r") as f:
314
- content = f.read()
315
- assert "SECRET_KEY=" in content
316
 
317
- def test_encrypt_decrypt_roundtrip(self, test_environment):
318
- """Test that files can be encrypted and then decrypted"""
319
- source_dir, output_dir, _ = test_environment
320
 
321
- # Generate key
322
- key, fernet = get_key()
323
 
324
- # Encrypt
325
- encrypt(fernet, source_dir, output_dir)
326
 
327
- # Decrypt one file
328
- encrypted_file = os.path.join(output_dir, "test1.txt")
329
- decrypted_dir = tempfile.mkdtemp()
330
- decrypted_file = os.path.join(decrypted_dir, "test1.txt")
331
 
332
- try:
333
- decrypt(fernet, encrypted_file, decrypted_file)
334
 
335
- # Verify content matches original
336
- with open(decrypted_file, 'r') as f:
337
- content = f.read()
338
- assert content == "Content 1"
339
- finally:
340
- shutil.rmtree(decrypted_dir, ignore_errors=True)
341
 
342
- def test_preserves_directory_structure(self, test_environment):
343
- """Test that directory structure is preserved"""
344
- source_dir, output_dir, _ = test_environment
345
 
346
- key, fernet = get_key()
347
- encrypt(fernet, source_dir, output_dir)
348
 
349
- # Check that subdirectory exists
350
- encrypted_subdir_file = os.path.join(output_dir, "subdir", "test3.json")
351
- assert os.path.exists(encrypted_subdir_file)
352
 
353
- def test_empty_directory_handling(self):
354
- """Test handling of empty directory"""
355
- with tempfile.TemporaryDirectory() as source_dir:
356
- with tempfile.TemporaryDirectory() as output_dir:
357
- key, fernet = get_key()
358
- result = encrypt(fernet, source_dir, output_dir)
359
- assert result is False
360
 
361
 
362
- # ============================================================================
363
- # PARAMETRIZED TESTS
364
- # ============================================================================
365
 
366
- class TestEncryptionWithDifferentFileTypes:
367
- """Test encryption with various file types"""
368
 
369
- @pytest.mark.parametrize("filename,content", [
370
- ("test.txt", b"Plain text content"),
371
- ("test.json", b'{"key": "value"}'),
372
- ("test.csv", b"a,b,c\n1,2,3"),
373
- ("test.bin", bytes(range(256))),
374
- ("test.pdf", b"%PDF-1.4\n%\xe2\xe3\xcf\xd3"),
375
- ])
376
- def test_encrypt_different_file_types(self, filename, content):
377
- """Test encryption of different file types"""
378
- with tempfile.TemporaryDirectory() as source_dir:
379
- with tempfile.TemporaryDirectory() as output_dir:
380
- # Create test file
381
- file_path = Path(source_dir, filename)
382
- file_path.write_bytes(content)
383
 
384
- # Encrypt
385
- key, fernet = get_key()
386
- result = encrypt(fernet, source_dir, output_dir)
387
- assert result is True
388
 
389
- # Verify encrypted file exists
390
- encrypted_path = Path(output_dir, filename)
391
- assert encrypted_path.exists()
392
 
393
- # Verify content is different (encrypted)
394
- encrypted_content = encrypted_path.read_bytes()
395
- assert encrypted_content != content
396
 
397
 
398
- if __name__ == "__main__":
399
- pytest.main([__file__, "-v", "--tb=short"])
 
1
+ # import os
2
+ # import pytest
3
+ # import tempfile
4
+ # import shutil
5
+ # from pathlib import Path
6
+ # from cryptography.fernet import Fernet
7
+ # import sys
8
 
 
 
 
9
 
10
 
11
+ # sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
 
 
 
 
 
12
 
13
+ # from data.encryption import get_key, encrypt, save_key
14
+ # from cryptography.fernet import Fernet
15
 
 
 
 
 
 
 
 
16
 
17
 
18
+ # # ============================================================================
19
+ # # UNIT TESTS
20
+ # # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # class TestGetKey:
23
+ # """Unit tests for get_key function"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # def test_get_key_returns_tuple(self):
26
+ # """Test that get_key returns a tuple"""
27
+ # result = get_key()
28
+ # assert isinstance(result, tuple)
29
+ # assert len(result) == 2
30
 
31
+ # def test_get_key_returns_valid_key(self):
32
+ # """Test that get_key returns valid Fernet key"""
33
+ # key, fernet = get_key()
34
+ # assert isinstance(key, bytes)
35
+ # assert isinstance(fernet, Fernet)
36
 
37
+ # def test_key_can_encrypt_decrypt(self):
38
+ # """Test that generated key can encrypt and decrypt"""
39
+ # key, fernet = get_key()
40
+ # message = b"Test message"
41
+ # encrypted = fernet.encrypt(message)
42
+ # decrypted = fernet.decrypt(encrypted)
43
+ # assert decrypted == message
44
 
45
+ # def test_keys_are_unique(self):
46
+ # """Test that each call generates a unique key"""
47
+ # key1, _ = get_key()
48
+ # key2, _ = get_key()
49
+ # assert key1 != key2
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # class TestSaveKey:
54
+ # """Unit tests for save_key function"""
55
 
56
+ # @pytest.fixture
57
+ # def temp_env_file(self):
58
+ # """Create temporary .env file"""
59
+ # fd, path = tempfile.mkstemp(suffix=".env")
60
+ # os.close(fd)
61
+ # yield path
62
+ # if os.path.exists(path):
63
+ # os.remove(path)
64
 
65
+ # def test_save_key_creates_file(self, temp_env_file):
66
+ # """Test that save_key creates .env file"""
67
+ # os.remove(temp_env_file) # Remove to test creation
68
+ # key, _ = get_key()
69
+ # result = save_key(key, temp_env_file)
70
+ # assert result is True
71
+ # assert os.path.exists(temp_env_file)
72
 
73
+ # def test_save_key_appends_to_existing(self, temp_env_file):
74
+ # """Test that save_key appends to existing file"""
75
+ # # Write initial content
76
+ # with open(temp_env_file, "w") as f:
77
+ # f.write("EXISTING_VAR=value\n")
78
 
79
+ # key, _ = get_key()
80
+ # save_key(key, temp_env_file)
81
 
82
+ # with open(temp_env_file, "r") as f:
83
+ # content = f.read()
84
 
85
+ # assert "EXISTING_VAR=value" in content
86
+ # assert "SECRET_KEY=" in content
87
 
88
+ # def test_saved_key_format(self, temp_env_file):
89
+ # """Test that saved key has correct format"""
90
+ # key, _ = get_key()
91
+ # save_key(key, temp_env_file)
92
 
93
+ # with open(temp_env_file, "r") as f:
94
+ # content = f.read()
95
 
96
+ # assert f"SECRET_KEY='{key.decode()}'" in content
97
 
98
 
99
+ # class TestDecrypt:
100
+ # """Unit tests for decrypt function"""
101
 
102
+ # @pytest.fixture
103
+ # def encrypted_file(self):
104
+ # """Create temporary encrypted file"""
105
+ # key, fernet = get_key()
106
+ # fd, path = tempfile.mkstemp()
107
 
108
+ # content = b"Test content to encrypt"
109
+ # encrypted = fernet.encrypt(content)
110
 
111
+ # with os.fdopen(fd, 'wb') as f:
112
+ # f.write(encrypted)
113
 
114
+ # yield path, fernet, content
115
 
116
+ # if os.path.exists(path):
117
+ # os.remove(path)
118
 
119
+ # # def test_decrypt_success(self, encrypted_file):
120
+ # # """Test successful decryption"""
121
+ # # encrypted_path, fernet, original_content = encrypted_file
122
 
123
+ # # with tempfile.NamedTemporaryFile(delete=False) as output:
124
+ # # output_path = output.name
125
 
126
+ # # try:
127
+ # # result = decrypt(fernet, encrypted_path, output_path)
128
+ # # assert result is True
129
 
130
+ # # with open(output_path, 'rb') as f:
131
+ # # decrypted = f.read()
132
 
133
+ # # assert decrypted == original_content
134
+ # # finally:
135
+ # # if os.path.exists(output_path):
136
+ # # os.remove(output_path)
137
 
138
+ # def test_decrypt_wrong_key(self, encrypted_file):
139
+ # """Test decryption with wrong key fails"""
140
+ # encrypted_path, _, _ = encrypted_file
141
+ # _, wrong_fernet = get_key() # Different key
142
+
143
+ # with tempfile.NamedTemporaryFile(delete=False) as output:
144
+ # output_path = output.name
145
+
146
+ # try:
147
+ # result = decrypt(wrong_fernet, encrypted_path, output_path)
148
+ # assert result is False
149
+ # finally:
150
+ # if os.path.exists(output_path):
151
+ # os.remove(output_path)
152
 
153
 
154
+ # # ============================================================================
155
+ # # INTEGRATION TESTS
156
+ # # ============================================================================
157
 
158
+ # class TestEncryptionWorkflow:
159
+ # """Integration tests for complete encryption workflow"""
160
 
161
+ # @pytest.fixture
162
+ # def test_environment(self):
163
+ # """Create complete test environment"""
164
+ # # Create temporary directories
165
+ # source_dir = tempfile.mkdtemp()
166
+ # output_dir = tempfile.mkdtemp()
167
+ # env_file = tempfile.NamedTemporaryFile(delete=False, suffix=".env")
168
+ # env_file.close()
169
+
170
+ # # Create test files
171
+ # Path(source_dir, "test1.txt").write_text("Content 1")
172
+ # Path(source_dir, "test2.csv").write_text("a,b,c\n1,2,3")
173
+
174
+ # subdir = Path(source_dir, "subdir")
175
+ # subdir.mkdir()
176
+ # Path(subdir, "test3.json").write_text('{"key": "value"}')
177
+
178
+ # yield source_dir, output_dir, env_file.name
179
+
180
+ # # Cleanup
181
+ # shutil.rmtree(source_dir, ignore_errors=True)
182
+ # shutil.rmtree(output_dir, ignore_errors=True)
183
+ # if os.path.exists(env_file.name):
184
+ # os.remove(env_file.name)
185
 
186
+ # def test_full_encryption_workflow(self, test_environment):
187
+ # """Test complete encryption workflow"""
188
+ # source_dir, output_dir, env_file = test_environment
189
 
190
+ # # Generate key
191
+ # key, fernet = get_key()
192
 
193
+ # # Encrypt files
194
+ # result = encrypt(fernet, source_dir, output_dir)
195
+ # assert result is True
196
 
197
+ # # Verify encrypted files exist
198
+ # encrypted_files = get_all_files(output_dir)
199
+ # assert len(encrypted_files) == 3
200
 
201
+ # # Save key
202
+ # result = save_key(key, env_file)
203
+ # assert result is True
204
 
205
+ # # Verify key was saved
206
+ # with open(env_file, "r") as f:
207
+ # content = f.read()
208
+ # assert "SECRET_KEY=" in content
209
 
210
+ # def test_encrypt_decrypt_roundtrip(self, test_environment):
211
+ # """Test that files can be encrypted and then decrypted"""
212
+ # source_dir, output_dir, _ = test_environment
213
 
214
+ # # Generate key
215
+ # key, fernet = get_key()
216
 
217
+ # # Encrypt
218
+ # encrypt(fernet, source_dir, output_dir)
219
 
220
+ # # Decrypt one file
221
+ # encrypted_file = os.path.join(output_dir, "test1.txt")
222
+ # decrypted_dir = tempfile.mkdtemp()
223
+ # decrypted_file = os.path.join(decrypted_dir, "test1.txt")
224
 
225
+ # try:
226
+ # decrypt(fernet, encrypted_file, decrypted_file)
227
 
228
+ # # Verify content matches original
229
+ # with open(decrypted_file, 'r') as f:
230
+ # content = f.read()
231
+ # assert content == "Content 1"
232
+ # finally:
233
+ # shutil.rmtree(decrypted_dir, ignore_errors=True)
234
 
235
+ # def test_preserves_directory_structure(self, test_environment):
236
+ # """Test that directory structure is preserved"""
237
+ # source_dir, output_dir, _ = test_environment
238
 
239
+ # key, fernet = get_key()
240
+ # encrypt(fernet, source_dir, output_dir)
241
 
242
+ # # Check that subdirectory exists
243
+ # encrypted_subdir_file = os.path.join(output_dir, "subdir", "test3.json")
244
+ # assert os.path.exists(encrypted_subdir_file)
245
 
246
+ # def test_empty_directory_handling(self):
247
+ # """Test handling of empty directory"""
248
+ # with tempfile.TemporaryDirectory() as source_dir:
249
+ # with tempfile.TemporaryDirectory() as output_dir:
250
+ # key, fernet = get_key()
251
+ # result = encrypt(fernet, source_dir, output_dir)
252
+ # assert result is False
253
 
254
 
255
+ # # ============================================================================
256
+ # # PARAMETRIZED TESTS
257
+ # # ============================================================================
258
 
259
+ # class TestEncryptionWithDifferentFileTypes:
260
+ # """Test encryption with various file types"""
261
 
262
+ # @pytest.mark.parametrize("filename,content", [
263
+ # ("test.txt", b"Plain text content"),
264
+ # ("test.json", b'{"key": "value"}'),
265
+ # ("test.csv", b"a,b,c\n1,2,3"),
266
+ # ("test.bin", bytes(range(256))),
267
+ # ("test.pdf", b"%PDF-1.4\n%\xe2\xe3\xcf\xd3"),
268
+ # ])
269
+ # def test_encrypt_different_file_types(self, filename, content):
270
+ # """Test encryption of different file types"""
271
+ # with tempfile.TemporaryDirectory() as source_dir:
272
+ # with tempfile.TemporaryDirectory() as output_dir:
273
+ # # Create test file
274
+ # file_path = Path(source_dir, filename)
275
+ # file_path.write_bytes(content)
276
 
277
+ # # Encrypt
278
+ # key, fernet = get_key()
279
+ # result = encrypt(fernet, source_dir, output_dir)
280
+ # assert result is True
281
 
282
+ # # Verify encrypted file exists
283
+ # encrypted_path = Path(output_dir, filename)
284
+ # assert encrypted_path.exists()
285
 
286
+ # # Verify content is different (encrypted)
287
+ # encrypted_content = encrypted_path.read_bytes()
288
+ # assert encrypted_content != content
289
 
290
 
291
+ # if __name__ == "__main__":
292
+ # pytest.main([__file__, "-v", "--tb=short"])
tests/test_restrict_usage.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from datetime import datetime, timedelta
3
+ from unittest.mock import patch
4
+ import sys
5
+ import os
6
+
7
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
+
9
+ from agent.restrict_usage import RateLimiter
10
+
11
+
12
+ @pytest.fixture
13
+ def limiter():
14
+ """Create a RateLimiter instance for testing."""
15
+ return RateLimiter(max_requests=5, window_minutes=60)
16
+
17
+
18
+ def test_first_request_allowed(limiter):
19
+ """First request should always be allowed."""
20
+ assert limiter.is_allowed("test_ip") is True
21
+
22
+
23
+ def test_requests_within_limit(limiter):
24
+ """All requests within limit should be allowed."""
25
+ for _ in range(5):
26
+ assert limiter.is_allowed("test_ip") is True
27
+
28
+
29
+ def test_request_exceeds_limit(limiter):
30
+ """Request exceeding limit should be blocked."""
31
+ for _ in range(5):
32
+ limiter.is_allowed("test_ip")
33
+ assert limiter.is_allowed("test_ip") is False
34
+
35
+
36
+ def test_different_identifiers_separate(limiter):
37
+ """Different identifiers should have separate limits."""
38
+ for _ in range(5):
39
+ limiter.is_allowed("ip1")
40
+ assert limiter.is_allowed("ip1") is False
41
+ assert limiter.is_allowed("ip2") is True
42
+
43
+
44
+ def test_get_remaining(limiter):
45
+ """get_remaining should return correct count."""
46
+ assert limiter.get_remaining("test_ip") == 5
47
+ limiter.is_allowed("test_ip")
48
+ assert limiter.get_remaining("test_ip") == 4
49
+ for _ in range(4):
50
+ limiter.is_allowed("test_ip")
51
+ assert limiter.get_remaining("test_ip") == 0
52
+
53
+
54
+ def test_old_requests_cleaned(limiter):
55
+ """Old requests outside window should be cleaned."""
56
+ fixed_time = datetime(2024, 1, 1, 12, 0, 0)
57
+
58
+ with patch('agent.restrict_usage.datetime') as mock_datetime:
59
+ mock_datetime.now.return_value = fixed_time
60
+
61
+ for _ in range(5):
62
+ limiter.is_allowed("test_ip")
63
+
64
+ assert limiter.is_allowed("test_ip") is False
65
+
66
+ # Move time forward past the window
67
+ mock_datetime.now.return_value = fixed_time + timedelta(minutes=61)
68
+
69
+ assert limiter.is_allowed("test_ip") is True
70
+ assert limiter.get_remaining("test_ip") == 4
71
+
72
+
73
+ def test_multiple_identifiers_tracking(limiter):
74
+ """Multiple identifiers should be tracked independently."""
75
+ limiter.is_allowed("user1")
76
+ limiter.is_allowed("user1")
77
+ limiter.is_allowed("user2")
78
+
79
+ assert limiter.get_remaining("user1") == 3
80
+ assert limiter.get_remaining("user2") == 4
81
+ assert limiter.get_remaining("user3") == 5