File size: 9,500 Bytes
ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 ff73b92 ce4bc73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
"""
Tests for the security module.
"""
import base64
import io
import os
import sys
import unittest
import pandas as pd
sys.path.insert(0, os.path.abspath(".."))
from src.folio.security import (
sanitize_cell,
sanitize_dangerous_content,
sanitize_dataframe,
sanitize_formula,
validate_csv_upload,
)
class TestSecurity(unittest.TestCase):
"""Test cases for the security module."""
def test_sanitize_cell(self):
"""Test sanitizing individual cell values."""
# Test formula sanitization
self.assertEqual(sanitize_cell("=SUM(A1:B1)"), "'=SUM(A1:B1)")
self.assertEqual(sanitize_cell("@SUM(A1:B1)"), "'@SUM(A1:B1)")
self.assertEqual(sanitize_cell("+SUM(A1:B1)"), "+SUM(A1:B1)")
# Test HTML/script sanitization
self.assertEqual(sanitize_cell("<script>alert('XSS')</script>"), "[REMOVED]")
self.assertEqual(
sanitize_cell("javascript:alert('XSS')"), "[REMOVED]alert('XSS')"
)
self.assertEqual(
sanitize_cell("<img src=x onerror=alert('XSS')>"),
"<img src=x [REMOVED]=alert('XSS')>",
)
# Test command injection sanitization
self.assertEqual(sanitize_cell("value; rm -rf /"), "value rm -rf /")
self.assertEqual(
sanitize_cell("value | cat /etc/passwd"), "value cat /etc/passwd"
)
# Test non-string values
self.assertEqual(sanitize_cell(123), "123")
self.assertEqual(sanitize_cell(None), "None")
# Test negative numbers (should not be modified)
self.assertEqual(sanitize_cell("-123"), "-123")
self.assertEqual(sanitize_cell("-123.45"), "-123.45")
# Test financial values (should not be modified)
self.assertEqual(sanitize_cell("$123.45"), "$123.45")
self.assertEqual(sanitize_cell("-$123.45"), "-$123.45")
self.assertEqual(sanitize_cell("+$123.45"), "+$123.45")
# Test percentage values (should not be modified)
self.assertEqual(sanitize_cell("-12.34%"), "-12.34%")
self.assertEqual(sanitize_cell("+12.34%"), "+12.34%")
self.assertEqual(sanitize_cell("12.34%"), "12.34%")
# Test stock names with ampersands (should not be modified)
self.assertEqual(sanitize_cell("S&P 500"), "S&P 500")
self.assertEqual(
sanitize_cell("PROSHARES ULTRAPRO S&P500"), "PROSHARES ULTRAPRO S&P500"
)
def test_sanitize_formula(self):
"""Test sanitizing formula-like content."""
self.assertEqual(sanitize_formula("=SUM(A1:B1)"), "'=SUM(A1:B1)")
self.assertEqual(sanitize_formula("@SUM(A1:B1)"), "'@SUM(A1:B1)")
self.assertEqual(sanitize_formula("+SUM(A1:B1)"), "'+SUM(A1:B1)")
# Test that normal text is not modified
self.assertEqual(sanitize_formula("Normal text"), "Normal text")
# Test that negative numbers are not modified
self.assertEqual(sanitize_formula("-123"), "-123")
# Test financial values (should not be modified)
self.assertEqual(sanitize_formula("$123.45"), "$123.45")
self.assertEqual(sanitize_formula("-$123.45"), "-$123.45")
self.assertEqual(sanitize_formula("+$123.45"), "+$123.45")
# Test percentage values (should not be modified)
self.assertEqual(sanitize_formula("-12.34%"), "-12.34%")
self.assertEqual(sanitize_formula("+12.34%"), "+12.34%")
self.assertEqual(sanitize_formula("12.34%"), "12.34%")
def test_sanitize_dangerous_content(self):
"""Test sanitizing dangerous content while preserving financial data."""
# Test financial values (should not be modified)
self.assertEqual(sanitize_dangerous_content("$123.45"), "$123.45")
self.assertEqual(sanitize_dangerous_content("-$123.45"), "-$123.45")
self.assertEqual(sanitize_dangerous_content("+$123.45"), "+$123.45")
# Test percentage values (should not be modified)
self.assertEqual(sanitize_dangerous_content("-12.34%"), "-12.34%")
self.assertEqual(sanitize_dangerous_content("+12.34%"), "+12.34%")
self.assertEqual(sanitize_dangerous_content("12.34%"), "12.34%")
# Test stock names with ampersands (should not be modified)
self.assertEqual(sanitize_dangerous_content("S&P 500"), "S&P 500")
self.assertEqual(
sanitize_dangerous_content("PROSHARES ULTRAPRO S&P500"),
"PROSHARES ULTRAPRO S&P500",
)
# Test formula sanitization
self.assertEqual(sanitize_dangerous_content("=SUM(A1:B1)"), "'=SUM(A1:B1)")
self.assertEqual(sanitize_dangerous_content("@SUM(A1:B1)"), "'@SUM(A1:B1)")
self.assertEqual(sanitize_dangerous_content("+SUM(A1:B1)"), "'+SUM(A1:B1)")
# Test HTML/script sanitization
self.assertEqual(
sanitize_dangerous_content("<script>alert('XSS')</script>"), "[REMOVED]"
)
self.assertEqual(
sanitize_dangerous_content("javascript:alert('XSS')"),
"[REMOVED]alert('XSS')",
)
# Test command injection sanitization
self.assertEqual(
sanitize_dangerous_content("value; rm -rf /"), "value rm -rf /"
)
self.assertEqual(
sanitize_dangerous_content("value | cat /etc/passwd"),
"value cat /etc/passwd",
)
def test_sanitize_dataframe(self):
"""Test sanitizing a DataFrame."""
# Create a test DataFrame with potentially dangerous content
df = pd.DataFrame(
{
"Symbol": ["AAPL", "=SUM(A1:B1)", "MSFT"],
"Description": [
"Apple Inc",
'<script>alert("XSS")</script>',
"Microsoft Corp",
],
"Quantity": [100, 200, 300],
"Last Price": ["$150.00", '=HYPERLINK("malicious.com")', "$250.00"],
}
)
# Sanitize the DataFrame
sanitized_df = sanitize_dataframe(df)
# Check that the dangerous content was sanitized
self.assertEqual(sanitized_df.loc[1, "Symbol"], "'=SUM(A1:B1)")
self.assertEqual(sanitized_df.loc[1, "Description"], "[REMOVED]")
self.assertEqual(
sanitized_df.loc[1, "Last Price"], '\'=HYPERLINK("malicious.com")'
)
# Check that safe content was not modified
self.assertEqual(sanitized_df.loc[0, "Symbol"], "AAPL")
self.assertEqual(sanitized_df.loc[0, "Description"], "Apple Inc")
self.assertEqual(sanitized_df.loc[0, "Quantity"], 100)
def test_validate_csv_upload(self):
"""Test validating a CSV upload."""
# Create a valid CSV file
df = pd.DataFrame(
{
"Symbol": ["AAPL", "MSFT", "GOOGL"],
"Quantity": [100, 200, 300],
"Last Price": ["$150.00", "$250.00", "$2,500.00"],
}
)
# Convert to CSV and encode as base64
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False)
csv_str = csv_buffer.getvalue()
b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8")
contents = f"data:text/csv;base64,{b64_content}"
# Validate the CSV upload
result_df, error = validate_csv_upload(contents, "valid.csv")
# Check that validation passed
self.assertIsNone(error)
self.assertEqual(len(result_df), 3)
# Create a CSV with malicious content
df = pd.DataFrame(
{
"Symbol": ["AAPL", "=SUM(A1:B1)", "MSFT"],
"Quantity": [100, 200, 300],
"Last Price": ["$150.00", '=HYPERLINK("malicious.com")', "$250.00"],
}
)
# Convert to CSV and encode as base64
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False)
csv_str = csv_buffer.getvalue()
b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8")
contents = f"data:text/csv;base64,{b64_content}"
# Validate the CSV upload
result_df, error = validate_csv_upload(contents, "malicious.csv")
# Check that validation passed but content was sanitized
self.assertIsNone(error)
self.assertEqual(result_df.loc[1, "Symbol"], "'=SUM(A1:B1)")
self.assertEqual(
result_df.loc[1, "Last Price"], '\'=HYPERLINK("malicious.com")'
)
# Test with invalid file extension
with self.assertRaises(ValueError) as context:
validate_csv_upload(contents, "invalid.txt")
self.assertIn("Only CSV files are supported", str(context.exception))
# Test with missing required columns
df = pd.DataFrame(
{
"Symbol": ["AAPL", "MSFT", "GOOGL"],
# Missing 'Quantity' and 'Last Price'
}
)
# Convert to CSV and encode as base64
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False)
csv_str = csv_buffer.getvalue()
b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8")
contents = f"data:text/csv;base64,{b64_content}"
# Validate the CSV upload
with self.assertRaises(ValueError) as context:
validate_csv_upload(contents, "missing_columns.csv")
self.assertIn("Missing required columns", str(context.exception))
if __name__ == "__main__":
unittest.main()
|