nilenpatel's picture
Initial release: pg_plan_cache agent
40eb9bf
"""
Tool implementations for the pg_plan_cache agent.
Each function takes a DatabaseManager and optional args, returns a printable string.
"""
import time
from db import DatabaseManager
from normalizer import normalize_query, compute_query_hash
def _safe(fn):
"""Wrap a tool so exceptions become readable error strings."""
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
return f" Error: {type(e).__name__}: {e}"
wrapper.__name__ = fn.__name__
wrapper.__doc__ = fn.__doc__
return wrapper
# ---------------------------------------------------------------------------
# Stats & monitoring
# ---------------------------------------------------------------------------
@_safe
def get_cache_stats(db: DatabaseManager) -> str:
"""Fetch cache statistics from pg_plan_cache_stats()."""
rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()")
if not rows:
return " No stats returned. Is the pg_plan_cache extension loaded?"
s = rows[0]
return (
f" cache_hits: {s['cache_hits']}\n"
f" cache_misses: {s['cache_misses']}\n"
f" cache_errors: {s['cache_errors']}\n"
f" cache_invalidations: {s['cache_invalidations']}\n"
f" redis_timeouts: {s['redis_timeouts']}\n"
f" plans_stored: {s['plans_stored']}\n"
f" redis_pool_active: {s['redis_pool_active']}\n"
f" redis_pool_idle: {s['redis_pool_idle']}"
)
@_safe
def check_redis_health(db: DatabaseManager) -> str:
"""Check Redis connectivity and report server info."""
alive = db.redis_ping()
if not alive:
return " Redis is NOT reachable."
info = db.redis_info()
plan_keys = db.redis_keys("plan:*")
dep_keys = db.redis_keys("deps:table:*")
qdep_keys = db.redis_keys("qdeps:*")
return (
f" Status: CONNECTED\n"
f" redis_version: {info.get('redis_version', '?')}\n"
f" uptime_seconds: {info.get('uptime_in_seconds', '?')}\n"
f" connected_clients: {info.get('connected_clients', '?')}\n"
f" used_memory_human: {info.get('used_memory_human', '?')}\n"
f" peak_memory: {info.get('used_memory_peak_human', '?')}\n"
f" total_commands: {info.get('total_commands_processed', '?')}\n"
f" plan keys: {len(plan_keys)}\n"
f" table dep sets: {len(dep_keys)}\n"
f" query dep sets: {len(qdep_keys)}"
)
@_safe
def check_pg_health(db: DatabaseManager) -> str:
"""Check PostgreSQL connectivity and extension status."""
rows = db.pg_query("SELECT version()")
version = rows[0]["version"] if rows else "unknown"
ext_rows = db.pg_query(
"SELECT extname, extversion FROM pg_extension WHERE extname = 'pg_plan_cache'"
)
if ext_rows:
ext_info = f" Extension: pg_plan_cache v{ext_rows[0]['extversion']}"
else:
ext_info = " Extension: NOT INSTALLED"
preload_rows = db.pg_query("SHOW shared_preload_libraries")
preload = list(preload_rows[0].values())[0] if preload_rows else "?"
return (
f" Status: CONNECTED\n"
f" Server: {version[:80]}\n"
f"{ext_info}\n"
f" Preloaded: {preload}"
)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
GUC_PARAMS = [
"pg_plan_cache.redis_host",
"pg_plan_cache.redis_port",
"pg_plan_cache.ttl",
"pg_plan_cache.enabled",
"pg_plan_cache.redis_timeout_ms",
"pg_plan_cache.redis_pool_size",
]
@_safe
def get_extension_config(db: DatabaseManager) -> str:
"""Show current GUC parameter values."""
lines = []
for name in GUC_PARAMS:
rows = db.pg_query("SHOW " + name)
val = list(rows[0].values())[0] if rows else "N/A"
lines.append(f" {name} = {val}")
return "\n".join(lines)
@_safe
def set_extension_config(db: DatabaseManager, param: str, value: str) -> str:
"""Change a GUC parameter at runtime."""
if not param.startswith("pg_plan_cache."):
return " Refused: only pg_plan_cache.* parameters can be set."
if param not in GUC_PARAMS:
return f" Unknown parameter: {param}\n Valid: {', '.join(GUC_PARAMS)}"
db.pg_execute(f"ALTER SYSTEM SET {param} = %s", (value,))
db.pg_execute("SELECT pg_reload_conf()")
return f" Set {param} = {value} and reloaded configuration."
# ---------------------------------------------------------------------------
# Cache management
# ---------------------------------------------------------------------------
@_safe
def invalidate_all_plans(db: DatabaseManager) -> str:
"""Invalidate all cached plans by bumping schema version."""
db.pg_execute("SELECT pg_plan_cache_invalidate()")
return " All cached plans invalidated (schema version incremented)."
@_safe
def invalidate_table_plans(db: DatabaseManager, table_name: str) -> str:
"""Invalidate cached plans depending on a specific table."""
rows = db.pg_query(
"SELECT pg_plan_cache_invalidate_table(%s) AS count", (table_name,)
)
count = rows[0]["count"] if rows else 0
return f" Invalidated {count} plan(s) for table '{table_name}'."
@_safe
def flush_redis_plan_keys(db: DatabaseManager) -> str:
"""Hard-delete ALL pg_plan_cache keys from Redis."""
patterns = ["plan:*", "deps:table:*", "qdeps:*"]
total = 0
for pattern in patterns:
keys = db.redis_keys(pattern)
for k in keys:
db.redis_delete(k)
total += len(keys)
return f" Flushed {total} pg_plan_cache key(s) from Redis."
# ---------------------------------------------------------------------------
# Cache inspection
# ---------------------------------------------------------------------------
@_safe
def list_cached_plans(db: DatabaseManager, limit: int = 50) -> str:
"""List cached plan keys with TTL and size."""
keys = db.redis_keys("plan:*")
if not keys:
return " No cached plans found in Redis."
keys = sorted(keys)[:limit]
lines = []
for key in keys:
ttl = db.redis_ttl(key)
val = db.redis_get(key)
size = len(val) if val else 0
short_hash = key.replace("plan:", "")
lines.append(f" {short_hash} ttl={ttl}s size={size}B")
return f" Cached plans ({len(keys)} shown):\n" + "\n".join(lines)
@_safe
def get_cached_plan_detail(db: DatabaseManager, query_hash: str) -> str:
"""Inspect a specific cached plan entry."""
key = f"plan:{query_hash}"
val = db.redis_get(key)
if val is None:
return f" No cached plan found for hash {query_hash}."
ttl = db.redis_ttl(key)
parts = val.split("|", 3)
if len(parts) == 4:
schema_ver, cost, ts, plan_data = parts
preview = plan_data[:500] + ("..." if len(plan_data) > 500 else "")
return (
f" hash: {query_hash}\n"
f" schema_version: {schema_ver}\n"
f" total_cost: {cost}\n"
f" created_at: {ts}\n"
f" ttl_remaining: {ttl}s\n"
f" plan_size: {len(plan_data)} chars\n"
f" plan_preview: {preview}"
)
return f" Raw value (unparseable): ttl={ttl}s\n {val[:800]}"
@_safe
def get_table_dependencies(db: DatabaseManager, table_name: str = None, query_hash: str = None) -> str:
"""Show dependency relationships between queries and tables."""
lines = []
if table_name:
members = db.redis_smembers(f"deps:table:{table_name}")
if members:
lines.append(f" Queries depending on '{table_name}' ({len(members)}):")
for m in sorted(members):
lines.append(f" - {m}")
else:
lines.append(f" No queries depend on '{table_name}'.")
if query_hash:
members = db.redis_smembers(f"qdeps:{query_hash}")
if members:
lines.append(f" Tables that query {query_hash} depends on ({len(members)}):")
for m in sorted(members):
lines.append(f" - {m}")
else:
lines.append(f" No table dependencies for query {query_hash}.")
if not lines:
return " Provide a table name or query hash."
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Query normalization
# ---------------------------------------------------------------------------
@_safe
def normalize_and_hash(db: DatabaseManager, query: str) -> str:
"""Normalize a SQL query and compute its cache key."""
normalized = normalize_query(query)
qhash = compute_query_hash(normalized)
cached = db.redis_get(f"plan:{qhash}")
status = "CACHED" if cached else "NOT CACHED"
return (
f" Original: {query}\n"
f" Normalized: {normalized}\n"
f" Hash: {qhash}\n"
f" Status: {status}"
)
# ---------------------------------------------------------------------------
# SQL passthrough
# ---------------------------------------------------------------------------
@_safe
def run_sql_query(db: DatabaseManager, sql: str) -> str:
"""Execute a read-only SELECT query."""
stripped = sql.strip()
if not stripped.upper().startswith("SELECT"):
return " Refused: only SELECT queries are allowed."
rows = db.pg_query(stripped)
if not rows:
return " (no rows returned)"
cols = list(rows[0].keys())
widths = [max(len(c), max(len(str(r.get(c, ""))) for r in rows[:100])) for c in cols]
header = " " + " | ".join(c.ljust(w) for c, w in zip(cols, widths))
sep = " " + "-+-".join("-" * w for w in widths)
body = []
for row in rows[:100]:
body.append(" " + " | ".join(str(row.get(c, "")).ljust(w) for c, w in zip(cols, widths)))
result = header + "\n" + sep + "\n" + "\n".join(body)
if len(rows) > 100:
result += f"\n ... ({len(rows)} total rows, first 100 shown)"
return result
# ---------------------------------------------------------------------------
# Analysis & diagnosis
# ---------------------------------------------------------------------------
@_safe
def analyze_cache_efficiency(db: DatabaseManager) -> str:
"""Analyze hit ratio, error rate, and provide recommendations."""
rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()")
if not rows:
return " Cannot retrieve stats. Is the extension loaded?"
s = rows[0]
hits = int(s.get("cache_hits", 0))
misses = int(s.get("cache_misses", 0))
errors = int(s.get("cache_errors", 0))
invalidations = int(s.get("cache_invalidations", 0))
stored = int(s.get("plans_stored", 0))
timeouts = int(s.get("redis_timeouts", 0))
total = hits + misses
hit_ratio = (hits / total * 100) if total > 0 else 0
error_rate = (errors / total * 100) if total > 0 else 0
plan_keys = db.redis_keys("plan:*")
plan_count = len(plan_keys)
total_size = 0
for k in plan_keys:
val = db.redis_get(k)
if val:
total_size += len(val)
cfg_rows = db.pg_query("SHOW pg_plan_cache.ttl")
ttl = list(cfg_rows[0].values())[0] if cfg_rows else "?"
lines = [
" === Cache Efficiency ===",
f" Total lookups: {total}",
f" Hits: {hits}",
f" Misses: {misses}",
f" Hit ratio: {hit_ratio:.1f}%",
f" Errors: {errors} ({error_rate:.1f}%)",
f" Timeouts: {timeouts}",
f" Invalidations: {invalidations}",
f" Plans stored: {stored}",
f" Plans in Redis: {plan_count}",
f" Total cache size: {total_size:,} bytes",
f" Current TTL: {ttl}s",
"",
" === Recommendations ===",
]
recs = []
if total == 0:
recs.append(" * No queries observed yet. Run some workload first.")
else:
if hit_ratio >= 90:
recs.append(" * Excellent hit ratio (>90%). Cache is very effective.")
elif hit_ratio >= 50:
recs.append(f" * Moderate hit ratio ({hit_ratio:.0f}%). Consider increasing TTL.")
else:
recs.append(
f" * Low hit ratio ({hit_ratio:.0f}%). Check if queries have high "
"literal variation or if TTL is too short."
)
if error_rate > 5:
recs.append(" * High error rate. Check Redis connectivity and timeout settings.")
if timeouts > 0:
recs.append(" * Redis timeouts detected. Consider increasing redis_timeout_ms.")
if invalidations > stored * 0.5 and stored > 0:
recs.append(" * High invalidation ratio — frequent schema changes hurt cache.")
if plan_count > 10000:
recs.append(" * Many cached plans. Consider lowering TTL or setting Redis maxmemory.")
if plan_count > 0 and total_size / plan_count > 50000:
recs.append(" * Large average plan size. Complex queries may bloat Redis memory.")
if not recs:
recs.append(" * Cache looks healthy. No action needed.")
lines.extend(recs)
return "\n".join(lines)
@_safe
def diagnose(db: DatabaseManager) -> str:
"""Run a full diagnostic: connectivity, extension, stats, Redis state."""
sections = []
# 1. PostgreSQL
sections.append("[PostgreSQL]")
try:
rows = db.pg_query("SELECT version()")
sections.append(f" Connected: {rows[0]['version'][:80]}")
ext = db.pg_query(
"SELECT extversion FROM pg_extension WHERE extname = 'pg_plan_cache'"
)
if ext:
sections.append(f" Extension: v{ext[0]['extversion']} installed")
else:
sections.append(" Extension: NOT INSTALLED <-- run CREATE EXTENSION pg_plan_cache;")
except Exception as e:
sections.append(f" FAILED: {e}")
# 2. Redis
sections.append("\n[Redis]")
try:
if db.redis_ping():
info = db.redis_info()
sections.append(f" Connected: v{info.get('redis_version', '?')}")
sections.append(f" Memory: {info.get('used_memory_human', '?')}")
plan_count = len(db.redis_keys("plan:*"))
sections.append(f" Cached plans: {plan_count}")
else:
sections.append(" FAILED: ping returned false")
except Exception as e:
sections.append(f" FAILED: {e}")
# 3. Stats
sections.append("\n[Cache Stats]")
try:
rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()")
if rows:
s = rows[0]
total = int(s['cache_hits']) + int(s['cache_misses'])
ratio = (int(s['cache_hits']) / total * 100) if total > 0 else 0
sections.append(f" Hits/Misses: {s['cache_hits']}/{s['cache_misses']} ({ratio:.1f}% hit)")
sections.append(f" Errors: {s['cache_errors']}, Timeouts: {s['redis_timeouts']}")
sections.append(f" Pool: {s['redis_pool_active']} active, {s['redis_pool_idle']} idle")
else:
sections.append(" No stats available")
except Exception as e:
sections.append(f" FAILED: {e}")
# 4. Config
sections.append("\n[Configuration]")
try:
for name in GUC_PARAMS:
rows = db.pg_query("SHOW " + name)
val = list(rows[0].values())[0] if rows else "N/A"
sections.append(f" {name} = {val}")
except Exception as e:
sections.append(f" FAILED: {e}")
return "\n".join(sections)
# ---------------------------------------------------------------------------
# Watch mode (live stats refresh)
# ---------------------------------------------------------------------------
def watch_stats(db: DatabaseManager, interval: int = 3, count: int = 10):
"""Print cache stats repeatedly. Returns nothing (prints directly)."""
print(f" Watching stats every {interval}s (press Ctrl+C to stop)...\n")
header = f" {'time':>8} {'hits':>8} {'misses':>8} {'ratio':>7} {'errors':>7} {'stored':>8} {'plans':>6}"
print(header)
print(" " + "-" * (len(header) - 2))
prev_hits = prev_misses = None
try:
for _ in range(count):
rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()")
if not rows:
print(" (no stats)")
break
s = rows[0]
h, m = int(s['cache_hits']), int(s['cache_misses'])
total = h + m
ratio = f"{h/total*100:.1f}%" if total > 0 else "N/A"
plan_count = len(db.redis_keys("plan:*"))
ts = time.strftime("%H:%M:%S")
delta = ""
if prev_hits is not None:
dh = h - prev_hits
dm = m - prev_misses
if dh or dm:
delta = f" (+{dh} hits, +{dm} misses)"
print(f" {ts:>8} {h:>8} {m:>8} {ratio:>7} {s['cache_errors']:>7} {s['plans_stored']:>8} {plan_count:>6}{delta}")
prev_hits, prev_misses = h, m
time.sleep(interval)
except KeyboardInterrupt:
pass
print()