File size: 5,159 Bytes
a21c316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
//! Estimation Calibrator Module
//!
//! Learns from historical request/response pairs to improve token estimation accuracy.
//! Uses actual token counts from Google API responses to calibrate future estimates.

use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
use tracing::info;

/// Estimation Calibrator - learns estimation error from historical requests
///
/// This module tracks the ratio between estimated tokens (before request) and
/// actual tokens (from Google API response) to improve future estimations.
pub struct EstimationCalibrator {
    /// Cumulative estimated tokens
    total_estimated: AtomicU64,
    /// Cumulative actual tokens (from Google API)
    total_actual: AtomicU64,
    /// Sample count
    sample_count: AtomicU64,
    /// Current calibration factor (estimated * factor ≈ actual)
    calibration_factor: RwLock<f32>,
}

impl EstimationCalibrator {
    /// Create a new calibrator with default settings
    pub const fn new() -> Self {
        Self {
            total_estimated: AtomicU64::new(0),
            total_actual: AtomicU64::new(0),
            sample_count: AtomicU64::new(0),
            // Initial assumption: estimates are 2.0x lower than actual
            // This is conservative and will be adjusted based on real data
            calibration_factor: RwLock::new(2.0),
        }
    }

    /// Record a request's estimated vs actual token counts
    ///
    /// Call this after receiving a response from Google API with actual token usage.
    pub fn record(&self, estimated: u32, actual: u32) {
        if estimated == 0 || actual == 0 {
            return;
        }

        self.total_estimated
            .fetch_add(estimated as u64, Ordering::Relaxed);
        self.total_actual
            .fetch_add(actual as u64, Ordering::Relaxed);
        let count = self.sample_count.fetch_add(1, Ordering::Relaxed) + 1;

        // Update calibration factor every 5 requests
        if count % 5 == 0 {
            self.update_calibration();
        }
    }

    /// Update the calibration factor based on accumulated data
    fn update_calibration(&self) {
        let estimated = self.total_estimated.load(Ordering::Relaxed) as f64;
        let actual = self.total_actual.load(Ordering::Relaxed) as f64;

        if estimated > 0.0 {
            let new_factor = (actual / estimated) as f32;
            // Clamp to reasonable range [0.8, 4.0]
            // - Below 0.8 means we're overestimating (rare)
            // - Above 4.0 means severe underestimation
            let clamped = new_factor.clamp(0.8, 4.0);

            if let Ok(mut factor) = self.calibration_factor.write() {
                // Exponential moving average: 60% old + 40% new
                // This provides stability while still adapting to changes
                let old = *factor;
                *factor = old * 0.6 + clamped * 0.4;

                info!(
                    "[Calibrator] Updated factor: {:.2} -> {:.2} (raw: {:.2}, samples: {})",
                    old,
                    *factor,
                    new_factor,
                    self.sample_count.load(Ordering::Relaxed)
                );
            }
        }
    }

    /// Get a calibrated estimate from a raw estimate
    ///
    /// Multiplies the raw estimate by the current calibration factor.
    pub fn calibrate(&self, estimated: u32) -> u32 {
        let factor = self.calibration_factor.read().map(|f| *f).unwrap_or(2.0);

        (estimated as f32 * factor).ceil() as u32
    }

    /// Get the current calibration factor
    pub fn get_factor(&self) -> f32 {
        self.calibration_factor.read().map(|f| *f).unwrap_or(2.0)
    }
}

impl Default for EstimationCalibrator {
    fn default() -> Self {
        Self::new()
    }
}

// Global singleton instance
use std::sync::OnceLock;

static CALIBRATOR: OnceLock<EstimationCalibrator> = OnceLock::new();

/// Get the global calibrator instance
pub fn get_calibrator() -> &'static EstimationCalibrator {
    CALIBRATOR.get_or_init(EstimationCalibrator::new)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_calibrator_basic() {
        let calibrator = EstimationCalibrator::new();

        // Initial factor should be 2.0
        assert!((calibrator.get_factor() - 2.0).abs() < 0.01);

        // Record some samples where actual is 3x estimated
        for _ in 0..10 {
            calibrator.record(100, 300);
        }

        // Factor should have moved towards 3.0
        let factor = calibrator.get_factor();
        assert!(factor > 2.0);
        assert!(factor < 3.5);
    }

    #[test]
    fn test_calibrate() {
        let calibrator = EstimationCalibrator::new();

        // With default factor of 2.0, 100 should become 200
        let calibrated = calibrator.calibrate(100);
        assert_eq!(calibrated, 200);
    }

    #[test]
    fn test_zero_handling() {
        let calibrator = EstimationCalibrator::new();

        // Recording zeros should not affect anything
        calibrator.record(0, 100);
        calibrator.record(100, 0);

        assert_eq!(calibrator.sample_count.load(Ordering::Relaxed), 0);
    }
}