Learn2Splat / optgs /misc /benchmarker.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
import json
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from time import time
import numpy as np
import torch
class Benchmarker:
def __init__(self):
self.execution_times = defaultdict(list)
@contextmanager
def time(self, tag: str, num_calls: int = 1):
try:
start_time = time()
yield
finally:
end_time = time()
for _ in range(num_calls):
self.execution_times[tag].append((end_time - start_time) / num_calls)
def record(self, tag: str, elapsed_ms: float) -> None:
"""Record a pre-measured elapsed time (in milliseconds) under the given tag."""
self.execution_times[tag].append(elapsed_ms)
def merge(self, other: "Benchmarker") -> None:
"""Merge another benchmarker's recorded times into this one."""
for tag, times in other.execution_times.items():
self.execution_times[tag].extend(times)
def dump(self, path: Path) -> None:
path.parent.mkdir(exist_ok=True, parents=True)
with path.open("w") as f:
json.dump(dict(self.execution_times), f)
def dump_memory(self, path: Path) -> None:
path.parent.mkdir(exist_ok=True, parents=True)
with path.open("w") as f:
json.dump(torch.cuda.memory_stats()["allocated_bytes.all.peak"], f)
def summarize(self) -> None:
for tag, times in self.execution_times.items():
print(f"{tag}: {len(times)} calls, avg {np.mean(times):.1f} ms/call, total {sum(times)/1000:.1f} s")
def clear_history(self) -> None:
self.execution_times = defaultdict(list)