| | import json |
| | import time |
| | from contextlib import contextmanager |
| | from pathlib import Path |
| | from typing import Generator |
| |
|
| | import psutil |
| | import torch |
| |
|
| | import invokeai.backend.util.logging as logger |
| | from invokeai.app.invocations.baseinvocation import BaseInvocation |
| | from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase |
| | from invokeai.app.services.invocation_stats.invocation_stats_common import ( |
| | GESStatsNotFoundError, |
| | GraphExecutionStats, |
| | GraphExecutionStatsSummary, |
| | InvocationStatsSummary, |
| | ModelCacheStatsSummary, |
| | NodeExecutionStats, |
| | NodeExecutionStatsSummary, |
| | ) |
| | from invokeai.app.services.invoker import Invoker |
| | from invokeai.backend.model_manager.load.model_cache import CacheStats |
| |
|
| | |
| | GB = 2**30 |
| |
|
| |
|
| | class InvocationStatsService(InvocationStatsServiceBase): |
| | """Accumulate performance information about a running graph. Collects time spent in each node, |
| | as well as the maximum and current VRAM utilisation for CUDA systems""" |
| |
|
| | def __init__(self): |
| | |
| | self._stats: dict[str, GraphExecutionStats] = {} |
| | |
| | self._cache_stats: dict[str, CacheStats] = {} |
| |
|
| | def start(self, invoker: Invoker) -> None: |
| | self._invoker = invoker |
| |
|
| | @contextmanager |
| | def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Generator[None, None, None]: |
| | |
| | |
| | services = self._invoker.services |
| | if not self._stats.get(graph_execution_state_id): |
| | |
| | self._stats[graph_execution_state_id] = GraphExecutionStats() |
| | self._cache_stats[graph_execution_state_id] = CacheStats() |
| |
|
| | |
| | start_time = time.time() |
| | start_ram = psutil.Process().memory_info().rss |
| | if torch.cuda.is_available(): |
| | torch.cuda.reset_peak_memory_stats() |
| |
|
| | assert services.model_manager.load is not None |
| | services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id] |
| |
|
| | try: |
| | |
| | yield None |
| | finally: |
| | |
| | node_stats = NodeExecutionStats( |
| | invocation_type=invocation.get_type(), |
| | start_time=start_time, |
| | end_time=time.time(), |
| | start_ram_gb=start_ram / GB, |
| | end_ram_gb=psutil.Process().memory_info().rss / GB, |
| | peak_vram_gb=torch.cuda.max_memory_allocated() / GB if torch.cuda.is_available() else 0.0, |
| | ) |
| | self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) |
| |
|
| | def reset_stats(self): |
| | self._stats = {} |
| | self._cache_stats = {} |
| |
|
| | def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary: |
| | graph_stats_summary = self._get_graph_summary(graph_execution_state_id) |
| | node_stats_summaries = self._get_node_summaries(graph_execution_state_id) |
| | model_cache_stats_summary = self._get_model_cache_summary(graph_execution_state_id) |
| | vram_usage_gb = torch.cuda.memory_allocated() / GB if torch.cuda.is_available() else None |
| |
|
| | return InvocationStatsSummary( |
| | graph_stats=graph_stats_summary, |
| | model_cache_stats=model_cache_stats_summary, |
| | node_stats=node_stats_summaries, |
| | vram_usage_gb=vram_usage_gb, |
| | ) |
| |
|
| | def log_stats(self, graph_execution_state_id: str) -> None: |
| | stats = self.get_stats(graph_execution_state_id) |
| | logger.info(str(stats)) |
| |
|
| | def dump_stats(self, graph_execution_state_id: str, output_path: Path) -> None: |
| | stats = self.get_stats(graph_execution_state_id) |
| | with open(output_path, "w") as f: |
| | f.write(json.dumps(stats.as_dict(), indent=2)) |
| |
|
| | def _get_model_cache_summary(self, graph_execution_state_id: str) -> ModelCacheStatsSummary: |
| | try: |
| | cache_stats = self._cache_stats[graph_execution_state_id] |
| | except KeyError as e: |
| | raise GESStatsNotFoundError( |
| | f"Attempted to get model cache statistics for unknown graph {graph_execution_state_id}: {e}." |
| | ) from e |
| |
|
| | return ModelCacheStatsSummary( |
| | cache_hits=cache_stats.hits, |
| | cache_misses=cache_stats.misses, |
| | high_water_mark_gb=cache_stats.high_watermark / GB, |
| | cache_size_gb=cache_stats.cache_size / GB, |
| | total_usage_gb=sum(list(cache_stats.loaded_model_sizes.values())) / GB, |
| | models_cached=cache_stats.in_cache, |
| | models_cleared=cache_stats.cleared, |
| | ) |
| |
|
| | def _get_graph_summary(self, graph_execution_state_id: str) -> GraphExecutionStatsSummary: |
| | try: |
| | graph_stats = self._stats[graph_execution_state_id] |
| | except KeyError as e: |
| | raise GESStatsNotFoundError( |
| | f"Attempted to get graph statistics for unknown graph {graph_execution_state_id}: {e}." |
| | ) from e |
| |
|
| | return graph_stats.get_graph_stats_summary(graph_execution_state_id) |
| |
|
| | def _get_node_summaries(self, graph_execution_state_id: str) -> list[NodeExecutionStatsSummary]: |
| | try: |
| | graph_stats = self._stats[graph_execution_state_id] |
| | except KeyError as e: |
| | raise GESStatsNotFoundError( |
| | f"Attempted to get node statistics for unknown graph {graph_execution_state_id}: {e}." |
| | ) from e |
| |
|
| | return graph_stats.get_node_stats_summaries() |
| |
|