""" Inference statistics manager for tracking and predicting model inference times. This module provides functionality to: - Record historical inference statistics (token count, inference time) - Predict inference time using k-nearest neighbors algorithm - Persist statistics to disk for cross-session usage """ import json import os from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Tuple class InferenceStatsManager: """Manages inference statistics for time prediction.""" def __init__(self, cache_dir: Optional[str] = None): """ Initialize the statistics manager. Args: cache_dir: Optional custom cache directory. If None, uses default. """ if cache_dir is None: # Use user's cache directory if os.name == 'nt': # Windows base_cache = os.path.expandvars(r'%LOCALAPPDATA%') else: # Unix-like base_cache = os.path.expanduser('~/.cache') cache_dir = os.path.join(base_cache, 'uncheatableeval_lens') self.cache_dir = Path(cache_dir) self.stats_file = self.cache_dir / 'inference_stats.json' # Create cache directory if it doesn't exist self.cache_dir.mkdir(parents=True, exist_ok=True) def _load_stats(self) -> List[Dict]: """ Load statistics from JSON file. Returns: List of statistics records, empty list if file doesn't exist. """ if not self.stats_file.exists(): return [] try: with open(self.stats_file, 'r', encoding='utf-8') as f: return json.load(f) except (json.JSONDecodeError, IOError) as e: print(f"Warning: Failed to load statistics file: {e}") return [] def _save_stats(self, stats: List[Dict]) -> None: """ Save statistics to JSON file. Args: stats: List of statistics records to save. """ try: with open(self.stats_file, 'w', encoding='utf-8') as f: json.dump(stats, f, indent=2, ensure_ascii=False) except IOError as e: print(f"Warning: Failed to save statistics file: {e}") def add_record(self, model_name: str, input_tokens: int, inference_time: float) -> None: """ Add a new inference record to the statistics. Args: model_name: Name of the model ("qwen" or "rwkv") input_tokens: Number of input tokens inference_time: Inference time in seconds """ stats = self._load_stats() record = { "model_name": model_name, "input_tokens": input_tokens, "inference_time": inference_time, "timestamp": datetime.now().isoformat() } stats.append(record) self._save_stats(stats) def _find_k_nearest(self, records: List[Dict], target_tokens: int, k: int) -> List[Tuple[Dict, float]]: """ Find k nearest records by token count. Args: records: List of historical records target_tokens: Target token count k: Number of nearest neighbors to find Returns: List of (record, distance) tuples, sorted by distance """ # Calculate distances distances = [] for record in records: distance = abs(record["input_tokens"] - target_tokens) distances.append((record, distance)) # Sort by distance and return top k distances.sort(key=lambda x: x[1]) return distances[:k] def predict_time(self, model_name: str, input_tokens: int, k: int = 5) -> Optional[float]: """ Predict inference time using k-nearest neighbors algorithm. Args: model_name: Name of the model ("qwen" or "rwkv") input_tokens: Number of input tokens k: Number of nearest neighbors to use (default: 5) Returns: Predicted inference time in seconds, or None if no historical data """ stats = self._load_stats() # Filter records for the specific model model_records = [r for r in stats if r["model_name"] == model_name] if not model_records: return None # Find k nearest neighbors nearest = self._find_k_nearest(model_records, input_tokens, k) if not nearest: return None # Calculate weighted average using inverse distance weighting total_weight = 0.0 weighted_sum = 0.0 for record, distance in nearest: # Inverse distance weighting: weight = 1 / (1 + distance) weight = 1.0 / (1.0 + distance) weighted_sum += weight * record["inference_time"] total_weight += weight if total_weight == 0: return None return weighted_sum / total_weight