File size: 5,267 Bytes
92ca806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

#
# Exceptions and other classes
#

class ExceededContextLengthException(Exception):
    """Exception raised when an input exceeds a model's context length"""

class _LlamaStopwatch:
    """Track elapsed time for prompt processing and text generation"""
    #
    # Q: why don't you use llama_perf_context?
    #
    # A: comments in llama.h state to only use that in llama.cpp examples,
    #    and to do your own performance measurements instead.
    #
    #    trying to use llama_perf_context leads to output with
    #    "0.00 ms per token" and "inf tokens per second"
    #
    def __init__(self):
        self.pp_start_time = None
        self.tg_start_time = None
        self.wall_start_time = None
        self.generic_start_time = None
        self.pp_elapsed_time = 0
        self.tg_elapsed_time = 0
        self.wall_elapsed_time = 0
        self.generic_elapsed_time = 0
        self.n_pp_tokens = 0
        self.n_tg_tokens = 0

    def start_pp(self):
        """Start prompt processing stopwatch"""
        self.pp_start_time = time.time_ns()

    def stop_pp(self):
        """Stop prompt processing stopwatch"""
        if self.pp_start_time is not None:
            self.pp_elapsed_time += time.time_ns() - self.pp_start_time
            self.pp_start_time = None

    def start_tg(self):
        """Start text generation stopwatch"""
        self.tg_start_time = time.time_ns()

    def stop_tg(self):
        """Stop text generation stopwatch"""
        if self.tg_start_time is not None:
            self.tg_elapsed_time += time.time_ns() - self.tg_start_time
            self.tg_start_time = None
    
    def start_wall_time(self):
        """Start wall-time stopwatch"""
        self.wall_start_time = time.time_ns()

    def stop_wall_time(self):
        """Stop wall-time stopwatch"""
        if self.wall_start_time is not None:
            self.wall_elapsed_time += time.time_ns() - self.wall_start_time
            self.wall_start_time = None

    def start_generic(self):
        """Start generic stopwatch (not shown in print_stats)"""
        self.generic_start_time = time.time_ns()
    
    def stop_generic(self):
        """Stop generic stopwatch"""
        if self.generic_start_time is not None:
            self.generic_elapsed_time += time.time_ns() - self.generic_start_time
            self.generic_start_time = None
    
    def get_elapsed_time_pp(self) -> int:
        """Total nanoseconds elapsed during prompt processing"""
        return self.pp_elapsed_time

    def get_elapsed_time_tg(self) -> int:
        """Total nanoseconds elapsed during text generation"""
        return self.tg_elapsed_time
    
    def get_elapsed_wall_time(self) -> int:
        """Total wall-time nanoseconds elapsed"""
        return self.wall_elapsed_time

    def get_elapsed_time_generic(self) -> int:
        """Total generic nanoseconds elapsed"""
        return self.generic_elapsed_time

    def increment_pp_tokens(self, n: int):
        if n < 0:
            raise ValueError('negative increments are not allowed')
        self.n_pp_tokens += n

    def increment_tg_tokens(self, n: int):
        if n < 0:
            raise ValueError('negative increments are not allowed')
        self.n_tg_tokens += n

    def reset(self):
        """Reset the stopwatch to its original state"""
        self.pp_start_time = None
        self.tg_start_time = None
        self.wall_start_time = None
        self.generic_start_time = None
        self.pp_elapsed_time = 0
        self.tg_elapsed_time = 0
        self.wall_elapsed_time = 0
        self.generic_elapsed_time = 0
        self.n_pp_tokens = 0
        self.n_tg_tokens = 0

    def print_stats(self):
        """Print performance statistics using current stopwatch state
        
        #### NOTE:
        The `n_tg_tokens` value will be equal to the number of calls to
        llama_decode which have a batch size of 1, which is technically not
        always equal to the number of tokens generated - it may be off by one."""

        print(f"\n", end='', file=sys.stderr, flush=True)

        if self.n_pp_tokens + self.n_tg_tokens == 0:
            print_stopwatch(f'print_stats was called but no tokens were processed or generated')

        if self.n_pp_tokens > 0:
            pp_elapsed_ns = self.get_elapsed_time_pp()
            pp_elapsed_ms = pp_elapsed_ns / 1e6
            pp_elapsed_s = pp_elapsed_ns / 1e9
            pp_tps = self.n_pp_tokens / pp_elapsed_s
            print_stopwatch(
                f'prompt processing: {self.n_pp_tokens:>7} tokens in {pp_elapsed_ms:>13.3f}ms '
                f'({pp_tps:>10.2f} tok/s)'
            )

        if self.n_tg_tokens > 0:
            tg_elapsed_ns = self.get_elapsed_time_tg()
            tg_elapsed_ms = tg_elapsed_ns / 1e6
            tg_elapsed_s = tg_elapsed_ns / 1e9
            tg_tps = self.n_tg_tokens / tg_elapsed_s
            print_stopwatch(
                f'  text generation: {self.n_tg_tokens:>7} tokens in {tg_elapsed_ms:>13.3f}ms '
                f'({tg_tps:>10.2f} tok/s)'
            )
        
        wall_elapsed_ns = self.get_elapsed_wall_time()
        wall_elapsed_ms = wall_elapsed_ns / 1e6
        print_stopwatch(f"        wall time:{' ' * 19}{wall_elapsed_ms:>13.3f}ms")