ddh0's picture
Upload 11 files
92ca806 verified
#
# Exceptions and other classes
#
class ExceededContextLengthException(Exception):
"""Exception raised when an input exceeds a model's context length"""
class _LlamaStopwatch:
"""Track elapsed time for prompt processing and text generation"""
#
# Q: why don't you use llama_perf_context?
#
# A: comments in llama.h state to only use that in llama.cpp examples,
# and to do your own performance measurements instead.
#
# trying to use llama_perf_context leads to output with
# "0.00 ms per token" and "inf tokens per second"
#
def __init__(self):
self.pp_start_time = None
self.tg_start_time = None
self.wall_start_time = None
self.generic_start_time = None
self.pp_elapsed_time = 0
self.tg_elapsed_time = 0
self.wall_elapsed_time = 0
self.generic_elapsed_time = 0
self.n_pp_tokens = 0
self.n_tg_tokens = 0
def start_pp(self):
"""Start prompt processing stopwatch"""
self.pp_start_time = time.time_ns()
def stop_pp(self):
"""Stop prompt processing stopwatch"""
if self.pp_start_time is not None:
self.pp_elapsed_time += time.time_ns() - self.pp_start_time
self.pp_start_time = None
def start_tg(self):
"""Start text generation stopwatch"""
self.tg_start_time = time.time_ns()
def stop_tg(self):
"""Stop text generation stopwatch"""
if self.tg_start_time is not None:
self.tg_elapsed_time += time.time_ns() - self.tg_start_time
self.tg_start_time = None
def start_wall_time(self):
"""Start wall-time stopwatch"""
self.wall_start_time = time.time_ns()
def stop_wall_time(self):
"""Stop wall-time stopwatch"""
if self.wall_start_time is not None:
self.wall_elapsed_time += time.time_ns() - self.wall_start_time
self.wall_start_time = None
def start_generic(self):
"""Start generic stopwatch (not shown in print_stats)"""
self.generic_start_time = time.time_ns()
def stop_generic(self):
"""Stop generic stopwatch"""
if self.generic_start_time is not None:
self.generic_elapsed_time += time.time_ns() - self.generic_start_time
self.generic_start_time = None
def get_elapsed_time_pp(self) -> int:
"""Total nanoseconds elapsed during prompt processing"""
return self.pp_elapsed_time
def get_elapsed_time_tg(self) -> int:
"""Total nanoseconds elapsed during text generation"""
return self.tg_elapsed_time
def get_elapsed_wall_time(self) -> int:
"""Total wall-time nanoseconds elapsed"""
return self.wall_elapsed_time
def get_elapsed_time_generic(self) -> int:
"""Total generic nanoseconds elapsed"""
return self.generic_elapsed_time
def increment_pp_tokens(self, n: int):
if n < 0:
raise ValueError('negative increments are not allowed')
self.n_pp_tokens += n
def increment_tg_tokens(self, n: int):
if n < 0:
raise ValueError('negative increments are not allowed')
self.n_tg_tokens += n
def reset(self):
"""Reset the stopwatch to its original state"""
self.pp_start_time = None
self.tg_start_time = None
self.wall_start_time = None
self.generic_start_time = None
self.pp_elapsed_time = 0
self.tg_elapsed_time = 0
self.wall_elapsed_time = 0
self.generic_elapsed_time = 0
self.n_pp_tokens = 0
self.n_tg_tokens = 0
def print_stats(self):
"""Print performance statistics using current stopwatch state
#### NOTE:
The `n_tg_tokens` value will be equal to the number of calls to
llama_decode which have a batch size of 1, which is technically not
always equal to the number of tokens generated - it may be off by one."""
print(f"\n", end='', file=sys.stderr, flush=True)
if self.n_pp_tokens + self.n_tg_tokens == 0:
print_stopwatch(f'print_stats was called but no tokens were processed or generated')
if self.n_pp_tokens > 0:
pp_elapsed_ns = self.get_elapsed_time_pp()
pp_elapsed_ms = pp_elapsed_ns / 1e6
pp_elapsed_s = pp_elapsed_ns / 1e9
pp_tps = self.n_pp_tokens / pp_elapsed_s
print_stopwatch(
f'prompt processing: {self.n_pp_tokens:>7} tokens in {pp_elapsed_ms:>13.3f}ms '
f'({pp_tps:>10.2f} tok/s)'
)
if self.n_tg_tokens > 0:
tg_elapsed_ns = self.get_elapsed_time_tg()
tg_elapsed_ms = tg_elapsed_ns / 1e6
tg_elapsed_s = tg_elapsed_ns / 1e9
tg_tps = self.n_tg_tokens / tg_elapsed_s
print_stopwatch(
f' text generation: {self.n_tg_tokens:>7} tokens in {tg_elapsed_ms:>13.3f}ms '
f'({tg_tps:>10.2f} tok/s)'
)
wall_elapsed_ns = self.get_elapsed_wall_time()
wall_elapsed_ms = wall_elapsed_ns / 1e6
print_stopwatch(f" wall time:{' ' * 19}{wall_elapsed_ms:>13.3f}ms")