| |
| |
| |
| |
| |
| |
|
|
| use std::collections::HashMap; |
| use std::path::Path; |
|
|
| use serde::{Deserialize, Serialize}; |
|
|
| |
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub struct KnowledgeEntry { |
| |
| pub id: String, |
| |
| pub hashtags: Vec<String>, |
| |
| pub question: String, |
| |
| pub question_en: String, |
| |
| pub answer: String, |
| |
| pub language: String, |
| } |
|
|
| |
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub struct KnowledgeBase { |
| pub version: u32, |
| pub entries: Vec<KnowledgeEntry>, |
| } |
|
|
| |
| pub struct KnowledgeIndex { |
| |
| tag_index: HashMap<String, Vec<usize>>, |
| |
| entries: Vec<KnowledgeEntry>, |
| } |
|
|
| |
| #[derive(Debug, Clone)] |
| pub enum KbLookup { |
| |
| Hit { |
| answer: String, |
| entry_id: String, |
| score: f64, |
| }, |
| |
| Partial { |
| answer_hint: String, |
| entry_id: String, |
| score: f64, |
| }, |
| |
| Miss, |
| } |
|
|
| #[allow(dead_code)] |
| impl KnowledgeIndex { |
| |
| pub fn load(path: &Path) -> anyhow::Result<Self> { |
| let content = std::fs::read_to_string(path)?; |
| let kb: KnowledgeBase = serde_json::from_str(&content)?; |
| tracing::info!( |
| "Knowledge base loaded: {} entries, version {}", |
| kb.entries.len(), |
| kb.version |
| ); |
| Self::from_entries(kb.entries) |
| } |
|
|
| |
| pub fn empty() -> Self { |
| Self { |
| tag_index: HashMap::new(), |
| entries: Vec::new(), |
| } |
| } |
|
|
| fn from_entries(entries: Vec<KnowledgeEntry>) -> anyhow::Result<Self> { |
| let mut tag_index: HashMap<String, Vec<usize>> = HashMap::new(); |
| for (i, entry) in entries.iter().enumerate() { |
| for tag in &entry.hashtags { |
| tag_index.entry(tag.clone()).or_default().push(i); |
| } |
| } |
| Ok(Self { tag_index, entries }) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| pub fn lookup(&self, query: &str, query_en: &str, hashtags: &[String]) -> KbLookup { |
| if hashtags.is_empty() || self.entries.is_empty() { |
| return KbLookup::Miss; |
| } |
|
|
| |
| let mut candidates: Vec<(usize, usize)> = Vec::new(); |
| let mut seen: std::collections::HashSet<usize> = std::collections::HashSet::new(); |
|
|
| for tag in hashtags { |
| let clean_tag = tag.trim_start_matches('#'); |
| if let Some(indices) = self.tag_index.get(clean_tag) { |
| for &idx in indices { |
| if seen.insert(idx) { |
| |
| let entry = &self.entries[idx]; |
| let overlap = entry |
| .hashtags |
| .iter() |
| .filter(|t| { |
| hashtags |
| .iter() |
| .any(|h| h.trim_start_matches('#') == t.as_str()) |
| }) |
| .count(); |
| candidates.push((idx, overlap)); |
| } |
| } |
| } |
| } |
|
|
| if candidates.is_empty() { |
| return KbLookup::Miss; |
| } |
|
|
| |
| let mut best_score = 0.0f64; |
| let mut best_idx = 0usize; |
|
|
| for (idx, tag_overlap) in &candidates { |
| let entry = &self.entries[*idx]; |
|
|
| |
| let total_hashtags = hashtags.len().max(entry.hashtags.len()).max(1); |
| let tag_score = 0.6 * (*tag_overlap as f64 / total_hashtags as f64); |
|
|
| |
| let text_score_en = 0.3 * str_similarity(query_en, &entry.question_en); |
| let text_score_orig = 0.1 * str_similarity(query, &entry.question); |
|
|
| let score = tag_score + text_score_en + text_score_orig; |
|
|
| if score > best_score { |
| best_score = score; |
| best_idx = *idx; |
| } |
| } |
|
|
| let entry = &self.entries[best_idx]; |
|
|
| if best_score >= 0.75 { |
| |
| KbLookup::Hit { |
| answer: entry.answer.clone(), |
| entry_id: entry.id.clone(), |
| score: best_score, |
| } |
| } else if best_score >= 0.35 { |
| |
| KbLookup::Partial { |
| answer_hint: entry.answer.clone(), |
| entry_id: entry.id.clone(), |
| score: best_score, |
| } |
| } else { |
| KbLookup::Miss |
| } |
| } |
|
|
| |
| pub fn len(&self) -> usize { |
| self.entries.len() |
| } |
|
|
| |
| pub fn is_empty(&self) -> bool { |
| self.entries.is_empty() |
| } |
| } |
|
|
| |
| |
| |
| fn str_similarity(a: &str, b: &str) -> f64 { |
| let a_lower = a.to_lowercase(); |
| let b_lower = b.to_lowercase(); |
|
|
| if a_lower == b_lower { |
| return 1.0; |
| } |
|
|
| |
| let bigrams_a: std::collections::HashSet<(char, char)> = a_lower |
| .chars() |
| .collect::<Vec<_>>() |
| .windows(2) |
| .map(|w| (w[0], w[1])) |
| .collect(); |
|
|
| let bigrams_b: std::collections::HashSet<(char, char)> = b_lower |
| .chars() |
| .collect::<Vec<_>>() |
| .windows(2) |
| .map(|w| (w[0], w[1])) |
| .collect(); |
|
|
| if bigrams_a.is_empty() || bigrams_b.is_empty() { |
| return 0.0; |
| } |
|
|
| let intersection = bigrams_a.intersection(&bigrams_b).count(); |
| let union = bigrams_a.union(&bigrams_b).count(); |
|
|
| if union == 0 { |
| return 0.0; |
| } |
|
|
| |
| let words_a: std::collections::HashSet<&str> = a_lower.split_whitespace().collect(); |
| let words_b: std::collections::HashSet<&str> = b_lower.split_whitespace().collect(); |
| let word_overlap = words_a.intersection(&words_b).count(); |
| let word_total = words_a.union(&words_b).count().max(1); |
|
|
| let bigram_score = intersection as f64 / union as f64; |
| let word_score = word_overlap as f64 / word_total as f64; |
|
|
| |
| 0.7 * bigram_score + 0.3 * word_score |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_similarity_identical() { |
| assert!((str_similarity("hello world", "hello world") - 1.0).abs() < 0.01); |
| } |
|
|
| #[test] |
| fn test_similarity_different() { |
| assert!(str_similarity("rust", "python") < 0.3); |
| } |
|
|
| #[test] |
| fn test_similarity_similar() { |
| let score = str_similarity("write a calculator in rust", "write a calc in rust"); |
| assert!(score > 0.5, "score was {score}"); |
| } |
|
|
| #[test] |
| fn test_kb_lookup_exact() { |
| let entries = vec![KnowledgeEntry { |
| id: "test_001".into(), |
| hashtags: vec!["rust".into(), "make".into(), "math".into()], |
| question: "Напиши калькулятор на Rust".into(), |
| question_en: "Write a calculator in Rust".into(), |
| answer: "Here is a Rust calculator...".into(), |
| language: "ru".into(), |
| }]; |
| let index = KnowledgeIndex::from_entries(entries).unwrap(); |
| let result = index.lookup( |
| "напиши калькулятор на раст", |
| "write a calculator in rust", |
| &["#rust".into(), "#make".into(), "#math".into()], |
| ); |
| match result { |
| KbLookup::Hit { score, .. } => assert!(score > 0.7), |
| _ => panic!("Expected Hit, got {:?}", result), |
| } |
| } |
| } |
|
|