stocks / core /expression_parser.py
Arrechenash's picture
Initial Commit
da67450
"""Expression parser for filter expressions with date filtering and sorting.
Required Syntax (Full Column Names Only):
gap_pct > 10 # Gap % > 10%
run_pct > 20 # Run % > 20%
change_pct > 5 # Change % > 5%
volume > 5M # Volume > 5,000,000
gap_pct > 10 in 5d # Gap > 10% in last 5 days
$close[1] # Close 1 day after event
close[-1] # Yesterday's close
max(high, 20) # 20-day max (excludes today)
volume > 5M sort volume desc
"""
import re
from dataclasses import dataclass, field
@dataclass
class FilterCondition:
"""A single filter condition (e.g., date >= '2026-01-01')."""
column: str
operator: str
value: str
@dataclass
class SortSpec:
"""Sort specification."""
column: str # Full expression (e.g., "close", "close / close[-10]")
direction: str = "asc"
@dataclass
class ParsedExpression:
"""Result of parsing a filter expression."""
date_conditions: list[FilterCondition] = field(default_factory=list)
sort: SortSpec | None = None
remaining_filter: str = "" # For pandas evaluation
# Detected features for SQL generation
metrics: set[tuple[str, int]] = field(default_factory=set) # (col, offset)
aggregations: set[tuple[str, str, int]] = field(default_factory=set) # (func, col, lookback)
window_chained: set[tuple[str, str, int, int]] = field(default_factory=set) # (func, col, lookback, offset)
chained_aggs: set[tuple[str, str, str, int]] = field(default_factory=set) # (func, col1, col2, offset)
binary_aggs: set[tuple[str, str, str]] = field(default_factory=set) # (func, col1, col2)
# Event features: in N days and $column[offset]
event_windows: list[tuple[str, int]] = field(default_factory=list) # [(condition, days)]
event_refs: list[tuple[str, int]] = field(default_factory=list) # [(column, offset)]
def get_start_date(self) -> str | None:
"""Extract start date from conditions (>= or >)."""
for cond in self.date_conditions:
if cond.operator in (">=", ">"):
return cond.value
return None
def get_end_date(self) -> str | None:
"""Extract end date from conditions (<= or <)."""
for cond in self.date_conditions:
if cond.operator in ("<=", "<"):
return cond.value
return None
def get_exact_date(self) -> str | None:
"""Extract exact date from conditions (= or ==)."""
for cond in self.date_conditions:
if cond.operator in ("=", "=="):
return cond.value
return None
class ExpressionParser:
"""Parse filter expressions into structured data for SQL and Pandas.
Full column names required:
- gap_pct, run_pct, change_pct, range_pct (percentage columns)
- gap_dollar, run_dollar, change_dollar, range_dollar (dollar columns)
- volume (not vol)
- streak_run_pct, rel_vol, vol_ratio_52wk
- up_streak, down_streak
"""
# Structural patterns
METRIC_PATTERN = r"(\w+)\[(-?\d+)\]"
AGG_PATTERN = r"(max|min|avg)\((\w+),\s*(\d+)\)"
WINDOW_CHAINED_PATTERN = r"(max|min|avg)\((\w+),\s*(\d+)\)\[(-?\d+)\]"
CHAINED_AGG_PATTERN = r"(max|min|avg)\(([\w]+),\s*([\w]+)\)\[(-?\d+)\]"
BINARY_AGG_PATTERN = r"(max|min|avg)\(([\w]+),\s*([a-zA-Z_][\w]*)\)(?!\s*\[)"
# UI/UX Patterns
SORT_PATTERN = r"(?:^|\s+)sort\s+(.+?)(?:\s+(asc|desc))?\s*$"
DATE_PATTERN = r'date\s*(>=|<=|==|>|<|=)\s*[\'"](\d{4}-\d{2}-\d{2})[\'"]'
# Event patterns
EVENT_WINDOW_PATTERN = r"\s+in\s+(\d+)d\b"
EVENT_REF_PATTERN = r"\$([a-z_][\w]*)\[(-?\d+)\]"
EVENT_REF_SIMPLE = r"\$([a-z_][\w]+)"
# Percentage columns (full names only)
PCT_COLS = ["gap_pct", "run_pct", "change_pct", "range_pct", "streak_run_pct"]
def parse(self, expr: str) -> ParsedExpression:
"""Fully parse expression into structured features."""
if not expr or not expr.strip():
return ParsedExpression()
# Normalize trader syntax (lowercase, K/M/B suffixes)
processed = self._normalize(expr)
# Structural parsing
sort_spec = self._parse_sort(processed)
expr_without_sort = self._remove_sort(processed)
date_conditions = self._parse_dates(expr_without_sort)
remaining = self._remove_dates(expr_without_sort)
remaining = self._cleanup_expression(remaining)
# Convert percentages without % suffix to decimal
remaining = self._convert_percentages(remaining)
# Convert % suffix to decimal: gap_pct > 10% → gap_pct > 0.10
remaining = self._convert_percent_suffix(remaining)
# Convert $amount: gap_pct > $5 → gap_dollar > 5
remaining = self._convert_dollar(remaining)
# Extract features
metrics = {(m, int(off)) for m, off in re.findall(self.METRIC_PATTERN, remaining)}
aggs = {(f.lower(), m, int(lb)) for f, m, lb in re.findall(self.AGG_PATTERN, remaining, re.IGNORECASE)}
window_chained = {
(f.lower(), c, int(lb), int(off))
for f, c, lb, off in re.findall(self.WINDOW_CHAINED_PATTERN, remaining, re.IGNORECASE)
}
chained_aggs = {
(f.lower(), c1, c2, int(off))
for f, c1, c2, off in re.findall(self.CHAINED_AGG_PATTERN, remaining, re.IGNORECASE)
}
binary_aggs = {
(f.lower(), c1, c2) for f, c1, c2 in re.findall(self.BINARY_AGG_PATTERN, remaining, re.IGNORECASE)
}
# Extract event features
remaining, event_windows, event_refs = self._extract_event_features(remaining)
return ParsedExpression(
date_conditions=date_conditions,
sort=sort_spec,
remaining_filter=remaining,
metrics=metrics,
aggregations=aggs,
window_chained=window_chained,
chained_aggs=chained_aggs,
binary_aggs=binary_aggs,
event_windows=event_windows,
event_refs=event_refs,
)
def _normalize(self, expr: str) -> str:
"""Normalize trader syntax: lowercase, K/M/B suffixes."""
# Protect string literals first
temp_strings = {}
result = expr
for counter, match in enumerate(re.finditer(r"'(?:[^'\\]|\\.)*'|\"(?:[^\"\\]|\\.)*\"", result)):
placeholder = f"__STR_{counter}__"
temp_strings[placeholder] = match.group(0)
result = result.replace(match.group(0), placeholder)
# Lowercase
result = result.lower()
# K/M/B suffixes for numbers (not after $)
def replace_suffix(match):
num = int(match.group(1))
suffix = match.group(2)
multipliers = {"k": 1_000, "m": 1_000_000, "b": 1_000_000_000}
return str(num * multipliers.get(suffix, 1))
result = re.sub(r"\b(\d+)([kmb])\b(?!\s*\[)", replace_suffix, result, flags=re.IGNORECASE)
# Restore string literals
for placeholder, original in temp_strings.items():
result = result.replace(placeholder.lower(), original)
return result
def _convert_dollar(self, expr: str) -> str:
"""Convert $number to _dollar column for percentage columns.
Example: gap_pct > $5 → gap_dollar > 5
"""
def dollar_to_column(match):
col = match.group(1)
op = match.group(2)
num = match.group(3)
dollar_col = col.replace("_pct", "_dollar")
return f"{dollar_col} {op} {num}"
# Match: column op $number (where column is a pct column)
pct_cols_pattern = r"\b(gap_pct|run_pct|change_pct|range_pct)\s*([><=!]+)\s*\$\s*(\d+\.?\d*)"
expr = re.sub(pct_cols_pattern, dollar_to_column, expr, flags=re.IGNORECASE)
return expr
def _convert_percentages(self, expr: str) -> str:
"""Convert percentage columns without % suffix to decimal.
Rules:
- gap_pct > 5 → gap_pct > 0.05 (5%, divide by 100 because 5 > 1)
- gap_pct > 5% → gap_pct > 0.05 (already converted by % handling)
- gap_pct > 0.10 → gap_pct > 0.10 (already decimal, no conversion)
- gap_pct > 0.05 → gap_pct > 0.05 (already decimal, no conversion)
Values > 1 are divided by 100.
Values <= 1 or with decimal point are kept as-is.
Values with % suffix are handled separately.
"""
pct_cols_pattern = "|".join(self.PCT_COLS)
def convert_to_decimal(match):
col = match.group(1)
offset = match.group(2) if match.group(2) else "" # Handle None when no offset
op = match.group(3)
num = match.group(4)
if "." in num:
return f"{col}{offset} {op} {num}"
if float(num) <= 1:
return f"{col}{offset} {op} {num}"
decimal = float(num) / 100
return f"{col}{offset} {op} {decimal}"
# Pattern: column[offset] operator number (not followed by %)
# Support negative numbers and offsets
pattern = rf"\b({pct_cols_pattern})(\[-?\d+\])?\s*([><=!]+)\s*(-?\d+\.?\d*)\b(?!\s*%)"
expr = re.sub(pattern, convert_to_decimal, expr, flags=re.IGNORECASE)
return expr
def _convert_percent_suffix(self, expr: str) -> str:
"""Convert percentage columns with % suffix to decimal.
Examples:
- gap_pct > 10% → gap_pct > 0.10
- gap_pct[-1] > 5% → gap_pct[-1] > 0.05
- gap_pct > %5 → gap_pct > 0.05 (old format)
"""
def percent_to_decimal(match):
col = match.group(1)
offset = match.group(2) if match.group(2) else "" # Handle None when no offset
op = match.group(3)
num = float(match.group(4))
decimal = num / 100
return f"{col}{offset} {op} {decimal}"
pct_cols_pattern = "|".join(self.PCT_COLS)
# New format: column op number% (e.g., gap_pct > 10%)
pct_cols_pattern_new = rf"\b({pct_cols_pattern})(\[-?\d+\])?\s*([><=!]+)\s*(-?\d+\.?\d*)\s*%"
expr = re.sub(pct_cols_pattern_new, percent_to_decimal, expr, flags=re.IGNORECASE)
# Old format: column op %number (e.g., gap_pct > %5)
pct_cols_pattern_old = rf"\b({pct_cols_pattern})(\[-?\d+\])?\s*([><=!]+)\s*%\s*(-?\d+\.?\d*)"
expr = re.sub(pct_cols_pattern_old, percent_to_decimal, expr, flags=re.IGNORECASE)
return expr
def _parse_sort(self, expr: str) -> SortSpec | None:
match = re.search(self.SORT_PATTERN, expr, re.IGNORECASE)
if match:
return SortSpec(column=match.group(1), direction=(match.group(2) or "asc").lower())
return None
def _remove_sort(self, expr: str) -> str:
return re.sub(self.SORT_PATTERN, "", expr, flags=re.IGNORECASE)
def _parse_dates(self, expr: str) -> list[FilterCondition]:
matches = re.findall(self.DATE_PATTERN, expr, re.IGNORECASE)
return [FilterCondition(column="date", operator=op, value=val) for op, val in matches]
def _remove_dates(self, expr: str) -> str:
result = re.sub(r"\s+and\s+" + self.DATE_PATTERN, "", expr, flags=re.IGNORECASE)
result = re.sub(self.DATE_PATTERN + r"\s+and\s+", "", result, flags=re.IGNORECASE)
result = re.sub(self.DATE_PATTERN, "", result, flags=re.IGNORECASE)
return result
def _cleanup_expression(self, expr: str) -> str:
result = re.sub(r"^\s*and\s+", "", expr, flags=re.IGNORECASE)
result = re.sub(r"\s+and\s*$", "", expr, flags=re.IGNORECASE)
result = re.sub(r"\s+and\s+and\s+", " and ", result, flags=re.IGNORECASE)
result = re.sub(r"\band\b", "and", result, flags=re.IGNORECASE)
result = re.sub(r"\bor\b", "or", result, flags=re.IGNORECASE)
return " ".join(result.split())
def _extract_event_features(self, expr: str) -> tuple[str, list[tuple[str, int]], list[tuple[str, int]]]:
"""Extract event windows and event refs from expression."""
windows = []
refs = []
remaining = expr
# Split by OR
or_parts = re.split(r"\s+or\s+", remaining, flags=re.IGNORECASE)
remaining_or_parts = []
for or_part in or_parts:
# Look for "in Nd" pattern
match = re.search(self.EVENT_WINDOW_PATTERN, or_part, re.IGNORECASE)
if match:
# Extract condition before "in Nd"
condition = or_part[: match.start()].strip()
days = int(match.group(1))
# Remove leading/trailing parentheses
condition = re.sub(r"^\((.+)\)$", r"\1", condition)
if condition:
windows.append((condition, days))
# Extract remaining after "in Nd"
after_event = or_part[match.end() :].strip()
if after_event.startswith("and "):
after_event = after_event[4:].strip()
# Split by AND for additional conditions
if after_event:
and_parts = re.split(r"\s+and\s+", after_event, flags=re.IGNORECASE)
for part in and_parts:
if part.strip():
remaining_or_parts.append(part.strip())
else:
remaining_or_parts.append(or_part)
remaining = " or ".join(remaining_or_parts)
# Extract event refs
for match in re.finditer(self.EVENT_REF_PATTERN, remaining, re.IGNORECASE):
col, off = match.group(1).lower(), int(match.group(2))
if (col, off) not in refs:
refs.append((col, off))
for match in re.finditer(self.EVENT_REF_SIMPLE, remaining, re.IGNORECASE):
col = match.group(1).lower()
if (col, 0) not in refs:
refs.append((col, 0))
# Clean up
remaining = remaining.strip()
remaining = re.sub(r"\s+", " ", remaining)
return remaining, windows, refs
def extract_lookback(self, expr_or_parsed: str | ParsedExpression) -> int:
"""Calculate required lookback days from expression or parsed object."""
parsed = self.parse(expr_or_parsed) if isinstance(expr_or_parsed, str) else expr_or_parsed
lookback = 0
if parsed.metrics:
lookback = max(lookback, max(abs(off) for _, off in parsed.metrics))
if parsed.aggregations:
lookback = max(lookback, max(lb for _, _, lb in parsed.aggregations))
if parsed.window_chained:
lookback = max(lookback, max(lb + abs(off) for _, _, lb, off in parsed.window_chained))
if parsed.chained_aggs:
lookback = max(lookback, max(abs(off) for _, _, _, off in parsed.chained_aggs))
return lookback + 5
def compile_safe(self, expr_str: str):
"""Compile an expression string into a safe, callable Python function.
Validates the expression using Python's AST to ensure only allowed
nodes and names are used, preventing code injection.
"""
import ast
if not expr_str:
return lambda ctx: False
# 1. Full normalization and cleanup using existing parser logic
parsed = self.parse(expr_str)
# Use the remaining_filter which has dates and sort removed
normalized = parsed.remaining_filter
# Remove $ from event refs for AST validation (e.g., $close -> close)
normalized = normalized.replace("$", "")
if not normalized or not normalized.strip():
return lambda ctx: True
# Handle some edge cases with 'and/or' and whitespace for Python AST
normalized = re.sub(r"\band\b", " and ", normalized, flags=re.IGNORECASE)
normalized = re.sub(r"\bor\b", " or ", normalized, flags=re.IGNORECASE)
try:
tree = ast.parse(normalized, mode="eval")
except SyntaxError as e:
raise ValueError(f"Invalid filter: Invalid expression syntax: {e}") from e
# Allowed AST nodes for stock scanning and backtesting
allowed_nodes = {
ast.Expression,
ast.BinOp,
ast.UnaryOp,
ast.Compare,
ast.BoolOp,
ast.Name,
ast.Constant,
ast.Subscript,
ast.Slice, # For open[-1]
ast.Call,
ast.Attribute,
# Operators
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
ast.Mod,
ast.Pow,
ast.And,
ast.Or,
ast.Not,
ast.Eq,
ast.NotEq,
ast.Lt,
ast.LtE,
ast.Gt,
ast.GtE,
ast.In,
ast.NotIn,
ast.USub,
ast.UAdd,
ast.Load,
ast.Index, # Required for Python < 3.9
}
# Prohibited variable/function names (security blocklist)
# Note: 'open' is excluded because it's a valid OHLC stock metric
prohibited_names = {
"eval", "exec", "compile", "__import__", "getattr", "setattr", "delattr",
"hasattr", "globals", "locals", "vars", "dir", "input", "breakpoint",
"exit", "quit", "help", "repr", "str", "int", "float", "list", "dict", "set",
"tuple", "type", "object", "class", "def", "return", "yield", "raise", "assert",
"import", "from", "global", "nonlocal", "try", "except", "finally", "with", "as",
"if", "else", "elif", "for", "while", "pass", "continue", "break", "del",
}
# Allowed variable/function names (Extensive whitelist)
allowed_names = {
# Metrics
"open",
"high",
"low",
"close",
"volume",
"price",
"time",
"gap_pct",
"run_pct",
"change_pct",
"range_pct",
"rel_vol",
"gap_dollar",
"run_dollar",
"change_dollar",
"range_dollar",
"volume_dollar",
"streak_run_pct",
"vol_ratio_52wk",
"up_streak",
"down_streak",
"rs",
# Metadata
"sector",
"industry",
"market_cap",
"country",
"name",
"symbol",
"date",
# Functions
"max",
"min",
"avg",
"ret",
"entry_price",
# Constants
"true",
"false",
"none",
}
for node in ast.walk(tree):
if type(node) not in allowed_nodes:
raise ValueError(f"Invalid filter: Prohibited expression element: {type(node).__name__}")
if isinstance(node, ast.Name):
name_id = node.id.lower()
if name_id in allowed_names or name_id.startswith("__str_"):
pass # Whitelisted — skip further checks
elif name_id in prohibited_names:
raise ValueError(f"Invalid filter: Prohibited expression element: {node.id}")
else:
raise ValueError(f"Invalid filter: Prohibited variable name: {node.id}")
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Name)
and node.func.id.lower() in prohibited_names
):
raise ValueError(f"Invalid filter: Prohibited expression element: {node.func.id}")
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Name)
and node.func.id.lower() not in allowed_names
):
raise ValueError(f"Invalid filter: Prohibited function call: {node.func.id}")
# Block attribute access on builtins that could be dangerous (e.g., open.__globals__)
if isinstance(node, ast.Attribute):
dangerous_attrs = {"__globals__", "__builtins__", "__class__", "__dict__",
"__module__", "__doc__", "__code__", "__defaults__",
"__kwdefaults__", "__annotations__", "__closure__"}
if node.attr in dangerous_attrs:
raise ValueError(f"Invalid filter: Prohibited expression element: {node.attr}")
# If we get here, it's safe to compile
try:
compiled = compile(tree, "<string>", "eval")
return compiled
except Exception as e:
raise ValueError(f"Invalid filter: Compilation failed: {e}") from e