|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use crate::core::Point; |
|
|
use std::collections::VecDeque; |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct LearnableRoutingConfig { |
|
|
|
|
|
pub learning_rate: f32, |
|
|
|
|
|
|
|
|
pub momentum: f32, |
|
|
|
|
|
|
|
|
pub weight_decay: f32, |
|
|
|
|
|
|
|
|
pub max_feedback_samples: usize, |
|
|
|
|
|
|
|
|
pub min_samples_to_learn: usize, |
|
|
|
|
|
|
|
|
pub update_frequency: usize, |
|
|
|
|
|
|
|
|
pub per_dimension_weights: bool, |
|
|
} |
|
|
|
|
|
impl Default for LearnableRoutingConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
learning_rate: 0.01, |
|
|
momentum: 0.9, |
|
|
weight_decay: 0.001, |
|
|
max_feedback_samples: 1000, |
|
|
min_samples_to_learn: 50, |
|
|
update_frequency: 10, |
|
|
per_dimension_weights: true, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl LearnableRoutingConfig { |
|
|
pub fn new() -> Self { |
|
|
Self::default() |
|
|
} |
|
|
|
|
|
pub fn with_learning_rate(mut self, lr: f32) -> Self { |
|
|
self.learning_rate = lr; |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn with_momentum(mut self, momentum: f32) -> Self { |
|
|
self.momentum = momentum.clamp(0.0, 0.99); |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn disabled() -> Self { |
|
|
Self { |
|
|
learning_rate: 0.0, |
|
|
..Default::default() |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct RoutingFeedback { |
|
|
|
|
|
pub query: Point, |
|
|
|
|
|
|
|
|
pub selected_centroid: Point, |
|
|
|
|
|
|
|
|
pub reward: f32, |
|
|
|
|
|
|
|
|
pub level: usize, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct LearnableRouter { |
|
|
|
|
|
config: LearnableRoutingConfig, |
|
|
|
|
|
|
|
|
weights: Vec<f32>, |
|
|
|
|
|
|
|
|
momentum_buffer: Vec<f32>, |
|
|
|
|
|
|
|
|
feedback_buffer: VecDeque<RoutingFeedback>, |
|
|
|
|
|
|
|
|
total_samples: usize, |
|
|
|
|
|
|
|
|
dims: usize, |
|
|
} |
|
|
|
|
|
impl LearnableRouter { |
|
|
|
|
|
pub fn new(dims: usize, config: LearnableRoutingConfig) -> Self { |
|
|
let weight_count = if config.per_dimension_weights { dims } else { 1 }; |
|
|
|
|
|
Self { |
|
|
config, |
|
|
weights: vec![1.0; weight_count], |
|
|
momentum_buffer: vec![0.0; weight_count], |
|
|
feedback_buffer: VecDeque::new(), |
|
|
total_samples: 0, |
|
|
dims, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn default_for_dims(dims: usize) -> Self { |
|
|
Self::new(dims, LearnableRoutingConfig::default()) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn is_learning_enabled(&self) -> bool { |
|
|
self.config.learning_rate > 0.0 |
|
|
} |
|
|
|
|
|
|
|
|
pub fn weights(&self) -> &[f32] { |
|
|
&self.weights |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn weighted_similarity(&self, query: &Point, centroid: &Point) -> f32 { |
|
|
if self.config.per_dimension_weights { |
|
|
|
|
|
query.dims().iter() |
|
|
.zip(centroid.dims().iter()) |
|
|
.zip(self.weights.iter()) |
|
|
.map(|((q, c), w)| w * q * c) |
|
|
.sum() |
|
|
} else { |
|
|
|
|
|
let dot: f32 = query.dims().iter() |
|
|
.zip(centroid.dims().iter()) |
|
|
.map(|(q, c)| q * c) |
|
|
.sum(); |
|
|
self.weights[0] * dot |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn record_feedback(&mut self, feedback: RoutingFeedback) { |
|
|
self.feedback_buffer.push_back(feedback); |
|
|
self.total_samples += 1; |
|
|
|
|
|
|
|
|
while self.feedback_buffer.len() > self.config.max_feedback_samples { |
|
|
self.feedback_buffer.pop_front(); |
|
|
} |
|
|
|
|
|
|
|
|
if self.should_update() { |
|
|
self.update_weights(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
fn should_update(&self) -> bool { |
|
|
self.config.learning_rate > 0.0 |
|
|
&& self.feedback_buffer.len() >= self.config.min_samples_to_learn |
|
|
&& self.total_samples % self.config.update_frequency == 0 |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn update_weights(&mut self) { |
|
|
if self.feedback_buffer.is_empty() { |
|
|
return; |
|
|
} |
|
|
|
|
|
let lr = self.config.learning_rate; |
|
|
let momentum = self.config.momentum; |
|
|
let decay = self.config.weight_decay; |
|
|
|
|
|
|
|
|
let mut gradient = vec![0.0f32; self.weights.len()]; |
|
|
|
|
|
for feedback in &self.feedback_buffer { |
|
|
let reward = feedback.reward; |
|
|
|
|
|
if self.config.per_dimension_weights { |
|
|
|
|
|
for ((&q, &c), g) in feedback.query.dims().iter() |
|
|
.zip(feedback.selected_centroid.dims().iter()) |
|
|
.zip(gradient.iter_mut()) |
|
|
{ |
|
|
|
|
|
*g += reward * q * c; |
|
|
} |
|
|
} else { |
|
|
|
|
|
let dot: f32 = feedback.query.dims().iter() |
|
|
.zip(feedback.selected_centroid.dims().iter()) |
|
|
.map(|(q, c)| q * c) |
|
|
.sum(); |
|
|
gradient[0] += reward * dot; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
let n = self.feedback_buffer.len() as f32; |
|
|
for g in gradient.iter_mut() { |
|
|
*g /= n; |
|
|
} |
|
|
|
|
|
|
|
|
for (i, (w, g)) in self.weights.iter_mut().zip(gradient.iter()).enumerate() { |
|
|
|
|
|
self.momentum_buffer[i] = momentum * self.momentum_buffer[i] + (1.0 - momentum) * g; |
|
|
|
|
|
|
|
|
*w += lr * self.momentum_buffer[i] - decay * (*w - 1.0); |
|
|
|
|
|
|
|
|
*w = w.clamp(0.1, 10.0); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn record_success(&mut self, query: &Point, selected_centroid: &Point, level: usize) { |
|
|
self.record_feedback(RoutingFeedback { |
|
|
query: query.clone(), |
|
|
selected_centroid: selected_centroid.clone(), |
|
|
reward: 1.0, |
|
|
level, |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
pub fn record_failure(&mut self, query: &Point, selected_centroid: &Point, level: usize) { |
|
|
self.record_feedback(RoutingFeedback { |
|
|
query: query.clone(), |
|
|
selected_centroid: selected_centroid.clone(), |
|
|
reward: -1.0, |
|
|
level, |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
pub fn record_implicit(&mut self, query: &Point, selected_centroid: &Point, level: usize, relevance_score: f32) { |
|
|
|
|
|
let reward = 2.0 * relevance_score - 1.0; |
|
|
self.record_feedback(RoutingFeedback { |
|
|
query: query.clone(), |
|
|
selected_centroid: selected_centroid.clone(), |
|
|
reward, |
|
|
level, |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
pub fn stats(&self) -> RouterStats { |
|
|
RouterStats { |
|
|
total_samples: self.total_samples, |
|
|
buffer_size: self.feedback_buffer.len(), |
|
|
weight_mean: self.weights.iter().sum::<f32>() / self.weights.len() as f32, |
|
|
weight_std: { |
|
|
let mean = self.weights.iter().sum::<f32>() / self.weights.len() as f32; |
|
|
(self.weights.iter().map(|w| (w - mean).powi(2)).sum::<f32>() |
|
|
/ self.weights.len() as f32).sqrt() |
|
|
}, |
|
|
weight_min: self.weights.iter().cloned().fold(f32::INFINITY, f32::min), |
|
|
weight_max: self.weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn reset_weights(&mut self) { |
|
|
for w in self.weights.iter_mut() { |
|
|
*w = 1.0; |
|
|
} |
|
|
for m in self.momentum_buffer.iter_mut() { |
|
|
*m = 0.0; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn clear_feedback(&mut self) { |
|
|
self.feedback_buffer.clear(); |
|
|
} |
|
|
|
|
|
|
|
|
pub fn dims(&self) -> usize { |
|
|
self.dims |
|
|
} |
|
|
|
|
|
|
|
|
pub fn serialize_weights(&self) -> Vec<u8> { |
|
|
let mut bytes = Vec::with_capacity(self.weights.len() * 4); |
|
|
for w in &self.weights { |
|
|
bytes.extend_from_slice(&w.to_le_bytes()); |
|
|
} |
|
|
bytes |
|
|
} |
|
|
|
|
|
|
|
|
pub fn deserialize_weights(&mut self, bytes: &[u8]) -> Result<(), &'static str> { |
|
|
if bytes.len() != self.weights.len() * 4 { |
|
|
return Err("Weight count mismatch"); |
|
|
} |
|
|
|
|
|
for (i, chunk) in bytes.chunks(4).enumerate() { |
|
|
let arr: [u8; 4] = chunk.try_into().map_err(|_| "Invalid byte chunk")?; |
|
|
self.weights[i] = f32::from_le_bytes(arr); |
|
|
} |
|
|
|
|
|
Ok(()) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct RouterStats { |
|
|
pub total_samples: usize, |
|
|
pub buffer_size: usize, |
|
|
pub weight_mean: f32, |
|
|
pub weight_std: f32, |
|
|
pub weight_min: f32, |
|
|
pub weight_max: f32, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn compute_routing_score( |
|
|
router: &LearnableRouter, |
|
|
query: &Point, |
|
|
centroid: &Point, |
|
|
temporal_distance: f32, |
|
|
temporal_weight: f32, |
|
|
) -> f32 { |
|
|
let semantic_sim = router.weighted_similarity(query, centroid); |
|
|
|
|
|
|
|
|
let semantic_dist = 1.0 - semantic_sim; |
|
|
|
|
|
|
|
|
semantic_dist * (1.0 - temporal_weight) + temporal_distance * temporal_weight |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
fn make_point(v: Vec<f32>) -> Point { |
|
|
Point::new(v).normalize() |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_router_creation() { |
|
|
let router = LearnableRouter::default_for_dims(64); |
|
|
|
|
|
assert_eq!(router.dims(), 64); |
|
|
assert_eq!(router.weights().len(), 64); |
|
|
assert!(router.is_learning_enabled()); |
|
|
|
|
|
|
|
|
for &w in router.weights() { |
|
|
assert!((w - 1.0).abs() < 1e-6); |
|
|
} |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_weighted_similarity() { |
|
|
let router = LearnableRouter::default_for_dims(4); |
|
|
|
|
|
let query = make_point(vec![1.0, 0.0, 0.0, 0.0]); |
|
|
let centroid = make_point(vec![0.8, 0.2, 0.0, 0.0]); |
|
|
|
|
|
let sim = router.weighted_similarity(&query, ¢roid); |
|
|
|
|
|
|
|
|
let expected_cosine: f32 = query.dims().iter() |
|
|
.zip(centroid.dims().iter()) |
|
|
.map(|(q, c)| q * c) |
|
|
.sum(); |
|
|
|
|
|
assert!((sim - expected_cosine).abs() < 1e-5); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_feedback_recording() { |
|
|
let mut router = LearnableRouter::new(4, LearnableRoutingConfig { |
|
|
min_samples_to_learn: 5, |
|
|
update_frequency: 5, |
|
|
..Default::default() |
|
|
}); |
|
|
|
|
|
let query = make_point(vec![1.0, 0.0, 0.0, 0.0]); |
|
|
let centroid = make_point(vec![0.9, 0.1, 0.0, 0.0]); |
|
|
|
|
|
|
|
|
for _ in 0..10 { |
|
|
router.record_success(&query, ¢roid, 0); |
|
|
} |
|
|
|
|
|
let stats = router.stats(); |
|
|
assert_eq!(stats.total_samples, 10); |
|
|
|
|
|
|
|
|
|
|
|
println!("Weights after positive feedback: {:?}", router.weights()); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_learning_dynamics() { |
|
|
let mut router = LearnableRouter::new(4, LearnableRoutingConfig { |
|
|
learning_rate: 0.1, |
|
|
min_samples_to_learn: 3, |
|
|
update_frequency: 3, |
|
|
momentum: 0.0, |
|
|
weight_decay: 0.0, |
|
|
..Default::default() |
|
|
}); |
|
|
|
|
|
|
|
|
let query = make_point(vec![1.0, 0.0, 0.0, 0.0]); |
|
|
|
|
|
let centroid_good = make_point(vec![0.95, 0.05, 0.0, 0.0]); |
|
|
|
|
|
let centroid_bad = make_point(vec![0.0, 1.0, 0.0, 0.0]); |
|
|
|
|
|
|
|
|
for _ in 0..6 { |
|
|
router.record_success(&query, ¢roid_good, 0); |
|
|
} |
|
|
|
|
|
let weights_after_positive = router.weights().to_vec(); |
|
|
|
|
|
|
|
|
for _ in 0..6 { |
|
|
router.record_failure(&query, ¢roid_bad, 0); |
|
|
} |
|
|
|
|
|
let weights_after_negative = router.weights().to_vec(); |
|
|
|
|
|
println!("Initial weights: [1.0, 1.0, 1.0, 1.0]"); |
|
|
println!("After positive: {:?}", weights_after_positive); |
|
|
println!("After negative: {:?}", weights_after_negative); |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_disabled_learning() { |
|
|
let mut router = LearnableRouter::new(4, LearnableRoutingConfig::disabled()); |
|
|
|
|
|
assert!(!router.is_learning_enabled()); |
|
|
|
|
|
let query = make_point(vec![1.0, 0.0, 0.0, 0.0]); |
|
|
let centroid = make_point(vec![0.9, 0.1, 0.0, 0.0]); |
|
|
|
|
|
|
|
|
for _ in 0..100 { |
|
|
router.record_success(&query, ¢roid, 0); |
|
|
} |
|
|
|
|
|
|
|
|
for &w in router.weights() { |
|
|
assert!((w - 1.0).abs() < 1e-6); |
|
|
} |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_serialization() { |
|
|
let mut router = LearnableRouter::default_for_dims(4); |
|
|
|
|
|
|
|
|
for (i, w) in router.weights.iter_mut().enumerate() { |
|
|
*w = (i as f32 + 1.0) * 0.5; |
|
|
} |
|
|
|
|
|
let bytes = router.serialize_weights(); |
|
|
|
|
|
let mut router2 = LearnableRouter::default_for_dims(4); |
|
|
router2.deserialize_weights(&bytes).unwrap(); |
|
|
|
|
|
for (w1, w2) in router.weights().iter().zip(router2.weights().iter()) { |
|
|
assert!((w1 - w2).abs() < 1e-6); |
|
|
} |
|
|
} |
|
|
} |
|
|
|