import time from typing import Any, Callable, Dict, List import torch from loguru import logger as eval_logger def space_tokenizer(text: str) -> int: """ A simple tokenizer that counts the token by the splitted space Then a rough estimate of the token count is returned. (times 1.5) Args: text (str): The input text to tokenize. """ return len(text.split(" ")) * 1.5 def calculate_token_throughput(token_count: int, duration: float) -> float: """ Calculate the token throughput. Args: token_count (int): The number of tokens processed. duration (float): The time taken to process the tokens in seconds. Returns: float: The token throughput in tokens per second. """ if duration <= 0: return 0.0 return token_count / duration def log_metrics(e2e_latency: float, total_tokens: int, avg_speed: float, additional_metrics: Dict[str, Any] = None): """ Log the metrics in a structured format. Args: e2e_latency (float): The end-to-end latency in seconds. total_tokens (int): The total number of tokens processed. avg_speed (float): The average speed in tokens per second. additional_metrics (Dict[str, Any]): Additional metrics to log. """ required_stats = f"Metric summary - Total time: {e2e_latency:.3f}s, Total tokens: {total_tokens}, Avg speed: {avg_speed:.1f} tokens/s" if additional_metrics is not None: required_stats += ", Additional metrics: " required_stats += ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in additional_metrics.items()) eval_logger.info(required_stats) class GenMetrics: """ A class to manage the generation of metrics for model evaluation. """ def __init__(self, tokenize_fn: Callable = space_tokenizer): self.tokenize_fn = tokenize_fn def __enter__(self): """ Initialize the context manager. """ self.metrics = {} self.start_time = time.perf_counter() return self def stop_timer(self): self.end_time = time.perf_counter() def log_metric(self, content: List[Any], additional_metrics: Dict[str, Any] = None): num_tokens = sum(self.tokenize_fn(item) for item in content) duration = self.end_time - self.start_time throughput = calculate_token_throughput(num_tokens, duration) self.metrics = { "num_tokens": num_tokens, "duration": duration, "throughput": throughput, } if additional_metrics: self.metrics.update(additional_metrics) log_metrics(self.metrics) def __exit__(self, exc_type, exc_value, traceback): """ Finalize the context manager and return the collected metrics. """ self.end_time = time.perf_counter() self.metrics["duration"] = self.end_time - self.start_time return self.metrics