Tabular Classification
Scikit-learn
Joblib
postgresql
sql
query-cache
plan-cache
redis
database
tabular-regression
Instructions to use nilenpatel/pg-plan-cache-models with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Scikit-learn
How to use nilenpatel/pg-plan-cache-models with Scikit-learn:
from huggingface_hub import hf_hub_download import joblib model = joblib.load( hf_hub_download("nilenpatel/pg-plan-cache-models", "sklearn_model.joblib") ) # only load pickle files from sources you trust # read more about it here https://skops.readthedocs.io/en/stable/persistence.html - Notebooks
- Google Colab
- Kaggle
| """ | |
| SQL feature extraction for pg_plan_cache models. | |
| Extracts structural features from raw SQL query text to feed into | |
| the Cache Advisor, TTL Recommender, and Complexity Estimator models. | |
| """ | |
| import re | |
| AGGREGATE_FUNCS = re.compile( | |
| r"\b(count|sum|avg|min|max|array_agg|string_agg|bool_and|bool_or|jsonb_agg)\s*\(", | |
| re.IGNORECASE, | |
| ) | |
| WINDOW_FUNCS = re.compile( | |
| r"\b(row_number|rank|dense_rank|ntile|lag|lead|first_value|last_value|nth_value)\s*\(", | |
| re.IGNORECASE, | |
| ) | |
| JOIN_PATTERN = re.compile( | |
| r"\b(inner\s+join|left\s+join|right\s+join|full\s+join|cross\s+join|join)\b", | |
| re.IGNORECASE, | |
| ) | |
| SUBQUERY_PATTERN = re.compile(r"\(\s*select\b", re.IGNORECASE) | |
| CTE_PATTERN = re.compile(r"\bwith\s+\w+\s+as\s*\(", re.IGNORECASE) | |
| UNION_PATTERN = re.compile(r"\b(union|intersect|except)\b", re.IGNORECASE) | |
| CASE_PATTERN = re.compile(r"\bcase\b", re.IGNORECASE) | |
| IN_PATTERN = re.compile(r"\bin\s*\(", re.IGNORECASE) | |
| LIKE_PATTERN = re.compile(r"\b(like|ilike)\b", re.IGNORECASE) | |
| BETWEEN_PATTERN = re.compile(r"\bbetween\b", re.IGNORECASE) | |
| EXISTS_PATTERN = re.compile(r"\bexists\s*\(", re.IGNORECASE) | |
| HAVING_PATTERN = re.compile(r"\bhaving\b", re.IGNORECASE) | |
| CAST_PATTERN = re.compile(r"\b(cast|::)\b", re.IGNORECASE) | |
| FEATURE_NAMES = [ | |
| "query_length", | |
| "query_type", # 0=SELECT, 1=INSERT, 2=UPDATE, 3=DELETE, 4=OTHER | |
| "num_tables", | |
| "num_joins", | |
| "num_conditions", | |
| "num_aggregates", | |
| "num_subqueries", | |
| "num_columns", | |
| "has_distinct", | |
| "has_order_by", | |
| "has_group_by", | |
| "has_having", | |
| "has_limit", | |
| "has_offset", | |
| "has_where", | |
| "has_like", | |
| "has_in_clause", | |
| "has_between", | |
| "has_exists", | |
| "has_window_func", | |
| "has_cte", | |
| "has_union", | |
| "has_case", | |
| "has_cast", | |
| "nesting_depth", | |
| "num_and_or", | |
| "num_string_literals", | |
| "num_numeric_literals", | |
| ] | |
| def _count_tables(sql: str) -> int: | |
| """Estimate the number of tables referenced.""" | |
| count = 0 | |
| # FROM clause tables | |
| from_match = re.search(r"\bfrom\s+(.+?)(?:\bwhere\b|\bjoin\b|\bgroup\b|\border\b|\blimit\b|\bhaving\b|;|$)", sql, re.IGNORECASE | re.DOTALL) | |
| if from_match: | |
| from_clause = from_match.group(1) | |
| count += len(re.split(r",", from_clause)) | |
| # JOIN tables | |
| count += len(JOIN_PATTERN.findall(sql)) | |
| return max(count, 0) | |
| def _count_columns(sql: str) -> int: | |
| """Estimate the number of columns in SELECT clause.""" | |
| match = re.search(r"\bselect\s+(.*?)\bfrom\b", sql, re.IGNORECASE | re.DOTALL) | |
| if not match: | |
| return 0 | |
| select_clause = match.group(1).strip() | |
| if select_clause == "*": | |
| return 1 | |
| # Split by commas not inside parentheses | |
| depth = 0 | |
| count = 1 | |
| for ch in select_clause: | |
| if ch == '(': | |
| depth += 1 | |
| elif ch == ')': | |
| depth -= 1 | |
| elif ch == ',' and depth == 0: | |
| count += 1 | |
| return count | |
| def _nesting_depth(sql: str) -> int: | |
| """Calculate maximum parenthesis nesting depth.""" | |
| max_depth = 0 | |
| depth = 0 | |
| for ch in sql: | |
| if ch == '(': | |
| depth += 1 | |
| max_depth = max(max_depth, depth) | |
| elif ch == ')': | |
| depth -= 1 | |
| return max_depth | |
| def extract_features(sql: str) -> list[float]: | |
| """ | |
| Extract a fixed-length feature vector from a SQL query string. | |
| Returns a list of floats matching FEATURE_NAMES ordering. | |
| """ | |
| sql = sql.strip() | |
| upper = sql.upper().lstrip() | |
| # Query type | |
| if upper.startswith("SELECT"): | |
| qtype = 0 | |
| elif upper.startswith("INSERT"): | |
| qtype = 1 | |
| elif upper.startswith("UPDATE"): | |
| qtype = 2 | |
| elif upper.startswith("DELETE"): | |
| qtype = 3 | |
| else: | |
| qtype = 4 | |
| num_joins = len(JOIN_PATTERN.findall(sql)) | |
| num_aggs = len(AGGREGATE_FUNCS.findall(sql)) | |
| num_subqueries = len(SUBQUERY_PATTERN.findall(sql)) | |
| num_conditions = len(re.findall(r"\b(and|or)\b", sql, re.IGNORECASE)) | |
| num_string_lits = len(re.findall(r"'[^']*'", sql)) | |
| num_numeric_lits = len(re.findall(r"\b\d+(?:\.\d+)?\b", sql)) | |
| features = [ | |
| float(len(sql)), # query_length | |
| float(qtype), # query_type | |
| float(_count_tables(sql)), # num_tables | |
| float(num_joins), # num_joins | |
| float(num_conditions), # num_conditions | |
| float(num_aggs), # num_aggregates | |
| float(num_subqueries), # num_subqueries | |
| float(_count_columns(sql)), # num_columns | |
| float(bool(re.search(r"\bdistinct\b", sql, re.I))), # has_distinct | |
| float(bool(re.search(r"\border\s+by\b", sql, re.I))), # has_order_by | |
| float(bool(re.search(r"\bgroup\s+by\b", sql, re.I))), # has_group_by | |
| float(bool(HAVING_PATTERN.search(sql))), # has_having | |
| float(bool(re.search(r"\blimit\b", sql, re.I))), # has_limit | |
| float(bool(re.search(r"\boffset\b", sql, re.I))), # has_offset | |
| float(bool(re.search(r"\bwhere\b", sql, re.I))), # has_where | |
| float(bool(LIKE_PATTERN.search(sql))), # has_like | |
| float(bool(IN_PATTERN.search(sql))), # has_in_clause | |
| float(bool(BETWEEN_PATTERN.search(sql))), # has_between | |
| float(bool(EXISTS_PATTERN.search(sql))), # has_exists | |
| float(bool(WINDOW_FUNCS.search(sql))), # has_window_func | |
| float(bool(CTE_PATTERN.search(sql))), # has_cte | |
| float(bool(UNION_PATTERN.search(sql))), # has_union | |
| float(bool(CASE_PATTERN.search(sql))), # has_case | |
| float(bool(CAST_PATTERN.search(sql))), # has_cast | |
| float(_nesting_depth(sql)), # nesting_depth | |
| float(num_conditions), # num_and_or | |
| float(num_string_lits), # num_string_literals | |
| float(num_numeric_lits), # num_numeric_literals | |
| ] | |
| return features | |