File size: 4,984 Bytes
15b2f1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Inference statistics manager for tracking and predicting model inference times.

This module provides functionality to:
- Record historical inference statistics (token count, inference time)
- Predict inference time using k-nearest neighbors algorithm
- Persist statistics to disk for cross-session usage
"""

import json
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple


class InferenceStatsManager:
    """Manages inference statistics for time prediction."""

    def __init__(self, cache_dir: Optional[str] = None):
        """
        Initialize the statistics manager.

        Args:
            cache_dir: Optional custom cache directory. If None, uses default.
        """
        if cache_dir is None:
            # Use user's cache directory
            if os.name == 'nt':  # Windows
                base_cache = os.path.expandvars(r'%LOCALAPPDATA%')
            else:  # Unix-like
                base_cache = os.path.expanduser('~/.cache')

            cache_dir = os.path.join(base_cache, 'uncheatableeval_lens')

        self.cache_dir = Path(cache_dir)
        self.stats_file = self.cache_dir / 'inference_stats.json'

        # Create cache directory if it doesn't exist
        self.cache_dir.mkdir(parents=True, exist_ok=True)

    def _load_stats(self) -> List[Dict]:
        """
        Load statistics from JSON file.

        Returns:
            List of statistics records, empty list if file doesn't exist.
        """
        if not self.stats_file.exists():
            return []

        try:
            with open(self.stats_file, 'r', encoding='utf-8') as f:
                return json.load(f)
        except (json.JSONDecodeError, IOError) as e:
            print(f"Warning: Failed to load statistics file: {e}")
            return []

    def _save_stats(self, stats: List[Dict]) -> None:
        """
        Save statistics to JSON file.

        Args:
            stats: List of statistics records to save.
        """
        try:
            with open(self.stats_file, 'w', encoding='utf-8') as f:
                json.dump(stats, f, indent=2, ensure_ascii=False)
        except IOError as e:
            print(f"Warning: Failed to save statistics file: {e}")

    def add_record(self, model_name: str, input_tokens: int, inference_time: float) -> None:
        """
        Add a new inference record to the statistics.

        Args:
            model_name: Name of the model ("qwen" or "rwkv")
            input_tokens: Number of input tokens
            inference_time: Inference time in seconds
        """
        stats = self._load_stats()

        record = {
            "model_name": model_name,
            "input_tokens": input_tokens,
            "inference_time": inference_time,
            "timestamp": datetime.now().isoformat()
        }

        stats.append(record)
        self._save_stats(stats)

    def _find_k_nearest(self, records: List[Dict], target_tokens: int, k: int) -> List[Tuple[Dict, float]]:
        """
        Find k nearest records by token count.

        Args:
            records: List of historical records
            target_tokens: Target token count
            k: Number of nearest neighbors to find

        Returns:
            List of (record, distance) tuples, sorted by distance
        """
        # Calculate distances
        distances = []
        for record in records:
            distance = abs(record["input_tokens"] - target_tokens)
            distances.append((record, distance))

        # Sort by distance and return top k
        distances.sort(key=lambda x: x[1])
        return distances[:k]

    def predict_time(self, model_name: str, input_tokens: int, k: int = 5) -> Optional[float]:
        """
        Predict inference time using k-nearest neighbors algorithm.

        Args:
            model_name: Name of the model ("qwen" or "rwkv")
            input_tokens: Number of input tokens
            k: Number of nearest neighbors to use (default: 5)

        Returns:
            Predicted inference time in seconds, or None if no historical data
        """
        stats = self._load_stats()

        # Filter records for the specific model
        model_records = [r for r in stats if r["model_name"] == model_name]

        if not model_records:
            return None

        # Find k nearest neighbors
        nearest = self._find_k_nearest(model_records, input_tokens, k)

        if not nearest:
            return None

        # Calculate weighted average using inverse distance weighting
        total_weight = 0.0
        weighted_sum = 0.0

        for record, distance in nearest:
            # Inverse distance weighting: weight = 1 / (1 + distance)
            weight = 1.0 / (1.0 + distance)
            weighted_sum += weight * record["inference_time"]
            total_weight += weight

        if total_weight == 0:
            return None

        return weighted_sum / total_weight