File size: 5,267 Bytes
92ca806 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
#
# Exceptions and other classes
#
class ExceededContextLengthException(Exception):
"""Exception raised when an input exceeds a model's context length"""
class _LlamaStopwatch:
"""Track elapsed time for prompt processing and text generation"""
#
# Q: why don't you use llama_perf_context?
#
# A: comments in llama.h state to only use that in llama.cpp examples,
# and to do your own performance measurements instead.
#
# trying to use llama_perf_context leads to output with
# "0.00 ms per token" and "inf tokens per second"
#
def __init__(self):
self.pp_start_time = None
self.tg_start_time = None
self.wall_start_time = None
self.generic_start_time = None
self.pp_elapsed_time = 0
self.tg_elapsed_time = 0
self.wall_elapsed_time = 0
self.generic_elapsed_time = 0
self.n_pp_tokens = 0
self.n_tg_tokens = 0
def start_pp(self):
"""Start prompt processing stopwatch"""
self.pp_start_time = time.time_ns()
def stop_pp(self):
"""Stop prompt processing stopwatch"""
if self.pp_start_time is not None:
self.pp_elapsed_time += time.time_ns() - self.pp_start_time
self.pp_start_time = None
def start_tg(self):
"""Start text generation stopwatch"""
self.tg_start_time = time.time_ns()
def stop_tg(self):
"""Stop text generation stopwatch"""
if self.tg_start_time is not None:
self.tg_elapsed_time += time.time_ns() - self.tg_start_time
self.tg_start_time = None
def start_wall_time(self):
"""Start wall-time stopwatch"""
self.wall_start_time = time.time_ns()
def stop_wall_time(self):
"""Stop wall-time stopwatch"""
if self.wall_start_time is not None:
self.wall_elapsed_time += time.time_ns() - self.wall_start_time
self.wall_start_time = None
def start_generic(self):
"""Start generic stopwatch (not shown in print_stats)"""
self.generic_start_time = time.time_ns()
def stop_generic(self):
"""Stop generic stopwatch"""
if self.generic_start_time is not None:
self.generic_elapsed_time += time.time_ns() - self.generic_start_time
self.generic_start_time = None
def get_elapsed_time_pp(self) -> int:
"""Total nanoseconds elapsed during prompt processing"""
return self.pp_elapsed_time
def get_elapsed_time_tg(self) -> int:
"""Total nanoseconds elapsed during text generation"""
return self.tg_elapsed_time
def get_elapsed_wall_time(self) -> int:
"""Total wall-time nanoseconds elapsed"""
return self.wall_elapsed_time
def get_elapsed_time_generic(self) -> int:
"""Total generic nanoseconds elapsed"""
return self.generic_elapsed_time
def increment_pp_tokens(self, n: int):
if n < 0:
raise ValueError('negative increments are not allowed')
self.n_pp_tokens += n
def increment_tg_tokens(self, n: int):
if n < 0:
raise ValueError('negative increments are not allowed')
self.n_tg_tokens += n
def reset(self):
"""Reset the stopwatch to its original state"""
self.pp_start_time = None
self.tg_start_time = None
self.wall_start_time = None
self.generic_start_time = None
self.pp_elapsed_time = 0
self.tg_elapsed_time = 0
self.wall_elapsed_time = 0
self.generic_elapsed_time = 0
self.n_pp_tokens = 0
self.n_tg_tokens = 0
def print_stats(self):
"""Print performance statistics using current stopwatch state
#### NOTE:
The `n_tg_tokens` value will be equal to the number of calls to
llama_decode which have a batch size of 1, which is technically not
always equal to the number of tokens generated - it may be off by one."""
print(f"\n", end='', file=sys.stderr, flush=True)
if self.n_pp_tokens + self.n_tg_tokens == 0:
print_stopwatch(f'print_stats was called but no tokens were processed or generated')
if self.n_pp_tokens > 0:
pp_elapsed_ns = self.get_elapsed_time_pp()
pp_elapsed_ms = pp_elapsed_ns / 1e6
pp_elapsed_s = pp_elapsed_ns / 1e9
pp_tps = self.n_pp_tokens / pp_elapsed_s
print_stopwatch(
f'prompt processing: {self.n_pp_tokens:>7} tokens in {pp_elapsed_ms:>13.3f}ms '
f'({pp_tps:>10.2f} tok/s)'
)
if self.n_tg_tokens > 0:
tg_elapsed_ns = self.get_elapsed_time_tg()
tg_elapsed_ms = tg_elapsed_ns / 1e6
tg_elapsed_s = tg_elapsed_ns / 1e9
tg_tps = self.n_tg_tokens / tg_elapsed_s
print_stopwatch(
f' text generation: {self.n_tg_tokens:>7} tokens in {tg_elapsed_ms:>13.3f}ms '
f'({tg_tps:>10.2f} tok/s)'
)
wall_elapsed_ns = self.get_elapsed_wall_time()
wall_elapsed_ms = wall_elapsed_ns / 1e6
print_stopwatch(f" wall time:{' ' * 19}{wall_elapsed_ms:>13.3f}ms")
|