|
|
""" |
|
|
Tests for the security module. |
|
|
""" |
|
|
|
|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
import sys |
|
|
import unittest |
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
sys.path.insert(0, os.path.abspath("..")) |
|
|
|
|
|
from src.folio.security import ( |
|
|
sanitize_cell, |
|
|
sanitize_dangerous_content, |
|
|
sanitize_dataframe, |
|
|
sanitize_formula, |
|
|
validate_csv_upload, |
|
|
) |
|
|
|
|
|
|
|
|
class TestSecurity(unittest.TestCase): |
|
|
"""Test cases for the security module.""" |
|
|
|
|
|
def test_sanitize_cell(self): |
|
|
"""Test sanitizing individual cell values.""" |
|
|
|
|
|
self.assertEqual(sanitize_cell("=SUM(A1:B1)"), "'=SUM(A1:B1)") |
|
|
self.assertEqual(sanitize_cell("@SUM(A1:B1)"), "'@SUM(A1:B1)") |
|
|
self.assertEqual(sanitize_cell("+SUM(A1:B1)"), "+SUM(A1:B1)") |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_cell("<script>alert('XSS')</script>"), "[REMOVED]") |
|
|
self.assertEqual( |
|
|
sanitize_cell("javascript:alert('XSS')"), "[REMOVED]alert('XSS')" |
|
|
) |
|
|
self.assertEqual( |
|
|
sanitize_cell("<img src=x onerror=alert('XSS')>"), |
|
|
"<img src=x [REMOVED]=alert('XSS')>", |
|
|
) |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_cell("value; rm -rf /"), "value rm -rf /") |
|
|
self.assertEqual( |
|
|
sanitize_cell("value | cat /etc/passwd"), "value cat /etc/passwd" |
|
|
) |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_cell(123), "123") |
|
|
self.assertEqual(sanitize_cell(None), "None") |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_cell("-123"), "-123") |
|
|
self.assertEqual(sanitize_cell("-123.45"), "-123.45") |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_cell("$123.45"), "$123.45") |
|
|
self.assertEqual(sanitize_cell("-$123.45"), "-$123.45") |
|
|
self.assertEqual(sanitize_cell("+$123.45"), "+$123.45") |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_cell("-12.34%"), "-12.34%") |
|
|
self.assertEqual(sanitize_cell("+12.34%"), "+12.34%") |
|
|
self.assertEqual(sanitize_cell("12.34%"), "12.34%") |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_cell("S&P 500"), "S&P 500") |
|
|
self.assertEqual( |
|
|
sanitize_cell("PROSHARES ULTRAPRO S&P500"), "PROSHARES ULTRAPRO S&P500" |
|
|
) |
|
|
|
|
|
def test_sanitize_formula(self): |
|
|
"""Test sanitizing formula-like content.""" |
|
|
self.assertEqual(sanitize_formula("=SUM(A1:B1)"), "'=SUM(A1:B1)") |
|
|
self.assertEqual(sanitize_formula("@SUM(A1:B1)"), "'@SUM(A1:B1)") |
|
|
self.assertEqual(sanitize_formula("+SUM(A1:B1)"), "'+SUM(A1:B1)") |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_formula("Normal text"), "Normal text") |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_formula("-123"), "-123") |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_formula("$123.45"), "$123.45") |
|
|
self.assertEqual(sanitize_formula("-$123.45"), "-$123.45") |
|
|
self.assertEqual(sanitize_formula("+$123.45"), "+$123.45") |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_formula("-12.34%"), "-12.34%") |
|
|
self.assertEqual(sanitize_formula("+12.34%"), "+12.34%") |
|
|
self.assertEqual(sanitize_formula("12.34%"), "12.34%") |
|
|
|
|
|
def test_sanitize_dangerous_content(self): |
|
|
"""Test sanitizing dangerous content while preserving financial data.""" |
|
|
|
|
|
self.assertEqual(sanitize_dangerous_content("$123.45"), "$123.45") |
|
|
self.assertEqual(sanitize_dangerous_content("-$123.45"), "-$123.45") |
|
|
self.assertEqual(sanitize_dangerous_content("+$123.45"), "+$123.45") |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_dangerous_content("-12.34%"), "-12.34%") |
|
|
self.assertEqual(sanitize_dangerous_content("+12.34%"), "+12.34%") |
|
|
self.assertEqual(sanitize_dangerous_content("12.34%"), "12.34%") |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_dangerous_content("S&P 500"), "S&P 500") |
|
|
self.assertEqual( |
|
|
sanitize_dangerous_content("PROSHARES ULTRAPRO S&P500"), |
|
|
"PROSHARES ULTRAPRO S&P500", |
|
|
) |
|
|
|
|
|
|
|
|
self.assertEqual(sanitize_dangerous_content("=SUM(A1:B1)"), "'=SUM(A1:B1)") |
|
|
self.assertEqual(sanitize_dangerous_content("@SUM(A1:B1)"), "'@SUM(A1:B1)") |
|
|
self.assertEqual(sanitize_dangerous_content("+SUM(A1:B1)"), "'+SUM(A1:B1)") |
|
|
|
|
|
|
|
|
self.assertEqual( |
|
|
sanitize_dangerous_content("<script>alert('XSS')</script>"), "[REMOVED]" |
|
|
) |
|
|
self.assertEqual( |
|
|
sanitize_dangerous_content("javascript:alert('XSS')"), |
|
|
"[REMOVED]alert('XSS')", |
|
|
) |
|
|
|
|
|
|
|
|
self.assertEqual( |
|
|
sanitize_dangerous_content("value; rm -rf /"), "value rm -rf /" |
|
|
) |
|
|
self.assertEqual( |
|
|
sanitize_dangerous_content("value | cat /etc/passwd"), |
|
|
"value cat /etc/passwd", |
|
|
) |
|
|
|
|
|
def test_sanitize_dataframe(self): |
|
|
"""Test sanitizing a DataFrame.""" |
|
|
|
|
|
df = pd.DataFrame( |
|
|
{ |
|
|
"Symbol": ["AAPL", "=SUM(A1:B1)", "MSFT"], |
|
|
"Description": [ |
|
|
"Apple Inc", |
|
|
'<script>alert("XSS")</script>', |
|
|
"Microsoft Corp", |
|
|
], |
|
|
"Quantity": [100, 200, 300], |
|
|
"Last Price": ["$150.00", '=HYPERLINK("malicious.com")', "$250.00"], |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
sanitized_df = sanitize_dataframe(df) |
|
|
|
|
|
|
|
|
self.assertEqual(sanitized_df.loc[1, "Symbol"], "'=SUM(A1:B1)") |
|
|
self.assertEqual(sanitized_df.loc[1, "Description"], "[REMOVED]") |
|
|
self.assertEqual( |
|
|
sanitized_df.loc[1, "Last Price"], '\'=HYPERLINK("malicious.com")' |
|
|
) |
|
|
|
|
|
|
|
|
self.assertEqual(sanitized_df.loc[0, "Symbol"], "AAPL") |
|
|
self.assertEqual(sanitized_df.loc[0, "Description"], "Apple Inc") |
|
|
self.assertEqual(sanitized_df.loc[0, "Quantity"], 100) |
|
|
|
|
|
def test_validate_csv_upload(self): |
|
|
"""Test validating a CSV upload.""" |
|
|
|
|
|
df = pd.DataFrame( |
|
|
{ |
|
|
"Symbol": ["AAPL", "MSFT", "GOOGL"], |
|
|
"Quantity": [100, 200, 300], |
|
|
"Last Price": ["$150.00", "$250.00", "$2,500.00"], |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
csv_buffer = io.StringIO() |
|
|
df.to_csv(csv_buffer, index=False) |
|
|
csv_str = csv_buffer.getvalue() |
|
|
b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8") |
|
|
contents = f"data:text/csv;base64,{b64_content}" |
|
|
|
|
|
|
|
|
result_df, error = validate_csv_upload(contents, "valid.csv") |
|
|
|
|
|
|
|
|
self.assertIsNone(error) |
|
|
self.assertEqual(len(result_df), 3) |
|
|
|
|
|
|
|
|
df = pd.DataFrame( |
|
|
{ |
|
|
"Symbol": ["AAPL", "=SUM(A1:B1)", "MSFT"], |
|
|
"Quantity": [100, 200, 300], |
|
|
"Last Price": ["$150.00", '=HYPERLINK("malicious.com")', "$250.00"], |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
csv_buffer = io.StringIO() |
|
|
df.to_csv(csv_buffer, index=False) |
|
|
csv_str = csv_buffer.getvalue() |
|
|
b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8") |
|
|
contents = f"data:text/csv;base64,{b64_content}" |
|
|
|
|
|
|
|
|
result_df, error = validate_csv_upload(contents, "malicious.csv") |
|
|
|
|
|
|
|
|
self.assertIsNone(error) |
|
|
self.assertEqual(result_df.loc[1, "Symbol"], "'=SUM(A1:B1)") |
|
|
self.assertEqual( |
|
|
result_df.loc[1, "Last Price"], '\'=HYPERLINK("malicious.com")' |
|
|
) |
|
|
|
|
|
|
|
|
with self.assertRaises(ValueError) as context: |
|
|
validate_csv_upload(contents, "invalid.txt") |
|
|
self.assertIn("Only CSV files are supported", str(context.exception)) |
|
|
|
|
|
|
|
|
df = pd.DataFrame( |
|
|
{ |
|
|
"Symbol": ["AAPL", "MSFT", "GOOGL"], |
|
|
|
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
csv_buffer = io.StringIO() |
|
|
df.to_csv(csv_buffer, index=False) |
|
|
csv_str = csv_buffer.getvalue() |
|
|
b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8") |
|
|
contents = f"data:text/csv;base64,{b64_content}" |
|
|
|
|
|
|
|
|
with self.assertRaises(ValueError) as context: |
|
|
validate_csv_upload(contents, "missing_columns.csv") |
|
|
self.assertIn("Missing required columns", str(context.exception)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
unittest.main() |
|
|
|