""" Tool implementations for the pg_plan_cache agent. Each function takes a DatabaseManager and optional args, returns a printable string. """ import time from db import DatabaseManager from normalizer import normalize_query, compute_query_hash def _safe(fn): """Wrap a tool so exceptions become readable error strings.""" def wrapper(*args, **kwargs): try: return fn(*args, **kwargs) except Exception as e: return f" Error: {type(e).__name__}: {e}" wrapper.__name__ = fn.__name__ wrapper.__doc__ = fn.__doc__ return wrapper # --------------------------------------------------------------------------- # Stats & monitoring # --------------------------------------------------------------------------- @_safe def get_cache_stats(db: DatabaseManager) -> str: """Fetch cache statistics from pg_plan_cache_stats().""" rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()") if not rows: return " No stats returned. Is the pg_plan_cache extension loaded?" s = rows[0] return ( f" cache_hits: {s['cache_hits']}\n" f" cache_misses: {s['cache_misses']}\n" f" cache_errors: {s['cache_errors']}\n" f" cache_invalidations: {s['cache_invalidations']}\n" f" redis_timeouts: {s['redis_timeouts']}\n" f" plans_stored: {s['plans_stored']}\n" f" redis_pool_active: {s['redis_pool_active']}\n" f" redis_pool_idle: {s['redis_pool_idle']}" ) @_safe def check_redis_health(db: DatabaseManager) -> str: """Check Redis connectivity and report server info.""" alive = db.redis_ping() if not alive: return " Redis is NOT reachable." info = db.redis_info() plan_keys = db.redis_keys("plan:*") dep_keys = db.redis_keys("deps:table:*") qdep_keys = db.redis_keys("qdeps:*") return ( f" Status: CONNECTED\n" f" redis_version: {info.get('redis_version', '?')}\n" f" uptime_seconds: {info.get('uptime_in_seconds', '?')}\n" f" connected_clients: {info.get('connected_clients', '?')}\n" f" used_memory_human: {info.get('used_memory_human', '?')}\n" f" peak_memory: {info.get('used_memory_peak_human', '?')}\n" f" total_commands: {info.get('total_commands_processed', '?')}\n" f" plan keys: {len(plan_keys)}\n" f" table dep sets: {len(dep_keys)}\n" f" query dep sets: {len(qdep_keys)}" ) @_safe def check_pg_health(db: DatabaseManager) -> str: """Check PostgreSQL connectivity and extension status.""" rows = db.pg_query("SELECT version()") version = rows[0]["version"] if rows else "unknown" ext_rows = db.pg_query( "SELECT extname, extversion FROM pg_extension WHERE extname = 'pg_plan_cache'" ) if ext_rows: ext_info = f" Extension: pg_plan_cache v{ext_rows[0]['extversion']}" else: ext_info = " Extension: NOT INSTALLED" preload_rows = db.pg_query("SHOW shared_preload_libraries") preload = list(preload_rows[0].values())[0] if preload_rows else "?" return ( f" Status: CONNECTED\n" f" Server: {version[:80]}\n" f"{ext_info}\n" f" Preloaded: {preload}" ) # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- GUC_PARAMS = [ "pg_plan_cache.redis_host", "pg_plan_cache.redis_port", "pg_plan_cache.ttl", "pg_plan_cache.enabled", "pg_plan_cache.redis_timeout_ms", "pg_plan_cache.redis_pool_size", ] @_safe def get_extension_config(db: DatabaseManager) -> str: """Show current GUC parameter values.""" lines = [] for name in GUC_PARAMS: rows = db.pg_query("SHOW " + name) val = list(rows[0].values())[0] if rows else "N/A" lines.append(f" {name} = {val}") return "\n".join(lines) @_safe def set_extension_config(db: DatabaseManager, param: str, value: str) -> str: """Change a GUC parameter at runtime.""" if not param.startswith("pg_plan_cache."): return " Refused: only pg_plan_cache.* parameters can be set." if param not in GUC_PARAMS: return f" Unknown parameter: {param}\n Valid: {', '.join(GUC_PARAMS)}" db.pg_execute(f"ALTER SYSTEM SET {param} = %s", (value,)) db.pg_execute("SELECT pg_reload_conf()") return f" Set {param} = {value} and reloaded configuration." # --------------------------------------------------------------------------- # Cache management # --------------------------------------------------------------------------- @_safe def invalidate_all_plans(db: DatabaseManager) -> str: """Invalidate all cached plans by bumping schema version.""" db.pg_execute("SELECT pg_plan_cache_invalidate()") return " All cached plans invalidated (schema version incremented)." @_safe def invalidate_table_plans(db: DatabaseManager, table_name: str) -> str: """Invalidate cached plans depending on a specific table.""" rows = db.pg_query( "SELECT pg_plan_cache_invalidate_table(%s) AS count", (table_name,) ) count = rows[0]["count"] if rows else 0 return f" Invalidated {count} plan(s) for table '{table_name}'." @_safe def flush_redis_plan_keys(db: DatabaseManager) -> str: """Hard-delete ALL pg_plan_cache keys from Redis.""" patterns = ["plan:*", "deps:table:*", "qdeps:*"] total = 0 for pattern in patterns: keys = db.redis_keys(pattern) for k in keys: db.redis_delete(k) total += len(keys) return f" Flushed {total} pg_plan_cache key(s) from Redis." # --------------------------------------------------------------------------- # Cache inspection # --------------------------------------------------------------------------- @_safe def list_cached_plans(db: DatabaseManager, limit: int = 50) -> str: """List cached plan keys with TTL and size.""" keys = db.redis_keys("plan:*") if not keys: return " No cached plans found in Redis." keys = sorted(keys)[:limit] lines = [] for key in keys: ttl = db.redis_ttl(key) val = db.redis_get(key) size = len(val) if val else 0 short_hash = key.replace("plan:", "") lines.append(f" {short_hash} ttl={ttl}s size={size}B") return f" Cached plans ({len(keys)} shown):\n" + "\n".join(lines) @_safe def get_cached_plan_detail(db: DatabaseManager, query_hash: str) -> str: """Inspect a specific cached plan entry.""" key = f"plan:{query_hash}" val = db.redis_get(key) if val is None: return f" No cached plan found for hash {query_hash}." ttl = db.redis_ttl(key) parts = val.split("|", 3) if len(parts) == 4: schema_ver, cost, ts, plan_data = parts preview = plan_data[:500] + ("..." if len(plan_data) > 500 else "") return ( f" hash: {query_hash}\n" f" schema_version: {schema_ver}\n" f" total_cost: {cost}\n" f" created_at: {ts}\n" f" ttl_remaining: {ttl}s\n" f" plan_size: {len(plan_data)} chars\n" f" plan_preview: {preview}" ) return f" Raw value (unparseable): ttl={ttl}s\n {val[:800]}" @_safe def get_table_dependencies(db: DatabaseManager, table_name: str = None, query_hash: str = None) -> str: """Show dependency relationships between queries and tables.""" lines = [] if table_name: members = db.redis_smembers(f"deps:table:{table_name}") if members: lines.append(f" Queries depending on '{table_name}' ({len(members)}):") for m in sorted(members): lines.append(f" - {m}") else: lines.append(f" No queries depend on '{table_name}'.") if query_hash: members = db.redis_smembers(f"qdeps:{query_hash}") if members: lines.append(f" Tables that query {query_hash} depends on ({len(members)}):") for m in sorted(members): lines.append(f" - {m}") else: lines.append(f" No table dependencies for query {query_hash}.") if not lines: return " Provide a table name or query hash." return "\n".join(lines) # --------------------------------------------------------------------------- # Query normalization # --------------------------------------------------------------------------- @_safe def normalize_and_hash(db: DatabaseManager, query: str) -> str: """Normalize a SQL query and compute its cache key.""" normalized = normalize_query(query) qhash = compute_query_hash(normalized) cached = db.redis_get(f"plan:{qhash}") status = "CACHED" if cached else "NOT CACHED" return ( f" Original: {query}\n" f" Normalized: {normalized}\n" f" Hash: {qhash}\n" f" Status: {status}" ) # --------------------------------------------------------------------------- # SQL passthrough # --------------------------------------------------------------------------- @_safe def run_sql_query(db: DatabaseManager, sql: str) -> str: """Execute a read-only SELECT query.""" stripped = sql.strip() if not stripped.upper().startswith("SELECT"): return " Refused: only SELECT queries are allowed." rows = db.pg_query(stripped) if not rows: return " (no rows returned)" cols = list(rows[0].keys()) widths = [max(len(c), max(len(str(r.get(c, ""))) for r in rows[:100])) for c in cols] header = " " + " | ".join(c.ljust(w) for c, w in zip(cols, widths)) sep = " " + "-+-".join("-" * w for w in widths) body = [] for row in rows[:100]: body.append(" " + " | ".join(str(row.get(c, "")).ljust(w) for c, w in zip(cols, widths))) result = header + "\n" + sep + "\n" + "\n".join(body) if len(rows) > 100: result += f"\n ... ({len(rows)} total rows, first 100 shown)" return result # --------------------------------------------------------------------------- # Analysis & diagnosis # --------------------------------------------------------------------------- @_safe def analyze_cache_efficiency(db: DatabaseManager) -> str: """Analyze hit ratio, error rate, and provide recommendations.""" rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()") if not rows: return " Cannot retrieve stats. Is the extension loaded?" s = rows[0] hits = int(s.get("cache_hits", 0)) misses = int(s.get("cache_misses", 0)) errors = int(s.get("cache_errors", 0)) invalidations = int(s.get("cache_invalidations", 0)) stored = int(s.get("plans_stored", 0)) timeouts = int(s.get("redis_timeouts", 0)) total = hits + misses hit_ratio = (hits / total * 100) if total > 0 else 0 error_rate = (errors / total * 100) if total > 0 else 0 plan_keys = db.redis_keys("plan:*") plan_count = len(plan_keys) total_size = 0 for k in plan_keys: val = db.redis_get(k) if val: total_size += len(val) cfg_rows = db.pg_query("SHOW pg_plan_cache.ttl") ttl = list(cfg_rows[0].values())[0] if cfg_rows else "?" lines = [ " === Cache Efficiency ===", f" Total lookups: {total}", f" Hits: {hits}", f" Misses: {misses}", f" Hit ratio: {hit_ratio:.1f}%", f" Errors: {errors} ({error_rate:.1f}%)", f" Timeouts: {timeouts}", f" Invalidations: {invalidations}", f" Plans stored: {stored}", f" Plans in Redis: {plan_count}", f" Total cache size: {total_size:,} bytes", f" Current TTL: {ttl}s", "", " === Recommendations ===", ] recs = [] if total == 0: recs.append(" * No queries observed yet. Run some workload first.") else: if hit_ratio >= 90: recs.append(" * Excellent hit ratio (>90%). Cache is very effective.") elif hit_ratio >= 50: recs.append(f" * Moderate hit ratio ({hit_ratio:.0f}%). Consider increasing TTL.") else: recs.append( f" * Low hit ratio ({hit_ratio:.0f}%). Check if queries have high " "literal variation or if TTL is too short." ) if error_rate > 5: recs.append(" * High error rate. Check Redis connectivity and timeout settings.") if timeouts > 0: recs.append(" * Redis timeouts detected. Consider increasing redis_timeout_ms.") if invalidations > stored * 0.5 and stored > 0: recs.append(" * High invalidation ratio — frequent schema changes hurt cache.") if plan_count > 10000: recs.append(" * Many cached plans. Consider lowering TTL or setting Redis maxmemory.") if plan_count > 0 and total_size / plan_count > 50000: recs.append(" * Large average plan size. Complex queries may bloat Redis memory.") if not recs: recs.append(" * Cache looks healthy. No action needed.") lines.extend(recs) return "\n".join(lines) @_safe def diagnose(db: DatabaseManager) -> str: """Run a full diagnostic: connectivity, extension, stats, Redis state.""" sections = [] # 1. PostgreSQL sections.append("[PostgreSQL]") try: rows = db.pg_query("SELECT version()") sections.append(f" Connected: {rows[0]['version'][:80]}") ext = db.pg_query( "SELECT extversion FROM pg_extension WHERE extname = 'pg_plan_cache'" ) if ext: sections.append(f" Extension: v{ext[0]['extversion']} installed") else: sections.append(" Extension: NOT INSTALLED <-- run CREATE EXTENSION pg_plan_cache;") except Exception as e: sections.append(f" FAILED: {e}") # 2. Redis sections.append("\n[Redis]") try: if db.redis_ping(): info = db.redis_info() sections.append(f" Connected: v{info.get('redis_version', '?')}") sections.append(f" Memory: {info.get('used_memory_human', '?')}") plan_count = len(db.redis_keys("plan:*")) sections.append(f" Cached plans: {plan_count}") else: sections.append(" FAILED: ping returned false") except Exception as e: sections.append(f" FAILED: {e}") # 3. Stats sections.append("\n[Cache Stats]") try: rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()") if rows: s = rows[0] total = int(s['cache_hits']) + int(s['cache_misses']) ratio = (int(s['cache_hits']) / total * 100) if total > 0 else 0 sections.append(f" Hits/Misses: {s['cache_hits']}/{s['cache_misses']} ({ratio:.1f}% hit)") sections.append(f" Errors: {s['cache_errors']}, Timeouts: {s['redis_timeouts']}") sections.append(f" Pool: {s['redis_pool_active']} active, {s['redis_pool_idle']} idle") else: sections.append(" No stats available") except Exception as e: sections.append(f" FAILED: {e}") # 4. Config sections.append("\n[Configuration]") try: for name in GUC_PARAMS: rows = db.pg_query("SHOW " + name) val = list(rows[0].values())[0] if rows else "N/A" sections.append(f" {name} = {val}") except Exception as e: sections.append(f" FAILED: {e}") return "\n".join(sections) # --------------------------------------------------------------------------- # Watch mode (live stats refresh) # --------------------------------------------------------------------------- def watch_stats(db: DatabaseManager, interval: int = 3, count: int = 10): """Print cache stats repeatedly. Returns nothing (prints directly).""" print(f" Watching stats every {interval}s (press Ctrl+C to stop)...\n") header = f" {'time':>8} {'hits':>8} {'misses':>8} {'ratio':>7} {'errors':>7} {'stored':>8} {'plans':>6}" print(header) print(" " + "-" * (len(header) - 2)) prev_hits = prev_misses = None try: for _ in range(count): rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()") if not rows: print(" (no stats)") break s = rows[0] h, m = int(s['cache_hits']), int(s['cache_misses']) total = h + m ratio = f"{h/total*100:.1f}%" if total > 0 else "N/A" plan_count = len(db.redis_keys("plan:*")) ts = time.strftime("%H:%M:%S") delta = "" if prev_hits is not None: dh = h - prev_hits dm = m - prev_misses if dh or dm: delta = f" (+{dh} hits, +{dm} misses)" print(f" {ts:>8} {h:>8} {m:>8} {ratio:>7} {s['cache_errors']:>7} {s['plans_stored']:>8} {plan_count:>6}{delta}") prev_hits, prev_misses = h, m time.sleep(interval) except KeyboardInterrupt: pass print()