| |
| |
| |
| |
| |
| |
|
|
| use anyhow::Result; |
|
|
| use crate::config::Config; |
| use crate::engine::{Engine, KvCacheConfig}; |
| use crate::pipeline::{ConversationTurn, Pipeline, PipelineResult}; |
| use llama_cpp_2::context::params::KvCacheType; |
|
|
| |
| #[derive(Debug, Clone)] |
| pub enum Message { |
| User(String), |
| Assistant(String), |
| } |
|
|
| |
| #[allow(dead_code)] |
| pub struct InferenceEngine { |
| engine: Engine, |
| pipeline: Pipeline, |
| config: Config, |
| active_expert: String, |
| conversation: Vec<Message>, |
| |
| total_queries: u64, |
| total_cache_hits: u64, |
| } |
|
|
| impl InferenceEngine { |
| |
| pub fn new(config: Config) -> Result<Self> { |
| tracing::info!("Initialising Sential engine with llama.cpp backend"); |
|
|
| |
| let n_gpu_layers: u32 = 20; |
| let n_ctx: u32 = config.max_seq_len as u32; |
|
|
| |
| let kv_config = KvCacheConfig { |
| cache_type_k: parse_cache_type(&config.kv_cache_type_k), |
| cache_type_v: parse_cache_type(&config.kv_cache_type_v), |
| offload_kqv: config.kv_offload_kqv, |
| defrag_thold: config.kv_defrag_thold, |
| }; |
|
|
| |
| let engine = |
| Engine::new_with_kv_config(&config.base_model_path, n_gpu_layers, n_ctx, kv_config)?; |
|
|
| |
| for expert in &config.experts { |
| if let Some(adapter_file) = &expert.adapter_file { |
| |
| let gguf_path = if adapter_file.ends_with(".gguf") { |
| config.adapters_dir.join(adapter_file) |
| } else { |
| let stem = adapter_file.trim_end_matches(".safetensors"); |
| config.adapters_dir.join(format!("{stem}.gguf")) |
| }; |
|
|
| if !gguf_path.exists() { |
| tracing::warn!( |
| "Adapter GGUF not found: {}. Skipping expert '{}'.", |
| gguf_path.display(), |
| expert.name, |
| ); |
| continue; |
| } |
|
|
| let scale: f32 = 1.0; |
| engine.register_adapter(&expert.name, &gguf_path, scale)?; |
| tracing::info!( |
| " Registered adapter '{}' -> {}", |
| expert.name, |
| gguf_path.display() |
| ); |
| } |
| } |
|
|
| |
| let kb_path = config.kb_path.clone(); |
| let pipeline = Pipeline::new(config.clone(), kb_path)?; |
| tracing::info!( |
| "Pipeline initialised: KB entries={}, translate={}, cache={}", |
| pipeline.kb_len(), |
| true, |
| pipeline.has_kb(), |
| ); |
|
|
| Ok(Self { |
| engine, |
| pipeline, |
| active_expert: "general".to_string(), |
| conversation: Vec::new(), |
| config, |
| total_queries: 0, |
| total_cache_hits: 0, |
| }) |
| } |
|
|
| |
| pub fn process_query(&mut self, query: &str) -> Result<String> { |
| self.process_query_with_expert(query, None) |
| } |
|
|
| |
| pub fn process_query_with_expert( |
| &mut self, |
| query: &str, |
| expert_override: Option<&str>, |
| ) -> Result<String> { |
| self.total_queries += 1; |
| tracing::info!("Processing query through pipeline..."); |
|
|
| |
| let history: Vec<ConversationTurn> = self.build_conversation_turns(); |
| let result: PipelineResult = |
| self.pipeline |
| .run(query, &mut self.engine, expert_override, &history)?; |
|
|
| |
| if result.from_cache { |
| self.total_cache_hits += 1; |
| } |
|
|
| |
| tracing::info!( |
| "Pipeline timing: hash={}µs tr={}µs kb={}µs route={}µs gen={}ms total={}ms | cache={} | expert={}", |
| result.timing.hashtag_ms * 1000, |
| result.timing.translate_ms * 1000, |
| result.timing.kb_lookup_ms * 1000, |
| result.timing.routing_ms * 1000, |
| result.timing.generation_ms, |
| result.timing.total_ms, |
| if result.from_cache { "HIT" } else { "MISS" }, |
| result.expert, |
| ); |
|
|
| |
| self.conversation.push(Message::User(query.to_string())); |
| self.conversation |
| .push(Message::Assistant(result.response.clone())); |
| self.active_expert = result.expert.clone(); |
|
|
| tracing::info!("Response ready ({} chars)", result.response.len()); |
| Ok(result.response) |
| } |
|
|
| |
| fn build_conversation_turns(&self) -> Vec<ConversationTurn> { |
| let mut turns = Vec::new(); |
| let mut i = 0; |
| while i + 1 < self.conversation.len() { |
| if let (Message::User(user), Message::Assistant(assistant)) = |
| (&self.conversation[i], &self.conversation[i + 1]) |
| { |
| turns.push(ConversationTurn { |
| user: user.clone(), |
| assistant: assistant.clone(), |
| }); |
| } |
| i += 2; |
| } |
| turns |
| } |
|
|
| |
| pub fn reset(&mut self) { |
| self.conversation.clear(); |
| self.active_expert = "general".to_string(); |
| let _ = self.engine.remove_adapter(); |
| } |
|
|
| pub fn active_expert(&self) -> &str { |
| &self.active_expert |
| } |
|
|
| pub fn stats(&self) -> serde_json::Value { |
| serde_json::json!({ |
| "active_expert": self.active_expert, |
| "conversation_length": self.conversation.len(), |
| "gpu_active": self.engine.is_gpu_active(), |
| "pipeline": { |
| "total_queries": self.total_queries, |
| "cache_hits": self.total_cache_hits, |
| "cache_hit_rate": if self.total_queries > 0 { |
| format!("{:.1}%", 100.0 * self.total_cache_hits as f64 / self.total_queries as f64) |
| } else { |
| "0%".to_string() |
| }, |
| "kb_entries": self.pipeline.kb_len(), |
| }, |
| "engine_stats": { |
| "total_prompts": self.engine.stats().total_prompts, |
| "total_tokens": self.engine.stats().total_tokens_generated, |
| "avg_tokens_per_second": self.engine.stats().avg_tokens_per_second, |
| } |
| }) |
| } |
|
|
| |
| #[allow(dead_code)] |
| pub fn kv_cache_info(&self) -> String { |
| format!( |
| "KV-cache: K={} V={} offload_kqv={} defrag={:.1}", |
| self.config.kv_cache_type_k, |
| self.config.kv_cache_type_v, |
| self.config.kv_offload_kqv, |
| self.config.kv_defrag_thold, |
| ) |
| } |
|
|
| |
| pub fn pipeline_info(&self) -> String { |
| format!( |
| "Pipeline: KB={} entries, Cache hits={}/{}, Hashtag extractor=on, Translator=on", |
| self.pipeline.kb_len(), |
| self.total_cache_hits, |
| self.total_queries, |
| ) |
| } |
| } |
|
|
| |
| fn parse_cache_type(s: &str) -> KvCacheType { |
| match s.to_lowercase().as_str() { |
| "q4_0" => KvCacheType::Q4_0, |
| "q4_1" => KvCacheType::Q4_1, |
| "q5_0" => KvCacheType::Q5_0, |
| "q5_1" => KvCacheType::Q5_1, |
| "q8_0" => KvCacheType::Q8_0, |
| "q8_1" => KvCacheType::Q8_1, |
| "q2_k" => KvCacheType::Q2_K, |
| "q3_k" => KvCacheType::Q3_K, |
| "q4_k" => KvCacheType::Q4_K, |
| "q5_k" => KvCacheType::Q5_K, |
| "q6_k" => KvCacheType::Q6_K, |
| "iq4_nl" => KvCacheType::IQ4_NL, |
| "f16" => KvCacheType::F16, |
| "f32" => KvCacheType::F32, |
| _ => { |
| tracing::warn!("Unknown KV-cache type '{s}', falling back to Q4_0"); |
| KvCacheType::Q4_0 |
| } |
| } |
| } |
|
|