File size: 5,811 Bytes
8a37e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import json
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Generator

import psutil
import torch

import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from invokeai.app.services.invocation_stats.invocation_stats_common import (
    GESStatsNotFoundError,
    GraphExecutionStats,
    GraphExecutionStatsSummary,
    InvocationStatsSummary,
    ModelCacheStatsSummary,
    NodeExecutionStats,
    NodeExecutionStatsSummary,
)
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load.model_cache import CacheStats

# Size of 1GB in bytes.
GB = 2**30


class InvocationStatsService(InvocationStatsServiceBase):
    """Accumulate performance information about a running graph. Collects time spent in each node,
    as well as the maximum and current VRAM utilisation for CUDA systems"""

    def __init__(self):
        # Maps graph_execution_state_id to GraphExecutionStats.
        self._stats: dict[str, GraphExecutionStats] = {}
        # Maps graph_execution_state_id to model manager CacheStats.
        self._cache_stats: dict[str, CacheStats] = {}

    def start(self, invoker: Invoker) -> None:
        self._invoker = invoker

    @contextmanager
    def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Generator[None, None, None]:
        # This is to handle case of the model manager not being initialized, which happens
        # during some tests.
        services = self._invoker.services
        if not self._stats.get(graph_execution_state_id):
            # First time we're seeing this graph_execution_state_id.
            self._stats[graph_execution_state_id] = GraphExecutionStats()
            self._cache_stats[graph_execution_state_id] = CacheStats()

        # Record state before the invocation.
        start_time = time.time()
        start_ram = psutil.Process().memory_info().rss
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()

        assert services.model_manager.load is not None
        services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]

        try:
            # Let the invocation run.
            yield None
        finally:
            # Record state after the invocation.
            node_stats = NodeExecutionStats(
                invocation_type=invocation.get_type(),
                start_time=start_time,
                end_time=time.time(),
                start_ram_gb=start_ram / GB,
                end_ram_gb=psutil.Process().memory_info().rss / GB,
                peak_vram_gb=torch.cuda.max_memory_allocated() / GB if torch.cuda.is_available() else 0.0,
            )
            self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)

    def reset_stats(self):
        self._stats = {}
        self._cache_stats = {}

    def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
        graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
        node_stats_summaries = self._get_node_summaries(graph_execution_state_id)
        model_cache_stats_summary = self._get_model_cache_summary(graph_execution_state_id)
        vram_usage_gb = torch.cuda.memory_allocated() / GB if torch.cuda.is_available() else None

        return InvocationStatsSummary(
            graph_stats=graph_stats_summary,
            model_cache_stats=model_cache_stats_summary,
            node_stats=node_stats_summaries,
            vram_usage_gb=vram_usage_gb,
        )

    def log_stats(self, graph_execution_state_id: str) -> None:
        stats = self.get_stats(graph_execution_state_id)
        logger.info(str(stats))

    def dump_stats(self, graph_execution_state_id: str, output_path: Path) -> None:
        stats = self.get_stats(graph_execution_state_id)
        with open(output_path, "w") as f:
            f.write(json.dumps(stats.as_dict(), indent=2))

    def _get_model_cache_summary(self, graph_execution_state_id: str) -> ModelCacheStatsSummary:
        try:
            cache_stats = self._cache_stats[graph_execution_state_id]
        except KeyError as e:
            raise GESStatsNotFoundError(
                f"Attempted to get model cache statistics for unknown graph {graph_execution_state_id}: {e}."
            ) from e

        return ModelCacheStatsSummary(
            cache_hits=cache_stats.hits,
            cache_misses=cache_stats.misses,
            high_water_mark_gb=cache_stats.high_watermark / GB,
            cache_size_gb=cache_stats.cache_size / GB,
            total_usage_gb=sum(list(cache_stats.loaded_model_sizes.values())) / GB,
            models_cached=cache_stats.in_cache,
            models_cleared=cache_stats.cleared,
        )

    def _get_graph_summary(self, graph_execution_state_id: str) -> GraphExecutionStatsSummary:
        try:
            graph_stats = self._stats[graph_execution_state_id]
        except KeyError as e:
            raise GESStatsNotFoundError(
                f"Attempted to get graph statistics for unknown graph {graph_execution_state_id}: {e}."
            ) from e

        return graph_stats.get_graph_stats_summary(graph_execution_state_id)

    def _get_node_summaries(self, graph_execution_state_id: str) -> list[NodeExecutionStatsSummary]:
        try:
            graph_stats = self._stats[graph_execution_state_id]
        except KeyError as e:
            raise GESStatsNotFoundError(
                f"Attempted to get node statistics for unknown graph {graph_execution_state_id}: {e}."
            ) from e

        return graph_stats.get_node_stats_summaries()