"""CPU profiler Gradio tab for Eve HMI demos. Provides a UI to start/stop ``cProfile`` on the main process and all worker processes, collect stats, display them in a sortable table, and export ``.prof`` files for offline analysis (e.g. with ``snakeviz``). Gated behind ``ENABLE_PROFILER=1`` — callers should check the env var before importing this module. """ import cProfile import io import marshal import os import pstats import time from pathlib import Path import gradio as gr from eve_worker_pool import EveWorkerPool from log_utils import setup_logger logger = setup_logger("Profiler") _TMP_DIR = Path(__file__).resolve().parents[2] / "tmp" class _ProfilerState: """Manages cProfile instances for main + workers.""" def __init__(self, pool: EveWorkerPool) -> None: self._pool = pool self._main_profiler = cProfile.Profile() self._running = False # Collected pstats.Stats dicts keyed by source label self._collected: dict[str, dict] = {} # SDK timing data: {worker_label: {call_name: (count, total_seconds)}} self._sdk_timings: dict[str, dict[str, tuple[int, float]]] = {} @property def running(self) -> bool: return self._running def start(self) -> str: if self._running: return "Profiling already running." self._running = True self._main_profiler.enable() errors: list[str] = [] workers = self._pool.get_live_workers() for w in workers: try: w.send_start_profiling() except Exception as exc: errors.append(f"w{w.worker_id}: {exc}") msg = f"Profiling started (main + {len(workers)} workers)." if errors: msg += f" Errors: {'; '.join(errors)}" logger.info(msg) return msg def stop(self) -> str: if not self._running: return "Profiling not running." self._running = False self._main_profiler.disable() errors: list[str] = [] workers = self._pool.get_live_workers() for w in workers: try: w.send_stop_profiling() except Exception as exc: errors.append(f"w{w.worker_id}: {exc}") msg = f"Profiling stopped (main + {len(workers)} workers)." if errors: msg += f" Errors: {'; '.join(errors)}" logger.info(msg) return msg def reset(self) -> str: was_running = self._running if was_running: self.stop() self._main_profiler = cProfile.Profile() self._collected.clear() msg = "Profiler reset." if was_running: self.start() msg += " Profiling restarted with fresh profiler." logger.info(msg) return msg def collect(self) -> str: """Collect stats from main process and all workers.""" was_running = self._running if was_running: self._main_profiler.disable() self._collected.clear() self._sdk_timings.clear() # Main process stats stream = io.StringIO() ps = pstats.Stats(self._main_profiler, stream=stream) self._collected["main"] = ps.stats if was_running: self._main_profiler.enable() # Worker stats (cProfile + SDK timings) errors: list[str] = [] workers = self._pool.get_live_workers() for w in workers: try: data = w.send_get_profile_stats() if data: self._collected[f"worker-{w.worker_id}"] = marshal.loads(data) # Workers auto-disable on GetProfileStatsCmd; re-enable if running if was_running: w.send_start_profiling() except Exception as exc: errors.append(f"w{w.worker_id}: {exc}") try: timing = w.send_get_timing_stats(reset=True) if timing: self._sdk_timings[f"worker-{w.worker_id}"] = timing except Exception as exc: errors.append(f"w{w.worker_id} timing: {exc}") n_sources = len(self._collected) msg = f"Collected stats from {n_sources} source(s)." if self._sdk_timings: msg += f" SDK timings from {len(self._sdk_timings)} worker(s)." if errors: msg += f" Errors: {'; '.join(errors)}" logger.info(msg) return msg def get_table_data( self, source: str = "All", sort_by: str = "cumtime", filter_text: str = "", limit: int = 100, ) -> list[list]: """Return profile data as rows for a Gradio Dataframe.""" merged = self._merge_stats(source) if not merged: return [] rows: list[list] = [] for key, (nc, totcalls, tottime, cumtime, callers) in merged.items(): filename, lineno, funcname = key func_label = f"{filename}:{lineno}({funcname})" if filter_text and filter_text.lower() not in func_label.lower(): continue rows.append( [ func_label, totcalls, round(tottime, 6), round(tottime / totcalls, 6) if totcalls else 0.0, round(cumtime, 6), round(cumtime / nc, 6) if nc else 0.0, ] ) sort_col = { "ncalls": 1, "tottime": 2, "tottime/call": 3, "cumtime": 4, "cumtime/call": 5, }.get(sort_by, 4) rows.sort(key=lambda r: r[sort_col], reverse=True) return rows[:limit] def get_sdk_timing_data(self) -> list[list]: """Return merged SDK per-call timing as rows for a Gradio Dataframe.""" if not self._sdk_timings: return [] # Merge across workers merged: dict[str, list[float]] = {} # {name: [count, total]} for worker_stats in self._sdk_timings.values(): for name, (count, total) in worker_stats.items(): if name in merged: merged[name][0] += count merged[name][1] += total else: merged[name] = [count, total] rows: list[list] = [] for name, (count, total) in merged.items(): avg_ms = (total / count * 1000) if count else 0.0 rows.append([name, int(count), round(total, 4), round(avg_ms, 3)]) rows.sort(key=lambda r: r[2], reverse=True) return rows def export(self) -> str: """Export collected stats to .prof files in tmp/.""" _TMP_DIR.mkdir(exist_ok=True) ts = time.strftime("%Y%m%d_%H%M%S") exported: list[str] = [] for source, stats_dict in self._collected.items(): path = _TMP_DIR / f"profile_{source}_{ts}.prof" marshal_path = _TMP_DIR / f"profile_{source}_{ts}.marshal" # Write marshalled stats that pstats can reload with open(marshal_path, "wb") as f: marshal.dump(stats_dict, f) # Also write a human-readable text summary txt_path = _TMP_DIR / f"profile_{source}_{ts}.txt" stream = io.StringIO() ps = pstats.Stats(str(marshal_path), stream=stream) ps.sort_stats("cumulative") ps.print_stats(200) with open(txt_path, "w") as f: f.write(stream.getvalue()) # Write binary .prof via pstats dump ps.dump_stats(str(path)) exported.append(str(path)) os.remove(marshal_path) if not exported: return "No stats to export. Collect stats first." msg = f"Exported {len(exported)} file(s) to tmp/: {', '.join(exported)}" logger.info(msg) return msg def _merge_stats(self, source: str) -> dict: """Merge pstats dicts based on source selection.""" if source == "Main Process": return self._collected.get("main", {}) if source == "Workers": return self._combine([v for k, v in self._collected.items() if k != "main"]) # "All" return self._combine(list(self._collected.values())) @staticmethod def _combine(stats_list: list[dict]) -> dict: """Combine multiple pstats.Stats.stats dicts into one.""" if not stats_list: return {} if len(stats_list) == 1: return stats_list[0] merged: dict = {} for stats_dict in stats_list: for key, (nc, totcalls, tottime, cumtime, callers) in stats_dict.items(): if key in merged: onc, ototcalls, otottime, ocumtime, ocallers = merged[key] new_callers = {**ocallers} for ck, cv in callers.items(): if ck in new_callers: occ = new_callers[ck] new_callers[ck] = tuple(a + b for a, b in zip(occ, cv)) else: new_callers[ck] = cv merged[key] = ( onc + nc, ototcalls + totcalls, otottime + tottime, ocumtime + cumtime, new_callers, ) else: merged[key] = (nc, totcalls, tottime, cumtime, dict(callers)) return merged def build_profiler_tab(pool: EveWorkerPool) -> gr.TabItem: """Build the CPU Profiler tab. Must be called inside a gr.Blocks context.""" state = _ProfilerState(pool) with gr.TabItem("CPU Profiler") as tab: gr.Markdown( "Profile CPU usage across the main process and all Eve SDK workers. " "Start profiling, exercise the app, then collect and view stats." ) status_text = gr.Textbox(label="Status", value="Profiler idle.", interactive=False) with gr.Row(): start_btn = gr.Button("Start Profiling", variant="primary") stop_btn = gr.Button("Stop Profiling") reset_btn = gr.Button("Reset") with gr.Row(): collect_btn = gr.Button("Collect Stats", variant="primary") export_btn = gr.Button("Export to tmp/") with gr.Row(): source_radio = gr.Radio( choices=["All", "Main Process", "Workers"], value="All", label="Source", ) sort_radio = gr.Radio( choices=["cumtime", "tottime", "ncalls", "tottime/call", "cumtime/call"], value="cumtime", label="Sort by", ) filter_input = gr.Textbox( label="Filter (function name / module path)", placeholder="e.g. eve_wrapper, inference, encode", ) stats_table = gr.Dataframe( headers=["Function", "ncalls", "tottime", "tottime/call", "cumtime", "cumtime/call"], datatype=["str", "number", "number", "number", "number", "number"], label="cProfile Stats (CPU time only — I/O wait excluded)", interactive=False, wrap=True, ) gr.Markdown("### Eve SDK Call Timing (wall-clock, all workers merged)") sdk_timing_table = gr.Dataframe( headers=["SDK Call", "ncalls", "total (s)", "avg (ms)"], datatype=["str", "number", "number", "number"], label="Eve SDK Per-Call Timing", interactive=False, ) # --- Event handlers --- def _start() -> str: return state.start() def _stop() -> str: return state.stop() def _reset() -> str: return state.reset() def _collect() -> tuple[str, list[list], list[list]]: msg = state.collect() rows = state.get_table_data() sdk_rows = state.get_sdk_timing_data() return msg, rows, sdk_rows def _export() -> str: return state.export() def _refresh_table(source: str, sort_by: str, filter_text: str) -> list[list]: return state.get_table_data(source=source, sort_by=sort_by, filter_text=filter_text) start_btn.click(fn=_start, outputs=[status_text]) stop_btn.click(fn=_stop, outputs=[status_text]) reset_btn.click(fn=_reset, outputs=[status_text]) collect_btn.click(fn=_collect, outputs=[status_text, stats_table, sdk_timing_table]) export_btn.click(fn=_export, outputs=[status_text]) for component in (source_radio, sort_radio, filter_input): component.change( fn=_refresh_table, inputs=[source_radio, sort_radio, filter_input], outputs=[stats_table], ) return tab