| # | |
| # Exceptions and other classes | |
| # | |
| class ExceededContextLengthException(Exception): | |
| """Exception raised when an input exceeds a model's context length""" | |
| class _LlamaStopwatch: | |
| """Track elapsed time for prompt processing and text generation""" | |
| # | |
| # Q: why don't you use llama_perf_context? | |
| # | |
| # A: comments in llama.h state to only use that in llama.cpp examples, | |
| # and to do your own performance measurements instead. | |
| # | |
| # trying to use llama_perf_context leads to output with | |
| # "0.00 ms per token" and "inf tokens per second" | |
| # | |
| def __init__(self): | |
| self.pp_start_time = None | |
| self.tg_start_time = None | |
| self.wall_start_time = None | |
| self.generic_start_time = None | |
| self.pp_elapsed_time = 0 | |
| self.tg_elapsed_time = 0 | |
| self.wall_elapsed_time = 0 | |
| self.generic_elapsed_time = 0 | |
| self.n_pp_tokens = 0 | |
| self.n_tg_tokens = 0 | |
| def start_pp(self): | |
| """Start prompt processing stopwatch""" | |
| self.pp_start_time = time.time_ns() | |
| def stop_pp(self): | |
| """Stop prompt processing stopwatch""" | |
| if self.pp_start_time is not None: | |
| self.pp_elapsed_time += time.time_ns() - self.pp_start_time | |
| self.pp_start_time = None | |
| def start_tg(self): | |
| """Start text generation stopwatch""" | |
| self.tg_start_time = time.time_ns() | |
| def stop_tg(self): | |
| """Stop text generation stopwatch""" | |
| if self.tg_start_time is not None: | |
| self.tg_elapsed_time += time.time_ns() - self.tg_start_time | |
| self.tg_start_time = None | |
| def start_wall_time(self): | |
| """Start wall-time stopwatch""" | |
| self.wall_start_time = time.time_ns() | |
| def stop_wall_time(self): | |
| """Stop wall-time stopwatch""" | |
| if self.wall_start_time is not None: | |
| self.wall_elapsed_time += time.time_ns() - self.wall_start_time | |
| self.wall_start_time = None | |
| def start_generic(self): | |
| """Start generic stopwatch (not shown in print_stats)""" | |
| self.generic_start_time = time.time_ns() | |
| def stop_generic(self): | |
| """Stop generic stopwatch""" | |
| if self.generic_start_time is not None: | |
| self.generic_elapsed_time += time.time_ns() - self.generic_start_time | |
| self.generic_start_time = None | |
| def get_elapsed_time_pp(self) -> int: | |
| """Total nanoseconds elapsed during prompt processing""" | |
| return self.pp_elapsed_time | |
| def get_elapsed_time_tg(self) -> int: | |
| """Total nanoseconds elapsed during text generation""" | |
| return self.tg_elapsed_time | |
| def get_elapsed_wall_time(self) -> int: | |
| """Total wall-time nanoseconds elapsed""" | |
| return self.wall_elapsed_time | |
| def get_elapsed_time_generic(self) -> int: | |
| """Total generic nanoseconds elapsed""" | |
| return self.generic_elapsed_time | |
| def increment_pp_tokens(self, n: int): | |
| if n < 0: | |
| raise ValueError('negative increments are not allowed') | |
| self.n_pp_tokens += n | |
| def increment_tg_tokens(self, n: int): | |
| if n < 0: | |
| raise ValueError('negative increments are not allowed') | |
| self.n_tg_tokens += n | |
| def reset(self): | |
| """Reset the stopwatch to its original state""" | |
| self.pp_start_time = None | |
| self.tg_start_time = None | |
| self.wall_start_time = None | |
| self.generic_start_time = None | |
| self.pp_elapsed_time = 0 | |
| self.tg_elapsed_time = 0 | |
| self.wall_elapsed_time = 0 | |
| self.generic_elapsed_time = 0 | |
| self.n_pp_tokens = 0 | |
| self.n_tg_tokens = 0 | |
| def print_stats(self): | |
| """Print performance statistics using current stopwatch state | |
| #### NOTE: | |
| The `n_tg_tokens` value will be equal to the number of calls to | |
| llama_decode which have a batch size of 1, which is technically not | |
| always equal to the number of tokens generated - it may be off by one.""" | |
| print(f"\n", end='', file=sys.stderr, flush=True) | |
| if self.n_pp_tokens + self.n_tg_tokens == 0: | |
| print_stopwatch(f'print_stats was called but no tokens were processed or generated') | |
| if self.n_pp_tokens > 0: | |
| pp_elapsed_ns = self.get_elapsed_time_pp() | |
| pp_elapsed_ms = pp_elapsed_ns / 1e6 | |
| pp_elapsed_s = pp_elapsed_ns / 1e9 | |
| pp_tps = self.n_pp_tokens / pp_elapsed_s | |
| print_stopwatch( | |
| f'prompt processing: {self.n_pp_tokens:>7} tokens in {pp_elapsed_ms:>13.3f}ms ' | |
| f'({pp_tps:>10.2f} tok/s)' | |
| ) | |
| if self.n_tg_tokens > 0: | |
| tg_elapsed_ns = self.get_elapsed_time_tg() | |
| tg_elapsed_ms = tg_elapsed_ns / 1e6 | |
| tg_elapsed_s = tg_elapsed_ns / 1e9 | |
| tg_tps = self.n_tg_tokens / tg_elapsed_s | |
| print_stopwatch( | |
| f' text generation: {self.n_tg_tokens:>7} tokens in {tg_elapsed_ms:>13.3f}ms ' | |
| f'({tg_tps:>10.2f} tok/s)' | |
| ) | |
| wall_elapsed_ns = self.get_elapsed_wall_time() | |
| wall_elapsed_ms = wall_elapsed_ns / 1e6 | |
| print_stopwatch(f" wall time:{' ' * 19}{wall_elapsed_ms:>13.3f}ms") | |