beaupreda's picture
Upload sensAI-Generic-Object-Detection with upload_repo.py
13170f7 verified
Raw
History Blame Contribute Delete
13 kB
"""CPU profiler Gradio tab for Eve HMI demos.
Provides a UI to start/stop ``cProfile`` on the main process and all worker
processes, collect stats, display them in a sortable table, and export
``.prof`` files for offline analysis (e.g. with ``snakeviz``).
Gated behind ``ENABLE_PROFILER=1`` — callers should check the env var before
importing this module.
"""
import cProfile
import io
import marshal
import os
import pstats
import time
from pathlib import Path
import gradio as gr
from eve_worker_pool import EveWorkerPool
from log_utils import setup_logger
logger = setup_logger("Profiler")
_TMP_DIR = Path(__file__).resolve().parents[2] / "tmp"
class _ProfilerState:
"""Manages cProfile instances for main + workers."""
def __init__(self, pool: EveWorkerPool) -> None:
self._pool = pool
self._main_profiler = cProfile.Profile()
self._running = False
# Collected pstats.Stats dicts keyed by source label
self._collected: dict[str, dict] = {}
# SDK timing data: {worker_label: {call_name: (count, total_seconds)}}
self._sdk_timings: dict[str, dict[str, tuple[int, float]]] = {}
@property
def running(self) -> bool:
return self._running
def start(self) -> str:
if self._running:
return "Profiling already running."
self._running = True
self._main_profiler.enable()
errors: list[str] = []
workers = self._pool.get_live_workers()
for w in workers:
try:
w.send_start_profiling()
except Exception as exc:
errors.append(f"w{w.worker_id}: {exc}")
msg = f"Profiling started (main + {len(workers)} workers)."
if errors:
msg += f" Errors: {'; '.join(errors)}"
logger.info(msg)
return msg
def stop(self) -> str:
if not self._running:
return "Profiling not running."
self._running = False
self._main_profiler.disable()
errors: list[str] = []
workers = self._pool.get_live_workers()
for w in workers:
try:
w.send_stop_profiling()
except Exception as exc:
errors.append(f"w{w.worker_id}: {exc}")
msg = f"Profiling stopped (main + {len(workers)} workers)."
if errors:
msg += f" Errors: {'; '.join(errors)}"
logger.info(msg)
return msg
def reset(self) -> str:
was_running = self._running
if was_running:
self.stop()
self._main_profiler = cProfile.Profile()
self._collected.clear()
msg = "Profiler reset."
if was_running:
self.start()
msg += " Profiling restarted with fresh profiler."
logger.info(msg)
return msg
def collect(self) -> str:
"""Collect stats from main process and all workers."""
was_running = self._running
if was_running:
self._main_profiler.disable()
self._collected.clear()
self._sdk_timings.clear()
# Main process stats
stream = io.StringIO()
ps = pstats.Stats(self._main_profiler, stream=stream)
self._collected["main"] = ps.stats
if was_running:
self._main_profiler.enable()
# Worker stats (cProfile + SDK timings)
errors: list[str] = []
workers = self._pool.get_live_workers()
for w in workers:
try:
data = w.send_get_profile_stats()
if data:
self._collected[f"worker-{w.worker_id}"] = marshal.loads(data)
# Workers auto-disable on GetProfileStatsCmd; re-enable if running
if was_running:
w.send_start_profiling()
except Exception as exc:
errors.append(f"w{w.worker_id}: {exc}")
try:
timing = w.send_get_timing_stats(reset=True)
if timing:
self._sdk_timings[f"worker-{w.worker_id}"] = timing
except Exception as exc:
errors.append(f"w{w.worker_id} timing: {exc}")
n_sources = len(self._collected)
msg = f"Collected stats from {n_sources} source(s)."
if self._sdk_timings:
msg += f" SDK timings from {len(self._sdk_timings)} worker(s)."
if errors:
msg += f" Errors: {'; '.join(errors)}"
logger.info(msg)
return msg
def get_table_data(
self,
source: str = "All",
sort_by: str = "cumtime",
filter_text: str = "",
limit: int = 100,
) -> list[list]:
"""Return profile data as rows for a Gradio Dataframe."""
merged = self._merge_stats(source)
if not merged:
return []
rows: list[list] = []
for key, (nc, totcalls, tottime, cumtime, callers) in merged.items():
filename, lineno, funcname = key
func_label = f"{filename}:{lineno}({funcname})"
if filter_text and filter_text.lower() not in func_label.lower():
continue
rows.append(
[
func_label,
totcalls,
round(tottime, 6),
round(tottime / totcalls, 6) if totcalls else 0.0,
round(cumtime, 6),
round(cumtime / nc, 6) if nc else 0.0,
]
)
sort_col = {
"ncalls": 1,
"tottime": 2,
"tottime/call": 3,
"cumtime": 4,
"cumtime/call": 5,
}.get(sort_by, 4)
rows.sort(key=lambda r: r[sort_col], reverse=True)
return rows[:limit]
def get_sdk_timing_data(self) -> list[list]:
"""Return merged SDK per-call timing as rows for a Gradio Dataframe."""
if not self._sdk_timings:
return []
# Merge across workers
merged: dict[str, list[float]] = {} # {name: [count, total]}
for worker_stats in self._sdk_timings.values():
for name, (count, total) in worker_stats.items():
if name in merged:
merged[name][0] += count
merged[name][1] += total
else:
merged[name] = [count, total]
rows: list[list] = []
for name, (count, total) in merged.items():
avg_ms = (total / count * 1000) if count else 0.0
rows.append([name, int(count), round(total, 4), round(avg_ms, 3)])
rows.sort(key=lambda r: r[2], reverse=True)
return rows
def export(self) -> str:
"""Export collected stats to .prof files in tmp/."""
_TMP_DIR.mkdir(exist_ok=True)
ts = time.strftime("%Y%m%d_%H%M%S")
exported: list[str] = []
for source, stats_dict in self._collected.items():
path = _TMP_DIR / f"profile_{source}_{ts}.prof"
marshal_path = _TMP_DIR / f"profile_{source}_{ts}.marshal"
# Write marshalled stats that pstats can reload
with open(marshal_path, "wb") as f:
marshal.dump(stats_dict, f)
# Also write a human-readable text summary
txt_path = _TMP_DIR / f"profile_{source}_{ts}.txt"
stream = io.StringIO()
ps = pstats.Stats(str(marshal_path), stream=stream)
ps.sort_stats("cumulative")
ps.print_stats(200)
with open(txt_path, "w") as f:
f.write(stream.getvalue())
# Write binary .prof via pstats dump
ps.dump_stats(str(path))
exported.append(str(path))
os.remove(marshal_path)
if not exported:
return "No stats to export. Collect stats first."
msg = f"Exported {len(exported)} file(s) to tmp/: {', '.join(exported)}"
logger.info(msg)
return msg
def _merge_stats(self, source: str) -> dict:
"""Merge pstats dicts based on source selection."""
if source == "Main Process":
return self._collected.get("main", {})
if source == "Workers":
return self._combine([v for k, v in self._collected.items() if k != "main"])
# "All"
return self._combine(list(self._collected.values()))
@staticmethod
def _combine(stats_list: list[dict]) -> dict:
"""Combine multiple pstats.Stats.stats dicts into one."""
if not stats_list:
return {}
if len(stats_list) == 1:
return stats_list[0]
merged: dict = {}
for stats_dict in stats_list:
for key, (nc, totcalls, tottime, cumtime, callers) in stats_dict.items():
if key in merged:
onc, ototcalls, otottime, ocumtime, ocallers = merged[key]
new_callers = {**ocallers}
for ck, cv in callers.items():
if ck in new_callers:
occ = new_callers[ck]
new_callers[ck] = tuple(a + b for a, b in zip(occ, cv))
else:
new_callers[ck] = cv
merged[key] = (
onc + nc,
ototcalls + totcalls,
otottime + tottime,
ocumtime + cumtime,
new_callers,
)
else:
merged[key] = (nc, totcalls, tottime, cumtime, dict(callers))
return merged
def build_profiler_tab(pool: EveWorkerPool) -> gr.TabItem:
"""Build the CPU Profiler tab. Must be called inside a gr.Blocks context."""
state = _ProfilerState(pool)
with gr.TabItem("CPU Profiler") as tab:
gr.Markdown(
"Profile CPU usage across the main process and all Eve SDK workers. "
"Start profiling, exercise the app, then collect and view stats."
)
status_text = gr.Textbox(label="Status", value="Profiler idle.", interactive=False)
with gr.Row():
start_btn = gr.Button("Start Profiling", variant="primary")
stop_btn = gr.Button("Stop Profiling")
reset_btn = gr.Button("Reset")
with gr.Row():
collect_btn = gr.Button("Collect Stats", variant="primary")
export_btn = gr.Button("Export to tmp/")
with gr.Row():
source_radio = gr.Radio(
choices=["All", "Main Process", "Workers"],
value="All",
label="Source",
)
sort_radio = gr.Radio(
choices=["cumtime", "tottime", "ncalls", "tottime/call", "cumtime/call"],
value="cumtime",
label="Sort by",
)
filter_input = gr.Textbox(
label="Filter (function name / module path)",
placeholder="e.g. eve_wrapper, inference, encode",
)
stats_table = gr.Dataframe(
headers=["Function", "ncalls", "tottime", "tottime/call", "cumtime", "cumtime/call"],
datatype=["str", "number", "number", "number", "number", "number"],
label="cProfile Stats (CPU time only — I/O wait excluded)",
interactive=False,
wrap=True,
)
gr.Markdown("### Eve SDK Call Timing (wall-clock, all workers merged)")
sdk_timing_table = gr.Dataframe(
headers=["SDK Call", "ncalls", "total (s)", "avg (ms)"],
datatype=["str", "number", "number", "number"],
label="Eve SDK Per-Call Timing",
interactive=False,
)
# --- Event handlers ---
def _start() -> str:
return state.start()
def _stop() -> str:
return state.stop()
def _reset() -> str:
return state.reset()
def _collect() -> tuple[str, list[list], list[list]]:
msg = state.collect()
rows = state.get_table_data()
sdk_rows = state.get_sdk_timing_data()
return msg, rows, sdk_rows
def _export() -> str:
return state.export()
def _refresh_table(source: str, sort_by: str, filter_text: str) -> list[list]:
return state.get_table_data(source=source, sort_by=sort_by, filter_text=filter_text)
start_btn.click(fn=_start, outputs=[status_text])
stop_btn.click(fn=_stop, outputs=[status_text])
reset_btn.click(fn=_reset, outputs=[status_text])
collect_btn.click(fn=_collect, outputs=[status_text, stats_table, sdk_timing_table])
export_btn.click(fn=_export, outputs=[status_text])
for component in (source_radio, sort_radio, filter_input):
component.change(
fn=_refresh_table,
inputs=[source_radio, sort_radio, filter_input],
outputs=[stats_table],
)
return tab