|
|
| # |
| # Exceptions and other classes |
| # |
|
|
| class ExceededContextLengthException(Exception): |
| """Exception raised when an input exceeds a model's context length""" |
|
|
| class _LlamaStopwatch: |
| """Track elapsed time for prompt processing and text generation""" |
| # |
| # Q: why don't you use llama_perf_context? |
| # |
| # A: comments in llama.h state to only use that in llama.cpp examples, |
| # and to do your own performance measurements instead. |
| # |
| # trying to use llama_perf_context leads to output with |
| # "0.00 ms per token" and "inf tokens per second" |
| # |
| def __init__(self): |
| self.pp_start_time = None |
| self.tg_start_time = None |
| self.wall_start_time = None |
| self.generic_start_time = None |
| self.pp_elapsed_time = 0 |
| self.tg_elapsed_time = 0 |
| self.wall_elapsed_time = 0 |
| self.generic_elapsed_time = 0 |
| self.n_pp_tokens = 0 |
| self.n_tg_tokens = 0 |
|
|
| def start_pp(self): |
| """Start prompt processing stopwatch""" |
| self.pp_start_time = time.time_ns() |
|
|
| def stop_pp(self): |
| """Stop prompt processing stopwatch""" |
| if self.pp_start_time is not None: |
| self.pp_elapsed_time += time.time_ns() - self.pp_start_time |
| self.pp_start_time = None |
|
|
| def start_tg(self): |
| """Start text generation stopwatch""" |
| self.tg_start_time = time.time_ns() |
|
|
| def stop_tg(self): |
| """Stop text generation stopwatch""" |
| if self.tg_start_time is not None: |
| self.tg_elapsed_time += time.time_ns() - self.tg_start_time |
| self.tg_start_time = None |
| |
| def start_wall_time(self): |
| """Start wall-time stopwatch""" |
| self.wall_start_time = time.time_ns() |
|
|
| def stop_wall_time(self): |
| """Stop wall-time stopwatch""" |
| if self.wall_start_time is not None: |
| self.wall_elapsed_time += time.time_ns() - self.wall_start_time |
| self.wall_start_time = None |
|
|
| def start_generic(self): |
| """Start generic stopwatch (not shown in print_stats)""" |
| self.generic_start_time = time.time_ns() |
| |
| def stop_generic(self): |
| """Stop generic stopwatch""" |
| if self.generic_start_time is not None: |
| self.generic_elapsed_time += time.time_ns() - self.generic_start_time |
| self.generic_start_time = None |
| |
| def get_elapsed_time_pp(self) -> int: |
| """Total nanoseconds elapsed during prompt processing""" |
| return self.pp_elapsed_time |
|
|
| def get_elapsed_time_tg(self) -> int: |
| """Total nanoseconds elapsed during text generation""" |
| return self.tg_elapsed_time |
| |
| def get_elapsed_wall_time(self) -> int: |
| """Total wall-time nanoseconds elapsed""" |
| return self.wall_elapsed_time |
|
|
| def get_elapsed_time_generic(self) -> int: |
| """Total generic nanoseconds elapsed""" |
| return self.generic_elapsed_time |
|
|
| def increment_pp_tokens(self, n: int): |
| if n < 0: |
| raise ValueError('negative increments are not allowed') |
| self.n_pp_tokens += n |
|
|
| def increment_tg_tokens(self, n: int): |
| if n < 0: |
| raise ValueError('negative increments are not allowed') |
| self.n_tg_tokens += n |
|
|
| def reset(self): |
| """Reset the stopwatch to its original state""" |
| self.pp_start_time = None |
| self.tg_start_time = None |
| self.wall_start_time = None |
| self.generic_start_time = None |
| self.pp_elapsed_time = 0 |
| self.tg_elapsed_time = 0 |
| self.wall_elapsed_time = 0 |
| self.generic_elapsed_time = 0 |
| self.n_pp_tokens = 0 |
| self.n_tg_tokens = 0 |
|
|
| def print_stats(self): |
| """Print performance statistics using current stopwatch state |
| |
| #### NOTE: |
| The `n_tg_tokens` value will be equal to the number of calls to |
| llama_decode which have a batch size of 1, which is technically not |
| always equal to the number of tokens generated - it may be off by one.""" |
|
|
| print(f"\n", end='', file=sys.stderr, flush=True) |
|
|
| if self.n_pp_tokens + self.n_tg_tokens == 0: |
| print_stopwatch(f'print_stats was called but no tokens were processed or generated') |
|
|
| if self.n_pp_tokens > 0: |
| pp_elapsed_ns = self.get_elapsed_time_pp() |
| pp_elapsed_ms = pp_elapsed_ns / 1e6 |
| pp_elapsed_s = pp_elapsed_ns / 1e9 |
| pp_tps = self.n_pp_tokens / pp_elapsed_s |
| print_stopwatch( |
| f'prompt processing: {self.n_pp_tokens:>7} tokens in {pp_elapsed_ms:>13.3f}ms ' |
| f'({pp_tps:>10.2f} tok/s)' |
| ) |
|
|
| if self.n_tg_tokens > 0: |
| tg_elapsed_ns = self.get_elapsed_time_tg() |
| tg_elapsed_ms = tg_elapsed_ns / 1e6 |
| tg_elapsed_s = tg_elapsed_ns / 1e9 |
| tg_tps = self.n_tg_tokens / tg_elapsed_s |
| print_stopwatch( |
| f' text generation: {self.n_tg_tokens:>7} tokens in {tg_elapsed_ms:>13.3f}ms ' |
| f'({tg_tps:>10.2f} tok/s)' |
| ) |
| |
| wall_elapsed_ns = self.get_elapsed_wall_time() |
| wall_elapsed_ms = wall_elapsed_ns / 1e6 |
| print_stopwatch(f" wall time:{' ' * 19}{wall_elapsed_ms:>13.3f}ms") |
|
|
|
|