Compression-Lens / core /inference_stats.py
Jellyfish042's picture
update
15b2f1f
"""
Inference statistics manager for tracking and predicting model inference times.
This module provides functionality to:
- Record historical inference statistics (token count, inference time)
- Predict inference time using k-nearest neighbors algorithm
- Persist statistics to disk for cross-session usage
"""
import json
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
class InferenceStatsManager:
"""Manages inference statistics for time prediction."""
def __init__(self, cache_dir: Optional[str] = None):
"""
Initialize the statistics manager.
Args:
cache_dir: Optional custom cache directory. If None, uses default.
"""
if cache_dir is None:
# Use user's cache directory
if os.name == 'nt': # Windows
base_cache = os.path.expandvars(r'%LOCALAPPDATA%')
else: # Unix-like
base_cache = os.path.expanduser('~/.cache')
cache_dir = os.path.join(base_cache, 'uncheatableeval_lens')
self.cache_dir = Path(cache_dir)
self.stats_file = self.cache_dir / 'inference_stats.json'
# Create cache directory if it doesn't exist
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _load_stats(self) -> List[Dict]:
"""
Load statistics from JSON file.
Returns:
List of statistics records, empty list if file doesn't exist.
"""
if not self.stats_file.exists():
return []
try:
with open(self.stats_file, 'r', encoding='utf-8') as f:
return json.load(f)
except (json.JSONDecodeError, IOError) as e:
print(f"Warning: Failed to load statistics file: {e}")
return []
def _save_stats(self, stats: List[Dict]) -> None:
"""
Save statistics to JSON file.
Args:
stats: List of statistics records to save.
"""
try:
with open(self.stats_file, 'w', encoding='utf-8') as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
except IOError as e:
print(f"Warning: Failed to save statistics file: {e}")
def add_record(self, model_name: str, input_tokens: int, inference_time: float) -> None:
"""
Add a new inference record to the statistics.
Args:
model_name: Name of the model ("qwen" or "rwkv")
input_tokens: Number of input tokens
inference_time: Inference time in seconds
"""
stats = self._load_stats()
record = {
"model_name": model_name,
"input_tokens": input_tokens,
"inference_time": inference_time,
"timestamp": datetime.now().isoformat()
}
stats.append(record)
self._save_stats(stats)
def _find_k_nearest(self, records: List[Dict], target_tokens: int, k: int) -> List[Tuple[Dict, float]]:
"""
Find k nearest records by token count.
Args:
records: List of historical records
target_tokens: Target token count
k: Number of nearest neighbors to find
Returns:
List of (record, distance) tuples, sorted by distance
"""
# Calculate distances
distances = []
for record in records:
distance = abs(record["input_tokens"] - target_tokens)
distances.append((record, distance))
# Sort by distance and return top k
distances.sort(key=lambda x: x[1])
return distances[:k]
def predict_time(self, model_name: str, input_tokens: int, k: int = 5) -> Optional[float]:
"""
Predict inference time using k-nearest neighbors algorithm.
Args:
model_name: Name of the model ("qwen" or "rwkv")
input_tokens: Number of input tokens
k: Number of nearest neighbors to use (default: 5)
Returns:
Predicted inference time in seconds, or None if no historical data
"""
stats = self._load_stats()
# Filter records for the specific model
model_records = [r for r in stats if r["model_name"] == model_name]
if not model_records:
return None
# Find k nearest neighbors
nearest = self._find_k_nearest(model_records, input_tokens, k)
if not nearest:
return None
# Calculate weighted average using inverse distance weighting
total_weight = 0.0
weighted_sum = 0.0
for record, distance in nearest:
# Inverse distance weighting: weight = 1 / (1 + distance)
weight = 1.0 / (1.0 + distance)
weighted_sum += weight * record["inference_time"]
total_weight += weight
if total_weight == 0:
return None
return weighted_sum / total_weight