Spaces:
Running
Running
File size: 4,984 Bytes
15b2f1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""
Inference statistics manager for tracking and predicting model inference times.
This module provides functionality to:
- Record historical inference statistics (token count, inference time)
- Predict inference time using k-nearest neighbors algorithm
- Persist statistics to disk for cross-session usage
"""
import json
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
class InferenceStatsManager:
"""Manages inference statistics for time prediction."""
def __init__(self, cache_dir: Optional[str] = None):
"""
Initialize the statistics manager.
Args:
cache_dir: Optional custom cache directory. If None, uses default.
"""
if cache_dir is None:
# Use user's cache directory
if os.name == 'nt': # Windows
base_cache = os.path.expandvars(r'%LOCALAPPDATA%')
else: # Unix-like
base_cache = os.path.expanduser('~/.cache')
cache_dir = os.path.join(base_cache, 'uncheatableeval_lens')
self.cache_dir = Path(cache_dir)
self.stats_file = self.cache_dir / 'inference_stats.json'
# Create cache directory if it doesn't exist
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _load_stats(self) -> List[Dict]:
"""
Load statistics from JSON file.
Returns:
List of statistics records, empty list if file doesn't exist.
"""
if not self.stats_file.exists():
return []
try:
with open(self.stats_file, 'r', encoding='utf-8') as f:
return json.load(f)
except (json.JSONDecodeError, IOError) as e:
print(f"Warning: Failed to load statistics file: {e}")
return []
def _save_stats(self, stats: List[Dict]) -> None:
"""
Save statistics to JSON file.
Args:
stats: List of statistics records to save.
"""
try:
with open(self.stats_file, 'w', encoding='utf-8') as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
except IOError as e:
print(f"Warning: Failed to save statistics file: {e}")
def add_record(self, model_name: str, input_tokens: int, inference_time: float) -> None:
"""
Add a new inference record to the statistics.
Args:
model_name: Name of the model ("qwen" or "rwkv")
input_tokens: Number of input tokens
inference_time: Inference time in seconds
"""
stats = self._load_stats()
record = {
"model_name": model_name,
"input_tokens": input_tokens,
"inference_time": inference_time,
"timestamp": datetime.now().isoformat()
}
stats.append(record)
self._save_stats(stats)
def _find_k_nearest(self, records: List[Dict], target_tokens: int, k: int) -> List[Tuple[Dict, float]]:
"""
Find k nearest records by token count.
Args:
records: List of historical records
target_tokens: Target token count
k: Number of nearest neighbors to find
Returns:
List of (record, distance) tuples, sorted by distance
"""
# Calculate distances
distances = []
for record in records:
distance = abs(record["input_tokens"] - target_tokens)
distances.append((record, distance))
# Sort by distance and return top k
distances.sort(key=lambda x: x[1])
return distances[:k]
def predict_time(self, model_name: str, input_tokens: int, k: int = 5) -> Optional[float]:
"""
Predict inference time using k-nearest neighbors algorithm.
Args:
model_name: Name of the model ("qwen" or "rwkv")
input_tokens: Number of input tokens
k: Number of nearest neighbors to use (default: 5)
Returns:
Predicted inference time in seconds, or None if no historical data
"""
stats = self._load_stats()
# Filter records for the specific model
model_records = [r for r in stats if r["model_name"] == model_name]
if not model_records:
return None
# Find k nearest neighbors
nearest = self._find_k_nearest(model_records, input_tokens, k)
if not nearest:
return None
# Calculate weighted average using inverse distance weighting
total_weight = 0.0
weighted_sum = 0.0
for record, distance in nearest:
# Inverse distance weighting: weight = 1 / (1 + distance)
weight = 1.0 / (1.0 + distance)
weighted_sum += weight * record["inference_time"]
total_weight += weight
if total_weight == 0:
return None
return weighted_sum / total_weight
|