Spaces:
Sleeping
Sleeping
| """ | |
| Tool implementations for the pg_plan_cache agent. | |
| Each function takes a DatabaseManager and optional args, returns a printable string. | |
| """ | |
| import time | |
| from db import DatabaseManager | |
| from normalizer import normalize_query, compute_query_hash | |
| def _safe(fn): | |
| """Wrap a tool so exceptions become readable error strings.""" | |
| def wrapper(*args, **kwargs): | |
| try: | |
| return fn(*args, **kwargs) | |
| except Exception as e: | |
| return f" Error: {type(e).__name__}: {e}" | |
| wrapper.__name__ = fn.__name__ | |
| wrapper.__doc__ = fn.__doc__ | |
| return wrapper | |
| # --------------------------------------------------------------------------- | |
| # Stats & monitoring | |
| # --------------------------------------------------------------------------- | |
| def get_cache_stats(db: DatabaseManager) -> str: | |
| """Fetch cache statistics from pg_plan_cache_stats().""" | |
| rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()") | |
| if not rows: | |
| return " No stats returned. Is the pg_plan_cache extension loaded?" | |
| s = rows[0] | |
| return ( | |
| f" cache_hits: {s['cache_hits']}\n" | |
| f" cache_misses: {s['cache_misses']}\n" | |
| f" cache_errors: {s['cache_errors']}\n" | |
| f" cache_invalidations: {s['cache_invalidations']}\n" | |
| f" redis_timeouts: {s['redis_timeouts']}\n" | |
| f" plans_stored: {s['plans_stored']}\n" | |
| f" redis_pool_active: {s['redis_pool_active']}\n" | |
| f" redis_pool_idle: {s['redis_pool_idle']}" | |
| ) | |
| def check_redis_health(db: DatabaseManager) -> str: | |
| """Check Redis connectivity and report server info.""" | |
| alive = db.redis_ping() | |
| if not alive: | |
| return " Redis is NOT reachable." | |
| info = db.redis_info() | |
| plan_keys = db.redis_keys("plan:*") | |
| dep_keys = db.redis_keys("deps:table:*") | |
| qdep_keys = db.redis_keys("qdeps:*") | |
| return ( | |
| f" Status: CONNECTED\n" | |
| f" redis_version: {info.get('redis_version', '?')}\n" | |
| f" uptime_seconds: {info.get('uptime_in_seconds', '?')}\n" | |
| f" connected_clients: {info.get('connected_clients', '?')}\n" | |
| f" used_memory_human: {info.get('used_memory_human', '?')}\n" | |
| f" peak_memory: {info.get('used_memory_peak_human', '?')}\n" | |
| f" total_commands: {info.get('total_commands_processed', '?')}\n" | |
| f" plan keys: {len(plan_keys)}\n" | |
| f" table dep sets: {len(dep_keys)}\n" | |
| f" query dep sets: {len(qdep_keys)}" | |
| ) | |
| def check_pg_health(db: DatabaseManager) -> str: | |
| """Check PostgreSQL connectivity and extension status.""" | |
| rows = db.pg_query("SELECT version()") | |
| version = rows[0]["version"] if rows else "unknown" | |
| ext_rows = db.pg_query( | |
| "SELECT extname, extversion FROM pg_extension WHERE extname = 'pg_plan_cache'" | |
| ) | |
| if ext_rows: | |
| ext_info = f" Extension: pg_plan_cache v{ext_rows[0]['extversion']}" | |
| else: | |
| ext_info = " Extension: NOT INSTALLED" | |
| preload_rows = db.pg_query("SHOW shared_preload_libraries") | |
| preload = list(preload_rows[0].values())[0] if preload_rows else "?" | |
| return ( | |
| f" Status: CONNECTED\n" | |
| f" Server: {version[:80]}\n" | |
| f"{ext_info}\n" | |
| f" Preloaded: {preload}" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| GUC_PARAMS = [ | |
| "pg_plan_cache.redis_host", | |
| "pg_plan_cache.redis_port", | |
| "pg_plan_cache.ttl", | |
| "pg_plan_cache.enabled", | |
| "pg_plan_cache.redis_timeout_ms", | |
| "pg_plan_cache.redis_pool_size", | |
| ] | |
| def get_extension_config(db: DatabaseManager) -> str: | |
| """Show current GUC parameter values.""" | |
| lines = [] | |
| for name in GUC_PARAMS: | |
| rows = db.pg_query("SHOW " + name) | |
| val = list(rows[0].values())[0] if rows else "N/A" | |
| lines.append(f" {name} = {val}") | |
| return "\n".join(lines) | |
| def set_extension_config(db: DatabaseManager, param: str, value: str) -> str: | |
| """Change a GUC parameter at runtime.""" | |
| if not param.startswith("pg_plan_cache."): | |
| return " Refused: only pg_plan_cache.* parameters can be set." | |
| if param not in GUC_PARAMS: | |
| return f" Unknown parameter: {param}\n Valid: {', '.join(GUC_PARAMS)}" | |
| db.pg_execute(f"ALTER SYSTEM SET {param} = %s", (value,)) | |
| db.pg_execute("SELECT pg_reload_conf()") | |
| return f" Set {param} = {value} and reloaded configuration." | |
| # --------------------------------------------------------------------------- | |
| # Cache management | |
| # --------------------------------------------------------------------------- | |
| def invalidate_all_plans(db: DatabaseManager) -> str: | |
| """Invalidate all cached plans by bumping schema version.""" | |
| db.pg_execute("SELECT pg_plan_cache_invalidate()") | |
| return " All cached plans invalidated (schema version incremented)." | |
| def invalidate_table_plans(db: DatabaseManager, table_name: str) -> str: | |
| """Invalidate cached plans depending on a specific table.""" | |
| rows = db.pg_query( | |
| "SELECT pg_plan_cache_invalidate_table(%s) AS count", (table_name,) | |
| ) | |
| count = rows[0]["count"] if rows else 0 | |
| return f" Invalidated {count} plan(s) for table '{table_name}'." | |
| def flush_redis_plan_keys(db: DatabaseManager) -> str: | |
| """Hard-delete ALL pg_plan_cache keys from Redis.""" | |
| patterns = ["plan:*", "deps:table:*", "qdeps:*"] | |
| total = 0 | |
| for pattern in patterns: | |
| keys = db.redis_keys(pattern) | |
| for k in keys: | |
| db.redis_delete(k) | |
| total += len(keys) | |
| return f" Flushed {total} pg_plan_cache key(s) from Redis." | |
| # --------------------------------------------------------------------------- | |
| # Cache inspection | |
| # --------------------------------------------------------------------------- | |
| def list_cached_plans(db: DatabaseManager, limit: int = 50) -> str: | |
| """List cached plan keys with TTL and size.""" | |
| keys = db.redis_keys("plan:*") | |
| if not keys: | |
| return " No cached plans found in Redis." | |
| keys = sorted(keys)[:limit] | |
| lines = [] | |
| for key in keys: | |
| ttl = db.redis_ttl(key) | |
| val = db.redis_get(key) | |
| size = len(val) if val else 0 | |
| short_hash = key.replace("plan:", "") | |
| lines.append(f" {short_hash} ttl={ttl}s size={size}B") | |
| return f" Cached plans ({len(keys)} shown):\n" + "\n".join(lines) | |
| def get_cached_plan_detail(db: DatabaseManager, query_hash: str) -> str: | |
| """Inspect a specific cached plan entry.""" | |
| key = f"plan:{query_hash}" | |
| val = db.redis_get(key) | |
| if val is None: | |
| return f" No cached plan found for hash {query_hash}." | |
| ttl = db.redis_ttl(key) | |
| parts = val.split("|", 3) | |
| if len(parts) == 4: | |
| schema_ver, cost, ts, plan_data = parts | |
| preview = plan_data[:500] + ("..." if len(plan_data) > 500 else "") | |
| return ( | |
| f" hash: {query_hash}\n" | |
| f" schema_version: {schema_ver}\n" | |
| f" total_cost: {cost}\n" | |
| f" created_at: {ts}\n" | |
| f" ttl_remaining: {ttl}s\n" | |
| f" plan_size: {len(plan_data)} chars\n" | |
| f" plan_preview: {preview}" | |
| ) | |
| return f" Raw value (unparseable): ttl={ttl}s\n {val[:800]}" | |
| def get_table_dependencies(db: DatabaseManager, table_name: str = None, query_hash: str = None) -> str: | |
| """Show dependency relationships between queries and tables.""" | |
| lines = [] | |
| if table_name: | |
| members = db.redis_smembers(f"deps:table:{table_name}") | |
| if members: | |
| lines.append(f" Queries depending on '{table_name}' ({len(members)}):") | |
| for m in sorted(members): | |
| lines.append(f" - {m}") | |
| else: | |
| lines.append(f" No queries depend on '{table_name}'.") | |
| if query_hash: | |
| members = db.redis_smembers(f"qdeps:{query_hash}") | |
| if members: | |
| lines.append(f" Tables that query {query_hash} depends on ({len(members)}):") | |
| for m in sorted(members): | |
| lines.append(f" - {m}") | |
| else: | |
| lines.append(f" No table dependencies for query {query_hash}.") | |
| if not lines: | |
| return " Provide a table name or query hash." | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # Query normalization | |
| # --------------------------------------------------------------------------- | |
| def normalize_and_hash(db: DatabaseManager, query: str) -> str: | |
| """Normalize a SQL query and compute its cache key.""" | |
| normalized = normalize_query(query) | |
| qhash = compute_query_hash(normalized) | |
| cached = db.redis_get(f"plan:{qhash}") | |
| status = "CACHED" if cached else "NOT CACHED" | |
| return ( | |
| f" Original: {query}\n" | |
| f" Normalized: {normalized}\n" | |
| f" Hash: {qhash}\n" | |
| f" Status: {status}" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # SQL passthrough | |
| # --------------------------------------------------------------------------- | |
| def run_sql_query(db: DatabaseManager, sql: str) -> str: | |
| """Execute a read-only SELECT query.""" | |
| stripped = sql.strip() | |
| if not stripped.upper().startswith("SELECT"): | |
| return " Refused: only SELECT queries are allowed." | |
| rows = db.pg_query(stripped) | |
| if not rows: | |
| return " (no rows returned)" | |
| cols = list(rows[0].keys()) | |
| widths = [max(len(c), max(len(str(r.get(c, ""))) for r in rows[:100])) for c in cols] | |
| header = " " + " | ".join(c.ljust(w) for c, w in zip(cols, widths)) | |
| sep = " " + "-+-".join("-" * w for w in widths) | |
| body = [] | |
| for row in rows[:100]: | |
| body.append(" " + " | ".join(str(row.get(c, "")).ljust(w) for c, w in zip(cols, widths))) | |
| result = header + "\n" + sep + "\n" + "\n".join(body) | |
| if len(rows) > 100: | |
| result += f"\n ... ({len(rows)} total rows, first 100 shown)" | |
| return result | |
| # --------------------------------------------------------------------------- | |
| # Analysis & diagnosis | |
| # --------------------------------------------------------------------------- | |
| def analyze_cache_efficiency(db: DatabaseManager) -> str: | |
| """Analyze hit ratio, error rate, and provide recommendations.""" | |
| rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()") | |
| if not rows: | |
| return " Cannot retrieve stats. Is the extension loaded?" | |
| s = rows[0] | |
| hits = int(s.get("cache_hits", 0)) | |
| misses = int(s.get("cache_misses", 0)) | |
| errors = int(s.get("cache_errors", 0)) | |
| invalidations = int(s.get("cache_invalidations", 0)) | |
| stored = int(s.get("plans_stored", 0)) | |
| timeouts = int(s.get("redis_timeouts", 0)) | |
| total = hits + misses | |
| hit_ratio = (hits / total * 100) if total > 0 else 0 | |
| error_rate = (errors / total * 100) if total > 0 else 0 | |
| plan_keys = db.redis_keys("plan:*") | |
| plan_count = len(plan_keys) | |
| total_size = 0 | |
| for k in plan_keys: | |
| val = db.redis_get(k) | |
| if val: | |
| total_size += len(val) | |
| cfg_rows = db.pg_query("SHOW pg_plan_cache.ttl") | |
| ttl = list(cfg_rows[0].values())[0] if cfg_rows else "?" | |
| lines = [ | |
| " === Cache Efficiency ===", | |
| f" Total lookups: {total}", | |
| f" Hits: {hits}", | |
| f" Misses: {misses}", | |
| f" Hit ratio: {hit_ratio:.1f}%", | |
| f" Errors: {errors} ({error_rate:.1f}%)", | |
| f" Timeouts: {timeouts}", | |
| f" Invalidations: {invalidations}", | |
| f" Plans stored: {stored}", | |
| f" Plans in Redis: {plan_count}", | |
| f" Total cache size: {total_size:,} bytes", | |
| f" Current TTL: {ttl}s", | |
| "", | |
| " === Recommendations ===", | |
| ] | |
| recs = [] | |
| if total == 0: | |
| recs.append(" * No queries observed yet. Run some workload first.") | |
| else: | |
| if hit_ratio >= 90: | |
| recs.append(" * Excellent hit ratio (>90%). Cache is very effective.") | |
| elif hit_ratio >= 50: | |
| recs.append(f" * Moderate hit ratio ({hit_ratio:.0f}%). Consider increasing TTL.") | |
| else: | |
| recs.append( | |
| f" * Low hit ratio ({hit_ratio:.0f}%). Check if queries have high " | |
| "literal variation or if TTL is too short." | |
| ) | |
| if error_rate > 5: | |
| recs.append(" * High error rate. Check Redis connectivity and timeout settings.") | |
| if timeouts > 0: | |
| recs.append(" * Redis timeouts detected. Consider increasing redis_timeout_ms.") | |
| if invalidations > stored * 0.5 and stored > 0: | |
| recs.append(" * High invalidation ratio — frequent schema changes hurt cache.") | |
| if plan_count > 10000: | |
| recs.append(" * Many cached plans. Consider lowering TTL or setting Redis maxmemory.") | |
| if plan_count > 0 and total_size / plan_count > 50000: | |
| recs.append(" * Large average plan size. Complex queries may bloat Redis memory.") | |
| if not recs: | |
| recs.append(" * Cache looks healthy. No action needed.") | |
| lines.extend(recs) | |
| return "\n".join(lines) | |
| def diagnose(db: DatabaseManager) -> str: | |
| """Run a full diagnostic: connectivity, extension, stats, Redis state.""" | |
| sections = [] | |
| # 1. PostgreSQL | |
| sections.append("[PostgreSQL]") | |
| try: | |
| rows = db.pg_query("SELECT version()") | |
| sections.append(f" Connected: {rows[0]['version'][:80]}") | |
| ext = db.pg_query( | |
| "SELECT extversion FROM pg_extension WHERE extname = 'pg_plan_cache'" | |
| ) | |
| if ext: | |
| sections.append(f" Extension: v{ext[0]['extversion']} installed") | |
| else: | |
| sections.append(" Extension: NOT INSTALLED <-- run CREATE EXTENSION pg_plan_cache;") | |
| except Exception as e: | |
| sections.append(f" FAILED: {e}") | |
| # 2. Redis | |
| sections.append("\n[Redis]") | |
| try: | |
| if db.redis_ping(): | |
| info = db.redis_info() | |
| sections.append(f" Connected: v{info.get('redis_version', '?')}") | |
| sections.append(f" Memory: {info.get('used_memory_human', '?')}") | |
| plan_count = len(db.redis_keys("plan:*")) | |
| sections.append(f" Cached plans: {plan_count}") | |
| else: | |
| sections.append(" FAILED: ping returned false") | |
| except Exception as e: | |
| sections.append(f" FAILED: {e}") | |
| # 3. Stats | |
| sections.append("\n[Cache Stats]") | |
| try: | |
| rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()") | |
| if rows: | |
| s = rows[0] | |
| total = int(s['cache_hits']) + int(s['cache_misses']) | |
| ratio = (int(s['cache_hits']) / total * 100) if total > 0 else 0 | |
| sections.append(f" Hits/Misses: {s['cache_hits']}/{s['cache_misses']} ({ratio:.1f}% hit)") | |
| sections.append(f" Errors: {s['cache_errors']}, Timeouts: {s['redis_timeouts']}") | |
| sections.append(f" Pool: {s['redis_pool_active']} active, {s['redis_pool_idle']} idle") | |
| else: | |
| sections.append(" No stats available") | |
| except Exception as e: | |
| sections.append(f" FAILED: {e}") | |
| # 4. Config | |
| sections.append("\n[Configuration]") | |
| try: | |
| for name in GUC_PARAMS: | |
| rows = db.pg_query("SHOW " + name) | |
| val = list(rows[0].values())[0] if rows else "N/A" | |
| sections.append(f" {name} = {val}") | |
| except Exception as e: | |
| sections.append(f" FAILED: {e}") | |
| return "\n".join(sections) | |
| # --------------------------------------------------------------------------- | |
| # Watch mode (live stats refresh) | |
| # --------------------------------------------------------------------------- | |
| def watch_stats(db: DatabaseManager, interval: int = 3, count: int = 10): | |
| """Print cache stats repeatedly. Returns nothing (prints directly).""" | |
| print(f" Watching stats every {interval}s (press Ctrl+C to stop)...\n") | |
| header = f" {'time':>8} {'hits':>8} {'misses':>8} {'ratio':>7} {'errors':>7} {'stored':>8} {'plans':>6}" | |
| print(header) | |
| print(" " + "-" * (len(header) - 2)) | |
| prev_hits = prev_misses = None | |
| try: | |
| for _ in range(count): | |
| rows = db.pg_query("SELECT * FROM pg_plan_cache_stats()") | |
| if not rows: | |
| print(" (no stats)") | |
| break | |
| s = rows[0] | |
| h, m = int(s['cache_hits']), int(s['cache_misses']) | |
| total = h + m | |
| ratio = f"{h/total*100:.1f}%" if total > 0 else "N/A" | |
| plan_count = len(db.redis_keys("plan:*")) | |
| ts = time.strftime("%H:%M:%S") | |
| delta = "" | |
| if prev_hits is not None: | |
| dh = h - prev_hits | |
| dm = m - prev_misses | |
| if dh or dm: | |
| delta = f" (+{dh} hits, +{dm} misses)" | |
| print(f" {ts:>8} {h:>8} {m:>8} {ratio:>7} {s['cache_errors']:>7} {s['plans_stored']:>8} {plan_count:>6}{delta}") | |
| prev_hits, prev_misses = h, m | |
| time.sleep(interval) | |
| except KeyboardInterrupt: | |
| pass | |
| print() | |