File size: 5,159 Bytes
a21c316 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | //! Estimation Calibrator Module
//!
//! Learns from historical request/response pairs to improve token estimation accuracy.
//! Uses actual token counts from Google API responses to calibrate future estimates.
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
use tracing::info;
/// Estimation Calibrator - learns estimation error from historical requests
///
/// This module tracks the ratio between estimated tokens (before request) and
/// actual tokens (from Google API response) to improve future estimations.
pub struct EstimationCalibrator {
/// Cumulative estimated tokens
total_estimated: AtomicU64,
/// Cumulative actual tokens (from Google API)
total_actual: AtomicU64,
/// Sample count
sample_count: AtomicU64,
/// Current calibration factor (estimated * factor ≈ actual)
calibration_factor: RwLock<f32>,
}
impl EstimationCalibrator {
/// Create a new calibrator with default settings
pub const fn new() -> Self {
Self {
total_estimated: AtomicU64::new(0),
total_actual: AtomicU64::new(0),
sample_count: AtomicU64::new(0),
// Initial assumption: estimates are 2.0x lower than actual
// This is conservative and will be adjusted based on real data
calibration_factor: RwLock::new(2.0),
}
}
/// Record a request's estimated vs actual token counts
///
/// Call this after receiving a response from Google API with actual token usage.
pub fn record(&self, estimated: u32, actual: u32) {
if estimated == 0 || actual == 0 {
return;
}
self.total_estimated
.fetch_add(estimated as u64, Ordering::Relaxed);
self.total_actual
.fetch_add(actual as u64, Ordering::Relaxed);
let count = self.sample_count.fetch_add(1, Ordering::Relaxed) + 1;
// Update calibration factor every 5 requests
if count % 5 == 0 {
self.update_calibration();
}
}
/// Update the calibration factor based on accumulated data
fn update_calibration(&self) {
let estimated = self.total_estimated.load(Ordering::Relaxed) as f64;
let actual = self.total_actual.load(Ordering::Relaxed) as f64;
if estimated > 0.0 {
let new_factor = (actual / estimated) as f32;
// Clamp to reasonable range [0.8, 4.0]
// - Below 0.8 means we're overestimating (rare)
// - Above 4.0 means severe underestimation
let clamped = new_factor.clamp(0.8, 4.0);
if let Ok(mut factor) = self.calibration_factor.write() {
// Exponential moving average: 60% old + 40% new
// This provides stability while still adapting to changes
let old = *factor;
*factor = old * 0.6 + clamped * 0.4;
info!(
"[Calibrator] Updated factor: {:.2} -> {:.2} (raw: {:.2}, samples: {})",
old,
*factor,
new_factor,
self.sample_count.load(Ordering::Relaxed)
);
}
}
}
/// Get a calibrated estimate from a raw estimate
///
/// Multiplies the raw estimate by the current calibration factor.
pub fn calibrate(&self, estimated: u32) -> u32 {
let factor = self.calibration_factor.read().map(|f| *f).unwrap_or(2.0);
(estimated as f32 * factor).ceil() as u32
}
/// Get the current calibration factor
pub fn get_factor(&self) -> f32 {
self.calibration_factor.read().map(|f| *f).unwrap_or(2.0)
}
}
impl Default for EstimationCalibrator {
fn default() -> Self {
Self::new()
}
}
// Global singleton instance
use std::sync::OnceLock;
static CALIBRATOR: OnceLock<EstimationCalibrator> = OnceLock::new();
/// Get the global calibrator instance
pub fn get_calibrator() -> &'static EstimationCalibrator {
CALIBRATOR.get_or_init(EstimationCalibrator::new)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calibrator_basic() {
let calibrator = EstimationCalibrator::new();
// Initial factor should be 2.0
assert!((calibrator.get_factor() - 2.0).abs() < 0.01);
// Record some samples where actual is 3x estimated
for _ in 0..10 {
calibrator.record(100, 300);
}
// Factor should have moved towards 3.0
let factor = calibrator.get_factor();
assert!(factor > 2.0);
assert!(factor < 3.5);
}
#[test]
fn test_calibrate() {
let calibrator = EstimationCalibrator::new();
// With default factor of 2.0, 100 should become 200
let calibrated = calibrator.calibrate(100);
assert_eq!(calibrated, 200);
}
#[test]
fn test_zero_handling() {
let calibrator = EstimationCalibrator::new();
// Recording zeros should not affect anything
calibrator.record(0, 100);
calibrator.record(100, 0);
assert_eq!(calibrator.sample_count.load(Ordering::Relaxed), 0);
}
}
|