dystomachina commited on
Commit
ff73b92
·
1 Parent(s): 779ae91

feat: entrypoint for focli simulate

Browse files

can run the `focli` with --simulate to get directly to simulate
various other formatting and lint fixes

.pre-commit-config.yaml CHANGED
@@ -3,8 +3,11 @@ repos:
3
  rev: v0.11.4
4
  hooks:
5
  - id: ruff
6
- args: [--fix, --unsafe-fixes]
 
7
  - id: ruff-format
 
 
8
 
9
  - repo: https://github.com/pre-commit/pre-commit-hooks
10
  rev: v4.5.0
 
3
  rev: v0.11.4
4
  hooks:
5
  - id: ruff
6
+ args: [--fix]
7
+ files: ^(src|tests)/
8
  - id: ruff-format
9
+ files: ^(src|tests)/
10
+ types_or: [python, pyi]
11
 
12
  - repo: https://github.com/pre-commit/pre-commit-hooks
13
  rev: v4.5.0
.windsurfrules ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ description: Rules to get the AI to behave
3
+ alwaysApply: true
4
+ ---
5
+ # General rules for AI
6
+ - Prior to generating any code, carefully read the project conventions
7
+ - Read [project-design.md](docs/project-design.md) to understand the codebase
8
+ - Read [project-conventions.md](docs/project-conventions.md) to understand _how_ to write code for the codebase
9
+ - Run `make lint` and `make test` after every change. `lint` in particular can be run very frequently.
10
+ - When user starts a prompt with `QQ:` or `Question:`, just answer the question or prompt without producing code.
11
+ - Prefer small testable steps, after each step give a summary to the user and summarize the next step
12
+ - Maintain strict separation of concerns: Business logic MUST reside in the core library (`src/folio/`), not in interface layers (`src/focli/`). Interface layers should only handle user interaction, command parsing, and result presentation.
13
+ - Use `.docs/` for temporary documentation such as project plans or logs
14
+
15
+ ## Prohibited actions
16
+
17
+ - Do not run `make folio`. This is for the user to run only.
18
+ - Do not use `git` commands unless explicitly asked.
19
+ - Do not use `docker` commands unless explicitly asked.
Makefile CHANGED
@@ -112,7 +112,7 @@ lint:
112
  @mkdir -p $(LOGS_DIR)
113
  @(echo "=== Code Check Log $(TIMESTAMP) ===" && \
114
  echo "Starting checks at: $$(date)" && \
115
- $(POETRY) run ruff check --fix --unsafe-fixes . \
116
  2>&1) | tee $(LOGS_DIR)/code_check_latest.log
117
  @echo "Check log saved to: $(LOGS_DIR)/code_check_latest.log"
118
 
 
112
  @mkdir -p $(LOGS_DIR)
113
  @(echo "=== Code Check Log $(TIMESTAMP) ===" && \
114
  echo "Starting checks at: $$(date)" && \
115
+ $(POETRY) run ruff check --fix src/ tests/ \
116
  2>&1) | tee $(LOGS_DIR)/code_check_latest.log
117
  @echo "Check log saved to: $(LOGS_DIR)/code_check_latest.log"
118
 
poetry.lock CHANGED
@@ -1426,6 +1426,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
1426
  [package.extras]
1427
  dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
1428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1429
  [[package]]
1430
  name = "python-dateutil"
1431
  version = "2.9.0.post0"
@@ -1921,4 +1939,4 @@ type = ["pytest-mypy"]
1921
  [metadata]
1922
  lock-version = "2.1"
1923
  python-versions = "^3.9"
1924
- content-hash = "98fdada28ba7fcb3cfdbafe0bddcfcd41eec8384218d942c8991abe7040a4f4d"
 
1426
  [package.extras]
1427
  dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
1428
 
1429
+ [[package]]
1430
+ name = "pytest-mock"
1431
+ version = "3.14.0"
1432
+ description = "Thin-wrapper around the mock package for easier use with pytest"
1433
+ optional = false
1434
+ python-versions = ">=3.8"
1435
+ groups = ["dev"]
1436
+ files = [
1437
+ {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
1438
+ {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
1439
+ ]
1440
+
1441
+ [package.dependencies]
1442
+ pytest = ">=6.2.5"
1443
+
1444
+ [package.extras]
1445
+ dev = ["pre-commit", "pytest-asyncio", "tox"]
1446
+
1447
  [[package]]
1448
  name = "python-dateutil"
1449
  version = "2.9.0.post0"
 
1939
  [metadata]
1940
  lock-version = "2.1"
1941
  python-versions = "^3.9"
1942
+ content-hash = "85ddf4d6c12a985cab4eec4cca4a115f95cb23840d6a04793f54f5b53d2a110c"
pyproject.toml CHANGED
@@ -25,6 +25,7 @@ google-generativeai = ">=0.3.0"
25
  [tool.poetry.group.dev.dependencies]
26
  ruff = "^0.11.7"
27
  pytest = "^8.3.5"
 
28
  rich = ">=13.9.0"
29
  prompt-toolkit = ">=3.0.43"
30
  pre-commit = "^4.2.0"
 
25
  [tool.poetry.group.dev.dependencies]
26
  ruff = "^0.11.7"
27
  pytest = "^8.3.5"
28
+ pytest-mock = "^3.14.0" # Added for mocking in tests
29
  rich = ">=13.9.0"
30
  prompt-toolkit = ">=3.0.43"
31
  pre-commit = "^4.2.0"
scripts/run_mlflow.py CHANGED
@@ -11,9 +11,19 @@ import sys
11
 
12
  def main():
13
  """Start the MLflow UI server"""
14
- parser = argparse.ArgumentParser(description='Start the MLflow UI server')
15
- parser.add_argument('--port', type=int, default=5000, help='Port to run the server on (default: 5000)')
16
- parser.add_argument('--host', type=str, default='127.0.0.1', help='Host to run the server on (default: 127.0.0.1)')
 
 
 
 
 
 
 
 
 
 
17
  args = parser.parse_args()
18
 
19
  # Get the project root directory
@@ -21,16 +31,19 @@ def main():
21
  project_root = os.path.dirname(script_dir)
22
 
23
  # Set the MLflow tracking URI
24
- mlruns_dir = os.path.join(project_root, 'mlruns')
25
  tracking_uri = f"file:{mlruns_dir}"
26
 
27
-
28
  # Start the MLflow UI
29
  cmd = [
30
- "mlflow", "ui",
31
- "--backend-store-uri", tracking_uri,
32
- "--host", args.host,
33
- "--port", str(args.port)
 
 
 
 
34
  ]
35
 
36
  try:
@@ -40,5 +53,6 @@ def main():
40
  except Exception:
41
  sys.exit(1)
42
 
 
43
  if __name__ == "__main__":
44
  main()
 
11
 
12
  def main():
13
  """Start the MLflow UI server"""
14
+ parser = argparse.ArgumentParser(description="Start the MLflow UI server")
15
+ parser.add_argument(
16
+ "--port",
17
+ type=int,
18
+ default=5000,
19
+ help="Port to run the server on (default: 5000)",
20
+ )
21
+ parser.add_argument(
22
+ "--host",
23
+ type=str,
24
+ default="127.0.0.1",
25
+ help="Host to run the server on (default: 127.0.0.1)",
26
+ )
27
  args = parser.parse_args()
28
 
29
  # Get the project root directory
 
31
  project_root = os.path.dirname(script_dir)
32
 
33
  # Set the MLflow tracking URI
34
+ mlruns_dir = os.path.join(project_root, "mlruns")
35
  tracking_uri = f"file:{mlruns_dir}"
36
 
 
37
  # Start the MLflow UI
38
  cmd = [
39
+ "mlflow",
40
+ "ui",
41
+ "--backend-store-uri",
42
+ tracking_uri,
43
+ "--host",
44
+ args.host,
45
+ "--port",
46
+ str(args.port),
47
  ]
48
 
49
  try:
 
53
  except Exception:
54
  sys.exit(1)
55
 
56
+
57
  if __name__ == "__main__":
58
  main()
src/focli/README.md CHANGED
@@ -16,6 +16,23 @@ python src/focli/focli.py
16
 
17
  This will launch the interactive shell where you can enter commands.
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  ## Why Use Folio CLI?
20
 
21
  - **Speed**: Get answers in seconds without waiting for GUI elements to load
@@ -57,7 +74,7 @@ See how your portfolio might perform across different market scenarios:
57
  - `--steps 13` - Set the number of data points in the simulation
58
  - `--detailed` - Show position-level details in the simulation
59
  - `--focus SPY,AAPL` - Focus on specific positions
60
- - `--preset <name>` - Use a saved parameter preset
61
  - `--save-preset <name>` - Save current parameters as a preset
62
  - `--filter options` - Run simulation only on positions with options
63
 
@@ -91,6 +108,7 @@ Exit the application.
91
  3. **Drill down with position commands** to understand specific holdings
92
  4. **Save presets** for analyses you run frequently
93
  5. **Use filtering** to focus on segments of your portfolio
 
94
 
95
  ## Example Workflow
96
 
 
16
 
17
  This will launch the interactive shell where you can enter commands.
18
 
19
+ ### Direct Simulation Mode
20
+
21
+ You can also run simulations directly from the command line without entering the interactive shell:
22
+
23
+ ```bash
24
+ # Run simulation with default parameters
25
+ python src/focli/focli.py --simulate
26
+
27
+ # Run a quick simulation (fewer steps, smaller range)
28
+ python src/focli/focli.py --simulate --preset quick
29
+
30
+ # Run a detailed simulation (more steps)
31
+ python src/focli/focli.py --simulate --preset detailed
32
+ ```
33
+
34
+ This is useful for quickly checking how your portfolio might perform under different market conditions.
35
+
36
  ## Why Use Folio CLI?
37
 
38
  - **Speed**: Get answers in seconds without waiting for GUI elements to load
 
74
  - `--steps 13` - Set the number of data points in the simulation
75
  - `--detailed` - Show position-level details in the simulation
76
  - `--focus SPY,AAPL` - Focus on specific positions
77
+ - `--preset <name>` - Use a saved parameter preset (default, quick, detailed)
78
  - `--save-preset <name>` - Save current parameters as a preset
79
  - `--filter options` - Run simulation only on positions with options
80
 
 
108
  3. **Drill down with position commands** to understand specific holdings
109
  4. **Save presets** for analyses you run frequently
110
  5. **Use filtering** to focus on segments of your portfolio
111
+ 6. **Use direct simulation mode** for quick portfolio checks
112
 
113
  ## Example Workflow
114
 
src/focli/focli.py CHANGED
@@ -6,9 +6,15 @@ This script provides an interactive shell for running portfolio simulations,
6
  analyzing positions, and exploring investment scenarios.
7
 
8
  Usage:
9
- python src/focli/focli.py
 
 
10
 
11
- Commands:
 
 
 
 
12
  help Show help information
13
  simulate spy Simulate portfolio performance with SPY changes
14
  position <ticker> Analyze a specific position group
 
6
  analyzing positions, and exploring investment scenarios.
7
 
8
  Usage:
9
+ python src/focli/focli.py # Start interactive shell
10
+ python src/focli/focli.py --simulate # Run simulation directly
11
+ python src/focli/focli.py --simulate --preset quick # Run quick simulation
12
 
13
+ Command-line Options:
14
+ --simulate Run portfolio simulation directly without entering interactive shell
15
+ --preset NAME Use a specific simulation preset (default, quick, detailed)
16
+
17
+ Interactive Commands:
18
  help Show help information
19
  simulate spy Simulate portfolio performance with SPY changes
20
  position <ticker> Analyze a specific position group
src/focli/shell.py CHANGED
@@ -4,7 +4,9 @@ Interactive shell for the Folio CLI.
4
  This module provides the main entry point for the Folio CLI interactive shell.
5
  """
6
 
 
7
  import os
 
8
 
9
  from prompt_toolkit import PromptSession
10
  from prompt_toolkit.completion import NestedCompleter
@@ -13,6 +15,7 @@ from prompt_toolkit.shortcuts import confirm
13
  from rich.console import Console
14
 
15
  from src.focli.commands import execute_command, get_command_registry
 
16
  from src.focli.utils import load_portfolio
17
 
18
 
@@ -38,52 +41,106 @@ def create_completer():
38
  return NestedCompleter.from_nested_dict(completion_dict)
39
 
40
 
41
- def main():
42
- """Main entry point for the Folio CLI."""
43
- console = Console()
44
- console.print("[bold cyan]Folio Interactive Shell[/bold cyan]")
45
- console.print("Type 'help' for available commands.")
46
-
47
- # Create history file in user's home directory
48
- history_file = os.path.expanduser("~/.folio_history")
49
-
50
- # Create session with auto-completion and history
51
- session = PromptSession(
52
- completer=create_completer(), history=FileHistory(history_file)
53
- )
54
 
55
- # Initialize application state
56
- state = {
57
- # Portfolio data
 
58
  "portfolio_groups": None,
59
  "portfolio_summary": None,
60
  "loaded_portfolio": None,
61
- # Simulation results
62
  "last_simulation": None,
63
  "simulation_history": [],
64
- # Position analysis
65
  "last_position": None,
66
  "position_simulations": {},
67
  "filtered_groups": None,
68
- # Parameter presets
69
  "simulation_presets": {
70
  "default": {"range": 20.0, "steps": 13},
71
  "detailed": {"range": 20.0, "steps": 21, "detailed": True},
72
  "quick": {"range": 10.0, "steps": 5},
73
  },
74
- # Session history
75
  "command_history": [],
76
  }
77
 
78
- # Try to load default portfolio
 
 
 
 
 
 
 
 
 
 
79
  default_portfolio = "private-data/portfolio-private.csv"
80
  try:
81
  load_portfolio(default_portfolio, state, console)
 
82
  except Exception as e:
83
  console.print(f"[yellow]Could not load default portfolio: {e}[/yellow]")
84
  console.print(
85
  "[yellow]Use 'portfolio load <path>' to load a portfolio.[/yellow]"
86
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Main REPL loop
89
  while True:
@@ -100,25 +157,22 @@ def main():
100
  break
101
  continue
102
 
 
 
 
103
  # Add to command history
104
  state["command_history"].append(text)
105
 
106
- # Process the command
107
- execute_command(text, state, console)
108
-
109
  except KeyboardInterrupt:
110
  # Handle Ctrl+C
111
- console.print("\n[yellow]Use 'exit' to exit the application.[/yellow]")
112
- continue
113
  except EOFError:
114
  # Handle Ctrl+D
115
- console.print("\nGoodbye!")
116
  break
117
  except Exception as e:
118
- # Handle unexpected errors
119
- console.print(f"[bold red]Error:[/bold red] {e!s}")
120
-
121
- console.print("Goodbye!")
122
 
123
 
124
  def confirm_exit():
 
4
  This module provides the main entry point for the Folio CLI interactive shell.
5
  """
6
 
7
+ import argparse
8
  import os
9
+ import traceback
10
 
11
  from prompt_toolkit import PromptSession
12
  from prompt_toolkit.completion import NestedCompleter
 
15
  from rich.console import Console
16
 
17
  from src.focli.commands import execute_command, get_command_registry
18
+ from src.focli.commands.simulate import simulate_command
19
  from src.focli.utils import load_portfolio
20
 
21
 
 
41
  return NestedCompleter.from_nested_dict(completion_dict)
42
 
43
 
44
+ def initialize_state():
45
+ """Initialize the application state.
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ Returns:
48
+ Dictionary containing the initial application state
49
+ """
50
+ return {
51
  "portfolio_groups": None,
52
  "portfolio_summary": None,
53
  "loaded_portfolio": None,
 
54
  "last_simulation": None,
55
  "simulation_history": [],
 
56
  "last_position": None,
57
  "position_simulations": {},
58
  "filtered_groups": None,
 
59
  "simulation_presets": {
60
  "default": {"range": 20.0, "steps": 13},
61
  "detailed": {"range": 20.0, "steps": 21, "detailed": True},
62
  "quick": {"range": 10.0, "steps": 5},
63
  },
 
64
  "command_history": [],
65
  }
66
 
67
+
68
+ def load_default_portfolio(state, console):
69
+ """Try to load the default portfolio.
70
+
71
+ Args:
72
+ state: Application state
73
+ console: Rich console for output
74
+
75
+ Returns:
76
+ True if portfolio was loaded successfully, False otherwise
77
+ """
78
  default_portfolio = "private-data/portfolio-private.csv"
79
  try:
80
  load_portfolio(default_portfolio, state, console)
81
+ return True
82
  except Exception as e:
83
  console.print(f"[yellow]Could not load default portfolio: {e}[/yellow]")
84
  console.print(
85
  "[yellow]Use 'portfolio load <path>' to load a portfolio.[/yellow]"
86
  )
87
+ return False
88
+
89
+
90
+ def main():
91
+ """Main entry point for the Folio CLI."""
92
+ # Parse command-line arguments
93
+ parser = argparse.ArgumentParser(description="Folio CLI")
94
+ parser.add_argument(
95
+ "--simulate", action="store_true", help="Run simulation directly"
96
+ )
97
+ parser.add_argument(
98
+ "--preset", type=str, help="Simulation preset to use (default, quick, detailed)"
99
+ )
100
+ args = parser.parse_args()
101
+
102
+ console = Console()
103
+
104
+ # Initialize application state
105
+ state = initialize_state()
106
+
107
+ # If direct simulation is requested
108
+ if args.simulate:
109
+ console.print("[bold cyan]Folio CLI - Direct Simulation[/bold cyan]")
110
+
111
+ # Try to load default portfolio
112
+ if load_default_portfolio(state, console):
113
+ # Run simulation with optional preset
114
+ sim_args = []
115
+ if args.preset:
116
+ sim_args = ["-p", args.preset]
117
+
118
+ # Execute simulation command
119
+ simulate_command(sim_args, state, console)
120
+ return
121
+ else:
122
+ console.print(
123
+ "[bold red]Error:[/bold red] Cannot run simulation without a portfolio."
124
+ )
125
+ console.print(
126
+ "Please run the CLI without --simulate to load a portfolio first."
127
+ )
128
+ return
129
+
130
+ # Regular interactive mode
131
+ console.print("[bold cyan]Folio Interactive Shell[/bold cyan]")
132
+ console.print("Type 'help' for available commands.")
133
+
134
+ # Create history file in user's home directory
135
+ history_file = os.path.expanduser("~/.folio_history")
136
+
137
+ # Create session with auto-completion and history
138
+ session = PromptSession(
139
+ completer=create_completer(), history=FileHistory(history_file)
140
+ )
141
+
142
+ # Try to load default portfolio
143
+ load_default_portfolio(state, console)
144
 
145
  # Main REPL loop
146
  while True:
 
157
  break
158
  continue
159
 
160
+ # Execute the command
161
+ execute_command(text, state, console)
162
+
163
  # Add to command history
164
  state["command_history"].append(text)
165
 
 
 
 
166
  except KeyboardInterrupt:
167
  # Handle Ctrl+C
168
+ console.print("[yellow]Use 'exit' to exit the application.[/yellow]")
 
169
  except EOFError:
170
  # Handle Ctrl+D
 
171
  break
172
  except Exception as e:
173
+ # Handle other exceptions
174
+ console.print(f"[bold red]Error:[/bold red] {e}")
175
+ console.print(traceback.format_exc())
 
176
 
177
 
178
  def confirm_exit():
src/folio/exceptions.py CHANGED
@@ -8,26 +8,31 @@ and handling for different error conditions.
8
 
9
  class FolioError(Exception):
10
  """Base class for all Folio application exceptions."""
 
11
  pass
12
 
13
 
14
  class DataError(FolioError):
15
  """Raised when there are issues with data processing or validation."""
 
16
  pass
17
 
18
 
19
  class PortfolioError(FolioError):
20
  """Raised when there are issues with portfolio operations."""
 
21
  pass
22
 
23
 
24
  class UIError(FolioError):
25
  """Raised when there are issues with the UI components."""
 
26
  pass
27
 
28
 
29
  class ConfigurationError(FolioError):
30
  """Raised when there are issues with application configuration."""
 
31
  pass
32
 
33
 
 
8
 
9
  class FolioError(Exception):
10
  """Base class for all Folio application exceptions."""
11
+
12
  pass
13
 
14
 
15
  class DataError(FolioError):
16
  """Raised when there are issues with data processing or validation."""
17
+
18
  pass
19
 
20
 
21
  class PortfolioError(FolioError):
22
  """Raised when there are issues with portfolio operations."""
23
+
24
  pass
25
 
26
 
27
  class UIError(FolioError):
28
  """Raised when there are issues with the UI components."""
29
+
30
  pass
31
 
32
 
33
  class ConfigurationError(FolioError):
34
  """Raised when there are issues with application configuration."""
35
+
36
  pass
37
 
38
 
src/folio/security.py CHANGED
@@ -99,7 +99,7 @@ def validate_csv_upload(
99
  # Check file size
100
  if len(decoded) > MAX_FILE_SIZE:
101
  logger.warning(f"File too large: {len(decoded)} bytes (max {MAX_FILE_SIZE})")
102
- raise ValueError(f"File too large (max {MAX_FILE_SIZE/1024/1024:.1f}MB)")
103
 
104
  # Parse CSV
105
  try:
 
99
  # Check file size
100
  if len(decoded) > MAX_FILE_SIZE:
101
  logger.warning(f"File too large: {len(decoded)} bytes (max {MAX_FILE_SIZE})")
102
+ raise ValueError(f"File too large (max {MAX_FILE_SIZE / 1024 / 1024:.1f}MB)")
103
 
104
  # Parse CSV
105
  try:
tests/e2e/conftest.py CHANGED
@@ -67,28 +67,31 @@ def processed_portfolio(portfolio_data):
67
  # Possible alternative: (groups, cash_like_positions)
68
  groups, cash_like_positions = result
69
  from src.folio.portfolio import calculate_portfolio_summary
 
70
  summary = calculate_portfolio_summary(groups, cash_like_positions, 0.0)
71
  else:
72
  # If result is not a tuple, it's likely just the groups
73
  groups = result
74
  from src.folio.portfolio import calculate_portfolio_summary
 
75
  summary = calculate_portfolio_summary(groups, [], 0.0)
76
  cash_like_positions = []
77
 
78
  # Ensure we have a valid summary object
79
- if not hasattr(summary, 'to_dict'):
80
  logger.error("Error: summary object does not have to_dict method")
81
  logger.error(f"Type of summary: {type(summary)}")
82
  # Create a minimal summary for testing
83
  # Import here to avoid circular imports
84
  from src.folio.data_model import ExposureBreakdown, PortfolioSummary
 
85
  empty_exposure = ExposureBreakdown()
86
  summary = PortfolioSummary(
87
  net_market_exposure=0.0,
88
  portfolio_beta=0.0,
89
  long_exposure=empty_exposure,
90
  short_exposure=empty_exposure,
91
- options_exposure=empty_exposure
92
  )
93
 
94
  # Convert summary to dictionary for use in tests
 
67
  # Possible alternative: (groups, cash_like_positions)
68
  groups, cash_like_positions = result
69
  from src.folio.portfolio import calculate_portfolio_summary
70
+
71
  summary = calculate_portfolio_summary(groups, cash_like_positions, 0.0)
72
  else:
73
  # If result is not a tuple, it's likely just the groups
74
  groups = result
75
  from src.folio.portfolio import calculate_portfolio_summary
76
+
77
  summary = calculate_portfolio_summary(groups, [], 0.0)
78
  cash_like_positions = []
79
 
80
  # Ensure we have a valid summary object
81
+ if not hasattr(summary, "to_dict"):
82
  logger.error("Error: summary object does not have to_dict method")
83
  logger.error(f"Type of summary: {type(summary)}")
84
  # Create a minimal summary for testing
85
  # Import here to avoid circular imports
86
  from src.folio.data_model import ExposureBreakdown, PortfolioSummary
87
+
88
  empty_exposure = ExposureBreakdown()
89
  summary = PortfolioSummary(
90
  net_market_exposure=0.0,
91
  portfolio_beta=0.0,
92
  long_exposure=empty_exposure,
93
  short_exposure=empty_exposure,
94
+ options_exposure=empty_exposure,
95
  )
96
 
97
  # Convert summary to dictionary for use in tests
tests/e2e/test_exposures.py CHANGED
@@ -91,7 +91,6 @@ class TestExposures:
91
 
92
  summary_dict.get("pending_activity_value", 0.0)
93
 
94
-
95
  # Test that summary card values match position details
96
  assert abs(summary_net_exposure - total_ui_market_value) < 0.01, (
97
  f"Net Exposure in summary cards ({format_currency(summary_net_exposure)}) does not match the total market value shown in the UI ({format_currency(total_ui_market_value)})"
 
91
 
92
  summary_dict.get("pending_activity_value", 0.0)
93
 
 
94
  # Test that summary card values match position details
95
  assert abs(summary_net_exposure - total_ui_market_value) < 0.01, (
96
  f"Net Exposure in summary cards ({format_currency(summary_net_exposure)}) does not match the total market value shown in the UI ({format_currency(total_ui_market_value)})"
tests/test_data_model_serialization.py CHANGED
@@ -66,7 +66,9 @@ class TestDataModelSerialization(unittest.TestCase):
66
  stock = StockPosition.from_dict(stock_dict)
67
 
68
  # Check that market_value was calculated correctly
69
- self.assertEqual(stock.market_value, stock_dict["price"] * stock_dict["quantity"])
 
 
70
 
71
  def test_option_position_serialization(self):
72
  """Test that OptionPosition objects can be serialized and deserialized."""
@@ -138,7 +140,9 @@ class TestDataModelSerialization(unittest.TestCase):
138
  option = OptionPosition.from_dict(option_dict)
139
 
140
  # Check that market_value was calculated correctly with 100x multiplier
141
- self.assertEqual(option.market_value, option_dict["price"] * option_dict["quantity"] * 100)
 
 
142
 
143
  def test_portfolio_group_serialization(self):
144
  """Test that PortfolioGroup objects can be serialized and deserialized."""
@@ -202,50 +206,77 @@ class TestDataModelSerialization(unittest.TestCase):
202
  self.assertEqual(group.stock_position.ticker, group2.stock_position.ticker)
203
  self.assertEqual(group.stock_position.quantity, group2.stock_position.quantity)
204
  self.assertEqual(group.stock_position.beta, group2.stock_position.beta)
205
- self.assertEqual(group.stock_position.market_exposure, group2.stock_position.market_exposure)
 
 
206
  self.assertEqual(
207
  group.stock_position.beta_adjusted_exposure,
208
  group2.stock_position.beta_adjusted_exposure,
209
  )
210
  self.assertEqual(group.stock_position.price, group2.stock_position.price)
211
- self.assertEqual(group.stock_position.cost_basis, group2.stock_position.cost_basis)
212
- self.assertEqual(group.stock_position.market_value, group2.stock_position.market_value)
 
 
 
 
213
 
214
  # Check that the option position was deserialized correctly
215
- self.assertEqual(group.option_positions[0].ticker, group2.option_positions[0].ticker)
216
  self.assertEqual(
217
- group.option_positions[0].position_type, group2.option_positions[0].position_type
 
 
 
 
 
 
 
 
 
 
218
  )
219
- self.assertEqual(group.option_positions[0].quantity, group2.option_positions[0].quantity)
220
- self.assertEqual(group.option_positions[0].beta, group2.option_positions[0].beta)
221
  self.assertEqual(
222
  group.option_positions[0].beta_adjusted_exposure,
223
  group2.option_positions[0].beta_adjusted_exposure,
224
  )
225
- self.assertEqual(group.option_positions[0].strike, group2.option_positions[0].strike)
226
- self.assertEqual(group.option_positions[0].expiry, group2.option_positions[0].expiry)
227
  self.assertEqual(
228
- group.option_positions[0].option_type, group2.option_positions[0].option_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  )
230
- self.assertEqual(group.option_positions[0].delta, group2.option_positions[0].delta)
231
  self.assertEqual(
232
- group.option_positions[0].delta_exposure, group2.option_positions[0].delta_exposure
 
233
  )
234
  self.assertEqual(
235
- group.option_positions[0].notional_value, group2.option_positions[0].notional_value
 
236
  )
237
  self.assertEqual(
238
- group.option_positions[0].underlying_beta, group2.option_positions[0].underlying_beta
 
239
  )
240
  self.assertEqual(
241
- group.option_positions[0].market_exposure, group2.option_positions[0].market_exposure
242
  )
243
- self.assertEqual(group.option_positions[0].price, group2.option_positions[0].price)
244
  self.assertEqual(
245
  group.option_positions[0].cost_basis, group2.option_positions[0].cost_basis
246
  )
247
  self.assertEqual(
248
- group.option_positions[0].market_value, group2.option_positions[0].market_value
 
249
  )
250
 
251
  def test_portfolio_summary_serialization(self):
@@ -345,8 +376,12 @@ class TestDataModelSerialization(unittest.TestCase):
345
  self.assertEqual(summary.cash_percentage, summary2.cash_percentage)
346
  self.assertEqual(summary.stock_value, summary2.stock_value)
347
  self.assertEqual(summary.option_value, summary2.option_value)
348
- self.assertEqual(summary.pending_activity_value, summary2.pending_activity_value)
349
- self.assertEqual(summary.portfolio_estimate_value, summary2.portfolio_estimate_value)
 
 
 
 
350
  self.assertEqual(summary.price_updated_at, summary2.price_updated_at)
351
 
352
  # Check that the exposure breakdowns were deserialized correctly
@@ -354,7 +389,8 @@ class TestDataModelSerialization(unittest.TestCase):
354
  summary.long_exposure.stock_exposure, summary2.long_exposure.stock_exposure
355
  )
356
  self.assertEqual(
357
- summary.long_exposure.stock_beta_adjusted, summary2.long_exposure.stock_beta_adjusted
 
358
  )
359
  self.assertEqual(
360
  summary.long_exposure.option_delta_exposure,
@@ -368,7 +404,8 @@ class TestDataModelSerialization(unittest.TestCase):
368
  summary.long_exposure.total_exposure, summary2.long_exposure.total_exposure
369
  )
370
  self.assertEqual(
371
- summary.long_exposure.total_beta_adjusted, summary2.long_exposure.total_beta_adjusted
 
372
  )
373
 
374
  def test_portfolio_summary_serialization_without_pending_activity(self):
 
66
  stock = StockPosition.from_dict(stock_dict)
67
 
68
  # Check that market_value was calculated correctly
69
+ self.assertEqual(
70
+ stock.market_value, stock_dict["price"] * stock_dict["quantity"]
71
+ )
72
 
73
  def test_option_position_serialization(self):
74
  """Test that OptionPosition objects can be serialized and deserialized."""
 
140
  option = OptionPosition.from_dict(option_dict)
141
 
142
  # Check that market_value was calculated correctly with 100x multiplier
143
+ self.assertEqual(
144
+ option.market_value, option_dict["price"] * option_dict["quantity"] * 100
145
+ )
146
 
147
  def test_portfolio_group_serialization(self):
148
  """Test that PortfolioGroup objects can be serialized and deserialized."""
 
206
  self.assertEqual(group.stock_position.ticker, group2.stock_position.ticker)
207
  self.assertEqual(group.stock_position.quantity, group2.stock_position.quantity)
208
  self.assertEqual(group.stock_position.beta, group2.stock_position.beta)
209
+ self.assertEqual(
210
+ group.stock_position.market_exposure, group2.stock_position.market_exposure
211
+ )
212
  self.assertEqual(
213
  group.stock_position.beta_adjusted_exposure,
214
  group2.stock_position.beta_adjusted_exposure,
215
  )
216
  self.assertEqual(group.stock_position.price, group2.stock_position.price)
217
+ self.assertEqual(
218
+ group.stock_position.cost_basis, group2.stock_position.cost_basis
219
+ )
220
+ self.assertEqual(
221
+ group.stock_position.market_value, group2.stock_position.market_value
222
+ )
223
 
224
  # Check that the option position was deserialized correctly
 
225
  self.assertEqual(
226
+ group.option_positions[0].ticker, group2.option_positions[0].ticker
227
+ )
228
+ self.assertEqual(
229
+ group.option_positions[0].position_type,
230
+ group2.option_positions[0].position_type,
231
+ )
232
+ self.assertEqual(
233
+ group.option_positions[0].quantity, group2.option_positions[0].quantity
234
+ )
235
+ self.assertEqual(
236
+ group.option_positions[0].beta, group2.option_positions[0].beta
237
  )
 
 
238
  self.assertEqual(
239
  group.option_positions[0].beta_adjusted_exposure,
240
  group2.option_positions[0].beta_adjusted_exposure,
241
  )
 
 
242
  self.assertEqual(
243
+ group.option_positions[0].strike, group2.option_positions[0].strike
244
+ )
245
+ self.assertEqual(
246
+ group.option_positions[0].expiry, group2.option_positions[0].expiry
247
+ )
248
+ self.assertEqual(
249
+ group.option_positions[0].option_type,
250
+ group2.option_positions[0].option_type,
251
+ )
252
+ self.assertEqual(
253
+ group.option_positions[0].delta, group2.option_positions[0].delta
254
+ )
255
+ self.assertEqual(
256
+ group.option_positions[0].delta_exposure,
257
+ group2.option_positions[0].delta_exposure,
258
  )
 
259
  self.assertEqual(
260
+ group.option_positions[0].notional_value,
261
+ group2.option_positions[0].notional_value,
262
  )
263
  self.assertEqual(
264
+ group.option_positions[0].underlying_beta,
265
+ group2.option_positions[0].underlying_beta,
266
  )
267
  self.assertEqual(
268
+ group.option_positions[0].market_exposure,
269
+ group2.option_positions[0].market_exposure,
270
  )
271
  self.assertEqual(
272
+ group.option_positions[0].price, group2.option_positions[0].price
273
  )
 
274
  self.assertEqual(
275
  group.option_positions[0].cost_basis, group2.option_positions[0].cost_basis
276
  )
277
  self.assertEqual(
278
+ group.option_positions[0].market_value,
279
+ group2.option_positions[0].market_value,
280
  )
281
 
282
  def test_portfolio_summary_serialization(self):
 
376
  self.assertEqual(summary.cash_percentage, summary2.cash_percentage)
377
  self.assertEqual(summary.stock_value, summary2.stock_value)
378
  self.assertEqual(summary.option_value, summary2.option_value)
379
+ self.assertEqual(
380
+ summary.pending_activity_value, summary2.pending_activity_value
381
+ )
382
+ self.assertEqual(
383
+ summary.portfolio_estimate_value, summary2.portfolio_estimate_value
384
+ )
385
  self.assertEqual(summary.price_updated_at, summary2.price_updated_at)
386
 
387
  # Check that the exposure breakdowns were deserialized correctly
 
389
  summary.long_exposure.stock_exposure, summary2.long_exposure.stock_exposure
390
  )
391
  self.assertEqual(
392
+ summary.long_exposure.stock_beta_adjusted,
393
+ summary2.long_exposure.stock_beta_adjusted,
394
  )
395
  self.assertEqual(
396
  summary.long_exposure.option_delta_exposure,
 
404
  summary.long_exposure.total_exposure, summary2.long_exposure.total_exposure
405
  )
406
  self.assertEqual(
407
+ summary.long_exposure.total_beta_adjusted,
408
+ summary2.long_exposure.total_beta_adjusted,
409
  )
410
 
411
  def test_portfolio_summary_serialization_without_pending_activity(self):
tests/test_security.py CHANGED
@@ -10,7 +10,7 @@ import unittest
10
 
11
  import pandas as pd
12
 
13
- sys.path.insert(0, os.path.abspath('..'))
14
 
15
  from src.folio.security import (
16
  sanitize_cell,
@@ -33,12 +33,19 @@ class TestSecurity(unittest.TestCase):
33
 
34
  # Test HTML/script sanitization
35
  self.assertEqual(sanitize_cell("<script>alert('XSS')</script>"), "[REMOVED]")
36
- self.assertEqual(sanitize_cell("javascript:alert('XSS')"), "[REMOVED]alert('XSS')")
37
- self.assertEqual(sanitize_cell("<img src=x onerror=alert('XSS')>"), "<img src=x [REMOVED]=alert('XSS')>")
 
 
 
 
 
38
 
39
  # Test command injection sanitization
40
  self.assertEqual(sanitize_cell("value; rm -rf /"), "value rm -rf /")
41
- self.assertEqual(sanitize_cell("value | cat /etc/passwd"), "value cat /etc/passwd")
 
 
42
 
43
  # Test non-string values
44
  self.assertEqual(sanitize_cell(123), "123")
@@ -60,7 +67,9 @@ class TestSecurity(unittest.TestCase):
60
 
61
  # Test stock names with ampersands (should not be modified)
62
  self.assertEqual(sanitize_cell("S&P 500"), "S&P 500")
63
- self.assertEqual(sanitize_cell("PROSHARES ULTRAPRO S&P500"), "PROSHARES ULTRAPRO S&P500")
 
 
64
 
65
  def test_sanitize_formula(self):
66
  """Test sanitizing formula-like content."""
@@ -98,7 +107,10 @@ class TestSecurity(unittest.TestCase):
98
 
99
  # Test stock names with ampersands (should not be modified)
100
  self.assertEqual(sanitize_dangerous_content("S&P 500"), "S&P 500")
101
- self.assertEqual(sanitize_dangerous_content("PROSHARES ULTRAPRO S&P500"), "PROSHARES ULTRAPRO S&P500")
 
 
 
102
 
103
  # Test formula sanitization
104
  self.assertEqual(sanitize_dangerous_content("=SUM(A1:B1)"), "'=SUM(A1:B1)")
@@ -106,50 +118,70 @@ class TestSecurity(unittest.TestCase):
106
  self.assertEqual(sanitize_dangerous_content("+SUM(A1:B1)"), "'+SUM(A1:B1)")
107
 
108
  # Test HTML/script sanitization
109
- self.assertEqual(sanitize_dangerous_content("<script>alert('XSS')</script>"), "[REMOVED]")
110
- self.assertEqual(sanitize_dangerous_content("javascript:alert('XSS')"), "[REMOVED]alert('XSS')")
 
 
 
 
 
111
 
112
  # Test command injection sanitization
113
- self.assertEqual(sanitize_dangerous_content("value; rm -rf /"), "value rm -rf /")
114
- self.assertEqual(sanitize_dangerous_content("value | cat /etc/passwd"), "value cat /etc/passwd")
 
 
 
 
 
115
 
116
  def test_sanitize_dataframe(self):
117
  """Test sanitizing a DataFrame."""
118
  # Create a test DataFrame with potentially dangerous content
119
- df = pd.DataFrame({
120
- 'Symbol': ['AAPL', '=SUM(A1:B1)', 'MSFT'],
121
- 'Description': ['Apple Inc', '<script>alert("XSS")</script>', 'Microsoft Corp'],
122
- 'Quantity': [100, 200, 300],
123
- 'Last Price': ['$150.00', '=HYPERLINK("malicious.com")', '$250.00'],
124
- })
 
 
 
 
 
 
125
 
126
  # Sanitize the DataFrame
127
  sanitized_df = sanitize_dataframe(df)
128
 
129
  # Check that the dangerous content was sanitized
130
- self.assertEqual(sanitized_df.loc[1, 'Symbol'], "'=SUM(A1:B1)")
131
- self.assertEqual(sanitized_df.loc[1, 'Description'], "[REMOVED]")
132
- self.assertEqual(sanitized_df.loc[1, 'Last Price'], "'=HYPERLINK(\"malicious.com\")")
 
 
133
 
134
  # Check that safe content was not modified
135
- self.assertEqual(sanitized_df.loc[0, 'Symbol'], 'AAPL')
136
- self.assertEqual(sanitized_df.loc[0, 'Description'], 'Apple Inc')
137
- self.assertEqual(sanitized_df.loc[0, 'Quantity'], 100)
138
 
139
  def test_validate_csv_upload(self):
140
  """Test validating a CSV upload."""
141
  # Create a valid CSV file
142
- df = pd.DataFrame({
143
- 'Symbol': ['AAPL', 'MSFT', 'GOOGL'],
144
- 'Quantity': [100, 200, 300],
145
- 'Last Price': ['$150.00', '$250.00', '$2,500.00'],
146
- })
 
 
147
 
148
  # Convert to CSV and encode as base64
149
  csv_buffer = io.StringIO()
150
  df.to_csv(csv_buffer, index=False)
151
  csv_str = csv_buffer.getvalue()
152
- b64_content = base64.b64encode(csv_str.encode('utf-8')).decode('utf-8')
153
  contents = f"data:text/csv;base64,{b64_content}"
154
 
155
  # Validate the CSV upload
@@ -160,17 +192,19 @@ class TestSecurity(unittest.TestCase):
160
  self.assertEqual(len(result_df), 3)
161
 
162
  # Create a CSV with malicious content
163
- df = pd.DataFrame({
164
- 'Symbol': ['AAPL', '=SUM(A1:B1)', 'MSFT'],
165
- 'Quantity': [100, 200, 300],
166
- 'Last Price': ['$150.00', '=HYPERLINK("malicious.com")', '$250.00'],
167
- })
 
 
168
 
169
  # Convert to CSV and encode as base64
170
  csv_buffer = io.StringIO()
171
  df.to_csv(csv_buffer, index=False)
172
  csv_str = csv_buffer.getvalue()
173
- b64_content = base64.b64encode(csv_str.encode('utf-8')).decode('utf-8')
174
  contents = f"data:text/csv;base64,{b64_content}"
175
 
176
  # Validate the CSV upload
@@ -178,8 +212,10 @@ class TestSecurity(unittest.TestCase):
178
 
179
  # Check that validation passed but content was sanitized
180
  self.assertIsNone(error)
181
- self.assertEqual(result_df.loc[1, 'Symbol'], "'=SUM(A1:B1)")
182
- self.assertEqual(result_df.loc[1, 'Last Price'], "'=HYPERLINK(\"malicious.com\")")
 
 
183
 
184
  # Test with invalid file extension
185
  with self.assertRaises(ValueError) as context:
@@ -187,16 +223,18 @@ class TestSecurity(unittest.TestCase):
187
  self.assertIn("Only CSV files are supported", str(context.exception))
188
 
189
  # Test with missing required columns
190
- df = pd.DataFrame({
191
- 'Symbol': ['AAPL', 'MSFT', 'GOOGL'],
192
- # Missing 'Quantity' and 'Last Price'
193
- })
 
 
194
 
195
  # Convert to CSV and encode as base64
196
  csv_buffer = io.StringIO()
197
  df.to_csv(csv_buffer, index=False)
198
  csv_str = csv_buffer.getvalue()
199
- b64_content = base64.b64encode(csv_str.encode('utf-8')).decode('utf-8')
200
  contents = f"data:text/csv;base64,{b64_content}"
201
 
202
  # Validate the CSV upload
@@ -205,5 +243,5 @@ class TestSecurity(unittest.TestCase):
205
  self.assertIn("Missing required columns", str(context.exception))
206
 
207
 
208
- if __name__ == '__main__':
209
  unittest.main()
 
10
 
11
  import pandas as pd
12
 
13
+ sys.path.insert(0, os.path.abspath(".."))
14
 
15
  from src.folio.security import (
16
  sanitize_cell,
 
33
 
34
  # Test HTML/script sanitization
35
  self.assertEqual(sanitize_cell("<script>alert('XSS')</script>"), "[REMOVED]")
36
+ self.assertEqual(
37
+ sanitize_cell("javascript:alert('XSS')"), "[REMOVED]alert('XSS')"
38
+ )
39
+ self.assertEqual(
40
+ sanitize_cell("<img src=x onerror=alert('XSS')>"),
41
+ "<img src=x [REMOVED]=alert('XSS')>",
42
+ )
43
 
44
  # Test command injection sanitization
45
  self.assertEqual(sanitize_cell("value; rm -rf /"), "value rm -rf /")
46
+ self.assertEqual(
47
+ sanitize_cell("value | cat /etc/passwd"), "value cat /etc/passwd"
48
+ )
49
 
50
  # Test non-string values
51
  self.assertEqual(sanitize_cell(123), "123")
 
67
 
68
  # Test stock names with ampersands (should not be modified)
69
  self.assertEqual(sanitize_cell("S&P 500"), "S&P 500")
70
+ self.assertEqual(
71
+ sanitize_cell("PROSHARES ULTRAPRO S&P500"), "PROSHARES ULTRAPRO S&P500"
72
+ )
73
 
74
  def test_sanitize_formula(self):
75
  """Test sanitizing formula-like content."""
 
107
 
108
  # Test stock names with ampersands (should not be modified)
109
  self.assertEqual(sanitize_dangerous_content("S&P 500"), "S&P 500")
110
+ self.assertEqual(
111
+ sanitize_dangerous_content("PROSHARES ULTRAPRO S&P500"),
112
+ "PROSHARES ULTRAPRO S&P500",
113
+ )
114
 
115
  # Test formula sanitization
116
  self.assertEqual(sanitize_dangerous_content("=SUM(A1:B1)"), "'=SUM(A1:B1)")
 
118
  self.assertEqual(sanitize_dangerous_content("+SUM(A1:B1)"), "'+SUM(A1:B1)")
119
 
120
  # Test HTML/script sanitization
121
+ self.assertEqual(
122
+ sanitize_dangerous_content("<script>alert('XSS')</script>"), "[REMOVED]"
123
+ )
124
+ self.assertEqual(
125
+ sanitize_dangerous_content("javascript:alert('XSS')"),
126
+ "[REMOVED]alert('XSS')",
127
+ )
128
 
129
  # Test command injection sanitization
130
+ self.assertEqual(
131
+ sanitize_dangerous_content("value; rm -rf /"), "value rm -rf /"
132
+ )
133
+ self.assertEqual(
134
+ sanitize_dangerous_content("value | cat /etc/passwd"),
135
+ "value cat /etc/passwd",
136
+ )
137
 
138
  def test_sanitize_dataframe(self):
139
  """Test sanitizing a DataFrame."""
140
  # Create a test DataFrame with potentially dangerous content
141
+ df = pd.DataFrame(
142
+ {
143
+ "Symbol": ["AAPL", "=SUM(A1:B1)", "MSFT"],
144
+ "Description": [
145
+ "Apple Inc",
146
+ '<script>alert("XSS")</script>',
147
+ "Microsoft Corp",
148
+ ],
149
+ "Quantity": [100, 200, 300],
150
+ "Last Price": ["$150.00", '=HYPERLINK("malicious.com")', "$250.00"],
151
+ }
152
+ )
153
 
154
  # Sanitize the DataFrame
155
  sanitized_df = sanitize_dataframe(df)
156
 
157
  # Check that the dangerous content was sanitized
158
+ self.assertEqual(sanitized_df.loc[1, "Symbol"], "'=SUM(A1:B1)")
159
+ self.assertEqual(sanitized_df.loc[1, "Description"], "[REMOVED]")
160
+ self.assertEqual(
161
+ sanitized_df.loc[1, "Last Price"], '\'=HYPERLINK("malicious.com")'
162
+ )
163
 
164
  # Check that safe content was not modified
165
+ self.assertEqual(sanitized_df.loc[0, "Symbol"], "AAPL")
166
+ self.assertEqual(sanitized_df.loc[0, "Description"], "Apple Inc")
167
+ self.assertEqual(sanitized_df.loc[0, "Quantity"], 100)
168
 
169
  def test_validate_csv_upload(self):
170
  """Test validating a CSV upload."""
171
  # Create a valid CSV file
172
+ df = pd.DataFrame(
173
+ {
174
+ "Symbol": ["AAPL", "MSFT", "GOOGL"],
175
+ "Quantity": [100, 200, 300],
176
+ "Last Price": ["$150.00", "$250.00", "$2,500.00"],
177
+ }
178
+ )
179
 
180
  # Convert to CSV and encode as base64
181
  csv_buffer = io.StringIO()
182
  df.to_csv(csv_buffer, index=False)
183
  csv_str = csv_buffer.getvalue()
184
+ b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8")
185
  contents = f"data:text/csv;base64,{b64_content}"
186
 
187
  # Validate the CSV upload
 
192
  self.assertEqual(len(result_df), 3)
193
 
194
  # Create a CSV with malicious content
195
+ df = pd.DataFrame(
196
+ {
197
+ "Symbol": ["AAPL", "=SUM(A1:B1)", "MSFT"],
198
+ "Quantity": [100, 200, 300],
199
+ "Last Price": ["$150.00", '=HYPERLINK("malicious.com")', "$250.00"],
200
+ }
201
+ )
202
 
203
  # Convert to CSV and encode as base64
204
  csv_buffer = io.StringIO()
205
  df.to_csv(csv_buffer, index=False)
206
  csv_str = csv_buffer.getvalue()
207
+ b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8")
208
  contents = f"data:text/csv;base64,{b64_content}"
209
 
210
  # Validate the CSV upload
 
212
 
213
  # Check that validation passed but content was sanitized
214
  self.assertIsNone(error)
215
+ self.assertEqual(result_df.loc[1, "Symbol"], "'=SUM(A1:B1)")
216
+ self.assertEqual(
217
+ result_df.loc[1, "Last Price"], '\'=HYPERLINK("malicious.com")'
218
+ )
219
 
220
  # Test with invalid file extension
221
  with self.assertRaises(ValueError) as context:
 
223
  self.assertIn("Only CSV files are supported", str(context.exception))
224
 
225
  # Test with missing required columns
226
+ df = pd.DataFrame(
227
+ {
228
+ "Symbol": ["AAPL", "MSFT", "GOOGL"],
229
+ # Missing 'Quantity' and 'Last Price'
230
+ }
231
+ )
232
 
233
  # Convert to CSV and encode as base64
234
  csv_buffer = io.StringIO()
235
  df.to_csv(csv_buffer, index=False)
236
  csv_str = csv_buffer.getvalue()
237
+ b64_content = base64.b64encode(csv_str.encode("utf-8")).decode("utf-8")
238
  contents = f"data:text/csv;base64,{b64_content}"
239
 
240
  # Validate the CSV upload
 
243
  self.assertIn("Missing required columns", str(context.exception))
244
 
245
 
246
+ if __name__ == "__main__":
247
  unittest.main()