"""
Tests for the security module.
"""
import base64
import io
import os
import sys
import unittest
import pandas as pd
sys.path.insert(0, os.path.abspath(".."))
from src.folio.security import (
sanitize_cell,
sanitize_dangerous_content,
sanitize_dataframe,
sanitize_formula,
validate_csv_upload,
)
class TestSecurity(unittest.TestCase):
"""Test cases for the security module."""
def test_sanitize_cell(self):
"""Test sanitizing individual cell values."""
# Test formula sanitization
self.assertEqual(sanitize_cell("=SUM(A1:B1)"), "'=SUM(A1:B1)")
self.assertEqual(sanitize_cell("@SUM(A1:B1)"), "'@SUM(A1:B1)")
self.assertEqual(sanitize_cell("+SUM(A1:B1)"), "+SUM(A1:B1)")
# Test HTML/script sanitization
self.assertEqual(sanitize_cell(""), "[REMOVED]")
self.assertEqual(
sanitize_cell("javascript:alert('XSS')"), "[REMOVED]alert('XSS')"
)
self.assertEqual(
sanitize_cell("
"),
"
",
)
# Test command injection sanitization
self.assertEqual(sanitize_cell("value; rm -rf /"), "value rm -rf /")
self.assertEqual(
sanitize_cell("value | cat /etc/passwd"), "value cat /etc/passwd"
)
# Test non-string values
self.assertEqual(sanitize_cell(123), "123")
self.assertEqual(sanitize_cell(None), "None")
# Test negative numbers (should not be modified)
self.assertEqual(sanitize_cell("-123"), "-123")
self.assertEqual(sanitize_cell("-123.45"), "-123.45")
# Test financial values (should not be modified)
self.assertEqual(sanitize_cell("$123.45"), "$123.45")
self.assertEqual(sanitize_cell("-$123.45"), "-$123.45")
self.assertEqual(sanitize_cell("+$123.45"), "+$123.45")
# Test percentage values (should not be modified)
self.assertEqual(sanitize_cell("-12.34%"), "-12.34%")
self.assertEqual(sanitize_cell("+12.34%"), "+12.34%")
self.assertEqual(sanitize_cell("12.34%"), "12.34%")
# Test stock names with ampersands (should not be modified)
self.assertEqual(sanitize_cell("S&P 500"), "S&P 500")
self.assertEqual(
sanitize_cell("PROSHARES ULTRAPRO S&P500"), "PROSHARES ULTRAPRO S&P500"
)
def test_sanitize_formula(self):
"""Test sanitizing formula-like content."""
self.assertEqual(sanitize_formula("=SUM(A1:B1)"), "'=SUM(A1:B1)")
self.assertEqual(sanitize_formula("@SUM(A1:B1)"), "'@SUM(A1:B1)")
self.assertEqual(sanitize_formula("+SUM(A1:B1)"), "'+SUM(A1:B1)")
# Test that normal text is not modified
self.assertEqual(sanitize_formula("Normal text"), "Normal text")
# Test that negative numbers are not modified
self.assertEqual(sanitize_formula("-123"), "-123")
# Test financial values (should not be modified)
self.assertEqual(sanitize_formula("$123.45"), "$123.45")
self.assertEqual(sanitize_formula("-$123.45"), "-$123.45")
self.assertEqual(sanitize_formula("+$123.45"), "+$123.45")
# Test percentage values (should not be modified)
self.assertEqual(sanitize_formula("-12.34%"), "-12.34%")
self.assertEqual(sanitize_formula("+12.34%"), "+12.34%")
self.assertEqual(sanitize_formula("12.34%"), "12.34%")
def test_sanitize_dangerous_content(self):
"""Test sanitizing dangerous content while preserving financial data."""
# Test financial values (should not be modified)
self.assertEqual(sanitize_dangerous_content("$123.45"), "$123.45")
self.assertEqual(sanitize_dangerous_content("-$123.45"), "-$123.45")
self.assertEqual(sanitize_dangerous_content("+$123.45"), "+$123.45")
# Test percentage values (should not be modified)
self.assertEqual(sanitize_dangerous_content("-12.34%"), "-12.34%")
self.assertEqual(sanitize_dangerous_content("+12.34%"), "+12.34%")
self.assertEqual(sanitize_dangerous_content("12.34%"), "12.34%")
# Test stock names with ampersands (should not be modified)
self.assertEqual(sanitize_dangerous_content("S&P 500"), "S&P 500")
self.assertEqual(
sanitize_dangerous_content("PROSHARES ULTRAPRO S&P500"),
"PROSHARES ULTRAPRO S&P500",
)
# Test formula sanitization
self.assertEqual(sanitize_dangerous_content("=SUM(A1:B1)"), "'=SUM(A1:B1)")
self.assertEqual(sanitize_dangerous_content("@SUM(A1:B1)"), "'@SUM(A1:B1)")
self.assertEqual(sanitize_dangerous_content("+SUM(A1:B1)"), "'+SUM(A1:B1)")
# Test HTML/script sanitization
self.assertEqual(
sanitize_dangerous_content(""), "[REMOVED]"
)
self.assertEqual(
sanitize_dangerous_content("javascript:alert('XSS')"),
"[REMOVED]alert('XSS')",
)
# Test command injection sanitization
self.assertEqual(
sanitize_dangerous_content("value; rm -rf /"), "value rm -rf /"
)
self.assertEqual(
sanitize_dangerous_content("value | cat /etc/passwd"),
"value cat /etc/passwd",
)
def test_sanitize_dataframe(self):
"""Test sanitizing a DataFrame."""
# Create a test DataFrame with potentially dangerous content
df = pd.DataFrame(
{
"Symbol": ["AAPL", "=SUM(A1:B1)", "MSFT"],
"Description": [
"Apple Inc",
'',
"Microsoft Corp",
],
"Quantity": [100, 200, 300],
"Last Price": ["$150.00", '=HYPERLINK("malicious.com")', "$250.00"],
}
)
# Sanitize the DataFrame
sanitized_df = sanitize_dataframe(df)
# Check that the dangerous content was sanitized
self.assertEqual(sanitized_df.loc[1, "Symbol"], "'=SUM(A1:B1)")
self.assertEqual(sanitized_df.loc[1, "Description"], "[REMOVED]")
self.assertEqual(
sanitized_df.loc[1, "Last Price"], '\'=HYPERLINK("malicious.com")'
)
# Check that safe content was not modified
self.assertEqual(sanitized_df.loc[0, "Symbol"], "AAPL")
self.assertEqual(sanitized_df.loc[0, "Description"], "Apple Inc")
self.assertEqual(sanitized_df.loc[0, "Quantity"], 100)
def test_validate_csv_upload(self):
"""Test validating a CSV upload."""
# Create a valid CSV file
df = pd.DataFrame(
{
"Symbol": ["AAPL", "MSFT", "GOOGL"],
"Quantity": [100, 200, 300],
"Last Price": ["$150.00", "$250.00", "$2,500.00"],
}
)
# Convert to CSV and encode as base64
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False)
csv_str = csv_buffer.getvalue()
b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8")
contents = f"data:text/csv;base64,{b64_content}"
# Validate the CSV upload
result_df, error = validate_csv_upload(contents, "valid.csv")
# Check that validation passed
self.assertIsNone(error)
self.assertEqual(len(result_df), 3)
# Create a CSV with malicious content
df = pd.DataFrame(
{
"Symbol": ["AAPL", "=SUM(A1:B1)", "MSFT"],
"Quantity": [100, 200, 300],
"Last Price": ["$150.00", '=HYPERLINK("malicious.com")', "$250.00"],
}
)
# Convert to CSV and encode as base64
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False)
csv_str = csv_buffer.getvalue()
b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8")
contents = f"data:text/csv;base64,{b64_content}"
# Validate the CSV upload
result_df, error = validate_csv_upload(contents, "malicious.csv")
# Check that validation passed but content was sanitized
self.assertIsNone(error)
self.assertEqual(result_df.loc[1, "Symbol"], "'=SUM(A1:B1)")
self.assertEqual(
result_df.loc[1, "Last Price"], '\'=HYPERLINK("malicious.com")'
)
# Test with invalid file extension
with self.assertRaises(ValueError) as context:
validate_csv_upload(contents, "invalid.txt")
self.assertIn("Only CSV files are supported", str(context.exception))
# Test with missing required columns
df = pd.DataFrame(
{
"Symbol": ["AAPL", "MSFT", "GOOGL"],
# Missing 'Quantity' and 'Last Price'
}
)
# Convert to CSV and encode as base64
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False)
csv_str = csv_buffer.getvalue()
b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8")
contents = f"data:text/csv;base64,{b64_content}"
# Validate the CSV upload
with self.assertRaises(ValueError) as context:
validate_csv_upload(contents, "missing_columns.csv")
self.assertIn("Missing required columns", str(context.exception))
if __name__ == "__main__":
unittest.main()