Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """CPU profiler Gradio tab for Eve HMI demos. | |
| Provides a UI to start/stop ``cProfile`` on the main process and all worker | |
| processes, collect stats, display them in a sortable table, and export | |
| ``.prof`` files for offline analysis (e.g. with ``snakeviz``). | |
| Gated behind ``ENABLE_PROFILER=1`` — callers should check the env var before | |
| importing this module. | |
| """ | |
| import cProfile | |
| import io | |
| import marshal | |
| import os | |
| import pstats | |
| import time | |
| from pathlib import Path | |
| import gradio as gr | |
| from eve_worker_pool import EveWorkerPool | |
| from log_utils import setup_logger | |
| logger = setup_logger("Profiler") | |
| _TMP_DIR = Path(__file__).resolve().parents[2] / "tmp" | |
| class _ProfilerState: | |
| """Manages cProfile instances for main + workers.""" | |
| def __init__(self, pool: EveWorkerPool) -> None: | |
| self._pool = pool | |
| self._main_profiler = cProfile.Profile() | |
| self._running = False | |
| # Collected pstats.Stats dicts keyed by source label | |
| self._collected: dict[str, dict] = {} | |
| # SDK timing data: {worker_label: {call_name: (count, total_seconds)}} | |
| self._sdk_timings: dict[str, dict[str, tuple[int, float]]] = {} | |
| def running(self) -> bool: | |
| return self._running | |
| def start(self) -> str: | |
| if self._running: | |
| return "Profiling already running." | |
| self._running = True | |
| self._main_profiler.enable() | |
| errors: list[str] = [] | |
| workers = self._pool.get_live_workers() | |
| for w in workers: | |
| try: | |
| w.send_start_profiling() | |
| except Exception as exc: | |
| errors.append(f"w{w.worker_id}: {exc}") | |
| msg = f"Profiling started (main + {len(workers)} workers)." | |
| if errors: | |
| msg += f" Errors: {'; '.join(errors)}" | |
| logger.info(msg) | |
| return msg | |
| def stop(self) -> str: | |
| if not self._running: | |
| return "Profiling not running." | |
| self._running = False | |
| self._main_profiler.disable() | |
| errors: list[str] = [] | |
| workers = self._pool.get_live_workers() | |
| for w in workers: | |
| try: | |
| w.send_stop_profiling() | |
| except Exception as exc: | |
| errors.append(f"w{w.worker_id}: {exc}") | |
| msg = f"Profiling stopped (main + {len(workers)} workers)." | |
| if errors: | |
| msg += f" Errors: {'; '.join(errors)}" | |
| logger.info(msg) | |
| return msg | |
| def reset(self) -> str: | |
| was_running = self._running | |
| if was_running: | |
| self.stop() | |
| self._main_profiler = cProfile.Profile() | |
| self._collected.clear() | |
| msg = "Profiler reset." | |
| if was_running: | |
| self.start() | |
| msg += " Profiling restarted with fresh profiler." | |
| logger.info(msg) | |
| return msg | |
| def collect(self) -> str: | |
| """Collect stats from main process and all workers.""" | |
| was_running = self._running | |
| if was_running: | |
| self._main_profiler.disable() | |
| self._collected.clear() | |
| self._sdk_timings.clear() | |
| # Main process stats | |
| stream = io.StringIO() | |
| ps = pstats.Stats(self._main_profiler, stream=stream) | |
| self._collected["main"] = ps.stats | |
| if was_running: | |
| self._main_profiler.enable() | |
| # Worker stats (cProfile + SDK timings) | |
| errors: list[str] = [] | |
| workers = self._pool.get_live_workers() | |
| for w in workers: | |
| try: | |
| data = w.send_get_profile_stats() | |
| if data: | |
| self._collected[f"worker-{w.worker_id}"] = marshal.loads(data) | |
| # Workers auto-disable on GetProfileStatsCmd; re-enable if running | |
| if was_running: | |
| w.send_start_profiling() | |
| except Exception as exc: | |
| errors.append(f"w{w.worker_id}: {exc}") | |
| try: | |
| timing = w.send_get_timing_stats(reset=True) | |
| if timing: | |
| self._sdk_timings[f"worker-{w.worker_id}"] = timing | |
| except Exception as exc: | |
| errors.append(f"w{w.worker_id} timing: {exc}") | |
| n_sources = len(self._collected) | |
| msg = f"Collected stats from {n_sources} source(s)." | |
| if self._sdk_timings: | |
| msg += f" SDK timings from {len(self._sdk_timings)} worker(s)." | |
| if errors: | |
| msg += f" Errors: {'; '.join(errors)}" | |
| logger.info(msg) | |
| return msg | |
| def get_table_data( | |
| self, | |
| source: str = "All", | |
| sort_by: str = "cumtime", | |
| filter_text: str = "", | |
| limit: int = 100, | |
| ) -> list[list]: | |
| """Return profile data as rows for a Gradio Dataframe.""" | |
| merged = self._merge_stats(source) | |
| if not merged: | |
| return [] | |
| rows: list[list] = [] | |
| for key, (nc, totcalls, tottime, cumtime, callers) in merged.items(): | |
| filename, lineno, funcname = key | |
| func_label = f"{filename}:{lineno}({funcname})" | |
| if filter_text and filter_text.lower() not in func_label.lower(): | |
| continue | |
| rows.append( | |
| [ | |
| func_label, | |
| totcalls, | |
| round(tottime, 6), | |
| round(tottime / totcalls, 6) if totcalls else 0.0, | |
| round(cumtime, 6), | |
| round(cumtime / nc, 6) if nc else 0.0, | |
| ] | |
| ) | |
| sort_col = { | |
| "ncalls": 1, | |
| "tottime": 2, | |
| "tottime/call": 3, | |
| "cumtime": 4, | |
| "cumtime/call": 5, | |
| }.get(sort_by, 4) | |
| rows.sort(key=lambda r: r[sort_col], reverse=True) | |
| return rows[:limit] | |
| def get_sdk_timing_data(self) -> list[list]: | |
| """Return merged SDK per-call timing as rows for a Gradio Dataframe.""" | |
| if not self._sdk_timings: | |
| return [] | |
| # Merge across workers | |
| merged: dict[str, list[float]] = {} # {name: [count, total]} | |
| for worker_stats in self._sdk_timings.values(): | |
| for name, (count, total) in worker_stats.items(): | |
| if name in merged: | |
| merged[name][0] += count | |
| merged[name][1] += total | |
| else: | |
| merged[name] = [count, total] | |
| rows: list[list] = [] | |
| for name, (count, total) in merged.items(): | |
| avg_ms = (total / count * 1000) if count else 0.0 | |
| rows.append([name, int(count), round(total, 4), round(avg_ms, 3)]) | |
| rows.sort(key=lambda r: r[2], reverse=True) | |
| return rows | |
| def export(self) -> str: | |
| """Export collected stats to .prof files in tmp/.""" | |
| _TMP_DIR.mkdir(exist_ok=True) | |
| ts = time.strftime("%Y%m%d_%H%M%S") | |
| exported: list[str] = [] | |
| for source, stats_dict in self._collected.items(): | |
| path = _TMP_DIR / f"profile_{source}_{ts}.prof" | |
| marshal_path = _TMP_DIR / f"profile_{source}_{ts}.marshal" | |
| # Write marshalled stats that pstats can reload | |
| with open(marshal_path, "wb") as f: | |
| marshal.dump(stats_dict, f) | |
| # Also write a human-readable text summary | |
| txt_path = _TMP_DIR / f"profile_{source}_{ts}.txt" | |
| stream = io.StringIO() | |
| ps = pstats.Stats(str(marshal_path), stream=stream) | |
| ps.sort_stats("cumulative") | |
| ps.print_stats(200) | |
| with open(txt_path, "w") as f: | |
| f.write(stream.getvalue()) | |
| # Write binary .prof via pstats dump | |
| ps.dump_stats(str(path)) | |
| exported.append(str(path)) | |
| os.remove(marshal_path) | |
| if not exported: | |
| return "No stats to export. Collect stats first." | |
| msg = f"Exported {len(exported)} file(s) to tmp/: {', '.join(exported)}" | |
| logger.info(msg) | |
| return msg | |
| def _merge_stats(self, source: str) -> dict: | |
| """Merge pstats dicts based on source selection.""" | |
| if source == "Main Process": | |
| return self._collected.get("main", {}) | |
| if source == "Workers": | |
| return self._combine([v for k, v in self._collected.items() if k != "main"]) | |
| # "All" | |
| return self._combine(list(self._collected.values())) | |
| def _combine(stats_list: list[dict]) -> dict: | |
| """Combine multiple pstats.Stats.stats dicts into one.""" | |
| if not stats_list: | |
| return {} | |
| if len(stats_list) == 1: | |
| return stats_list[0] | |
| merged: dict = {} | |
| for stats_dict in stats_list: | |
| for key, (nc, totcalls, tottime, cumtime, callers) in stats_dict.items(): | |
| if key in merged: | |
| onc, ototcalls, otottime, ocumtime, ocallers = merged[key] | |
| new_callers = {**ocallers} | |
| for ck, cv in callers.items(): | |
| if ck in new_callers: | |
| occ = new_callers[ck] | |
| new_callers[ck] = tuple(a + b for a, b in zip(occ, cv)) | |
| else: | |
| new_callers[ck] = cv | |
| merged[key] = ( | |
| onc + nc, | |
| ototcalls + totcalls, | |
| otottime + tottime, | |
| ocumtime + cumtime, | |
| new_callers, | |
| ) | |
| else: | |
| merged[key] = (nc, totcalls, tottime, cumtime, dict(callers)) | |
| return merged | |
| def build_profiler_tab(pool: EveWorkerPool) -> gr.TabItem: | |
| """Build the CPU Profiler tab. Must be called inside a gr.Blocks context.""" | |
| state = _ProfilerState(pool) | |
| with gr.TabItem("CPU Profiler") as tab: | |
| gr.Markdown( | |
| "Profile CPU usage across the main process and all Eve SDK workers. " | |
| "Start profiling, exercise the app, then collect and view stats." | |
| ) | |
| status_text = gr.Textbox(label="Status", value="Profiler idle.", interactive=False) | |
| with gr.Row(): | |
| start_btn = gr.Button("Start Profiling", variant="primary") | |
| stop_btn = gr.Button("Stop Profiling") | |
| reset_btn = gr.Button("Reset") | |
| with gr.Row(): | |
| collect_btn = gr.Button("Collect Stats", variant="primary") | |
| export_btn = gr.Button("Export to tmp/") | |
| with gr.Row(): | |
| source_radio = gr.Radio( | |
| choices=["All", "Main Process", "Workers"], | |
| value="All", | |
| label="Source", | |
| ) | |
| sort_radio = gr.Radio( | |
| choices=["cumtime", "tottime", "ncalls", "tottime/call", "cumtime/call"], | |
| value="cumtime", | |
| label="Sort by", | |
| ) | |
| filter_input = gr.Textbox( | |
| label="Filter (function name / module path)", | |
| placeholder="e.g. eve_wrapper, inference, encode", | |
| ) | |
| stats_table = gr.Dataframe( | |
| headers=["Function", "ncalls", "tottime", "tottime/call", "cumtime", "cumtime/call"], | |
| datatype=["str", "number", "number", "number", "number", "number"], | |
| label="cProfile Stats (CPU time only — I/O wait excluded)", | |
| interactive=False, | |
| wrap=True, | |
| ) | |
| gr.Markdown("### Eve SDK Call Timing (wall-clock, all workers merged)") | |
| sdk_timing_table = gr.Dataframe( | |
| headers=["SDK Call", "ncalls", "total (s)", "avg (ms)"], | |
| datatype=["str", "number", "number", "number"], | |
| label="Eve SDK Per-Call Timing", | |
| interactive=False, | |
| ) | |
| # --- Event handlers --- | |
| def _start() -> str: | |
| return state.start() | |
| def _stop() -> str: | |
| return state.stop() | |
| def _reset() -> str: | |
| return state.reset() | |
| def _collect() -> tuple[str, list[list], list[list]]: | |
| msg = state.collect() | |
| rows = state.get_table_data() | |
| sdk_rows = state.get_sdk_timing_data() | |
| return msg, rows, sdk_rows | |
| def _export() -> str: | |
| return state.export() | |
| def _refresh_table(source: str, sort_by: str, filter_text: str) -> list[list]: | |
| return state.get_table_data(source=source, sort_by=sort_by, filter_text=filter_text) | |
| start_btn.click(fn=_start, outputs=[status_text]) | |
| stop_btn.click(fn=_stop, outputs=[status_text]) | |
| reset_btn.click(fn=_reset, outputs=[status_text]) | |
| collect_btn.click(fn=_collect, outputs=[status_text, stats_table, sdk_timing_table]) | |
| export_btn.click(fn=_export, outputs=[status_text]) | |
| for component in (source_radio, sort_radio, filter_input): | |
| component.change( | |
| fn=_refresh_table, | |
| inputs=[source_radio, sort_radio, filter_input], | |
| outputs=[stats_table], | |
| ) | |
| return tab | |