Tabular Classification
Scikit-learn
Joblib
postgresql
sql
query-cache
plan-cache
redis
database
tabular-regression
Instructions to use nilenpatel/pg-plan-cache-models with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Scikit-learn
How to use nilenpatel/pg-plan-cache-models with Scikit-learn:
from huggingface_hub import hf_hub_download import joblib model = joblib.load( hf_hub_download("nilenpatel/pg-plan-cache-models", "sklearn_model.joblib") ) # only load pickle files from sources you trust # read more about it here https://skops.readthedocs.io/en/stable/persistence.html - Notebooks
- Google Colab
- Kaggle
| """ | |
| Inference API for pg_plan_cache models. | |
| Loads trained models and provides prediction functions for: | |
| 1. Cache benefit (high / medium / low) | |
| 2. Recommended TTL (seconds) | |
| 3. Complexity score (1-100) | |
| """ | |
| import os | |
| import json | |
| import joblib | |
| import numpy as np | |
| from features import extract_features, FEATURE_NAMES | |
| MODEL_DIR = os.path.join(os.path.dirname(__file__), "trained") | |
| _cache_advisor = None | |
| _ttl_recommender = None | |
| _complexity_estimator = None | |
| _label_encoder = None | |
| _loaded = False | |
| def _load_models(): | |
| """Lazy-load all models from disk.""" | |
| global _cache_advisor, _ttl_recommender, _complexity_estimator, _label_encoder, _loaded | |
| if _loaded: | |
| return | |
| _cache_advisor = joblib.load(os.path.join(MODEL_DIR, "cache_advisor.joblib")) | |
| _ttl_recommender = joblib.load(os.path.join(MODEL_DIR, "ttl_recommender.joblib")) | |
| _complexity_estimator = joblib.load(os.path.join(MODEL_DIR, "complexity_estimator.joblib")) | |
| _label_encoder = joblib.load(os.path.join(MODEL_DIR, "label_encoder.joblib")) | |
| _loaded = True | |
| def predict(sql: str) -> dict: | |
| """ | |
| Run all three models on a SQL query. | |
| Returns: | |
| { | |
| "query": str, | |
| "cache_benefit": "high" | "medium" | "low", | |
| "cache_benefit_probabilities": {"high": 0.8, "medium": 0.15, "low": 0.05}, | |
| "recommended_ttl": int, # seconds | |
| "ttl_human": str, # e.g. "1h 0m" | |
| "complexity_score": int, # 1-100 | |
| "complexity_label": str, # "simple" | "moderate" | "complex" | "very complex" | |
| "features": {name: value, ...}, | |
| } | |
| """ | |
| _load_models() | |
| features = extract_features(sql) | |
| X = np.array([features]) | |
| # Cache advisor | |
| benefit_idx = _cache_advisor.predict(X)[0] | |
| benefit_label = _label_encoder.inverse_transform([benefit_idx])[0] | |
| benefit_probs = _cache_advisor.predict_proba(X)[0] | |
| prob_dict = { | |
| _label_encoder.inverse_transform([i])[0]: round(float(p), 4) | |
| for i, p in enumerate(benefit_probs) | |
| } | |
| # TTL recommender | |
| ttl_raw = _ttl_recommender.predict(X)[0] | |
| ttl = max(0, int(round(ttl_raw))) | |
| hours, mins = divmod(ttl // 60, 60) | |
| ttl_human = f"{hours}h {mins}m" if hours else f"{mins}m" | |
| # Complexity estimator | |
| cplx_raw = _complexity_estimator.predict(X)[0] | |
| cplx = max(1, min(100, int(round(cplx_raw)))) | |
| if cplx <= 20: | |
| cplx_label = "simple" | |
| elif cplx <= 45: | |
| cplx_label = "moderate" | |
| elif cplx <= 75: | |
| cplx_label = "complex" | |
| else: | |
| cplx_label = "very complex" | |
| return { | |
| "query": sql, | |
| "cache_benefit": benefit_label, | |
| "cache_benefit_probabilities": prob_dict, | |
| "recommended_ttl": ttl, | |
| "ttl_human": ttl_human, | |
| "complexity_score": cplx, | |
| "complexity_label": cplx_label, | |
| "features": dict(zip(FEATURE_NAMES, features)), | |
| } | |
| def predict_batch(queries: list[str]) -> list[dict]: | |
| """Run predictions on multiple queries.""" | |
| return [predict(q) for q in queries] | |
| def format_prediction(result: dict) -> str: | |
| """Format a prediction result as a readable string.""" | |
| lines = [ | |
| f" Query: {result['query'][:100]}{'...' if len(result['query']) > 100 else ''}", | |
| f" Cache Benefit: {result['cache_benefit'].upper()}", | |
| f" Probabilities: {result['cache_benefit_probabilities']}", | |
| f" Recommended TTL: {result['recommended_ttl']}s ({result['ttl_human']})", | |
| f" Complexity: {result['complexity_score']}/100 ({result['complexity_label']})", | |
| ] | |
| return "\n".join(lines) | |
| def get_model_info() -> dict: | |
| """Return model metadata.""" | |
| meta_path = os.path.join(MODEL_DIR, "metadata.json") | |
| if os.path.exists(meta_path): | |
| with open(meta_path) as f: | |
| return json.load(f) | |
| return {"error": "metadata.json not found. Run train.py first."} | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import sys | |
| if len(sys.argv) < 2: | |
| print("Usage: python predict.py \"SELECT * FROM users WHERE id = 42\"") | |
| sys.exit(1) | |
| sql = " ".join(sys.argv[1:]) | |
| result = predict(sql) | |
| print(format_prediction(result)) | |