SelentialCore / src /inference.rs
S4ntyC1t's picture
Upload 23 files
18e0633 verified
Raw
History Blame Contribute Delete
8.86 kB
//! Inference pipeline — orchestrates routing + generation via Sential Engine.
//!
//! 1. Pipeline pre-processes: hashtags, language detection, KB cache lookup
//! 2. Router classifies the query (keyword + hashtag matching)
//! 3. Engine (llama.cpp) generates with or without LoRA adapter
//! 4. Chat history maintained for context
use anyhow::Result;
use crate::config::Config;
use crate::engine::{Engine, KvCacheConfig};
use crate::pipeline::{ConversationTurn, Pipeline, PipelineResult};
use llama_cpp_2::context::params::KvCacheType;
/// Chat message types
#[derive(Debug, Clone)]
pub enum Message {
User(String),
Assistant(String),
}
/// Inference engine: routes queries, generates with LoRA adapters.
#[allow(dead_code)]
pub struct InferenceEngine {
engine: Engine,
pipeline: Pipeline,
config: Config,
active_expert: String,
conversation: Vec<Message>,
/// Accumulated pipeline stats
total_queries: u64,
total_cache_hits: u64,
}
impl InferenceEngine {
/// Initialise: load base model into Engine, register adapters, set up pipeline.
pub fn new(config: Config) -> Result<Self> {
tracing::info!("Initialising Sential engine with llama.cpp backend");
// Offload most layers to GPU (20/25 for Qwen3.5-0.8B, leaves headroom for compute buffers on 6 GB VRAM)
let n_gpu_layers: u32 = 20;
let n_ctx: u32 = config.max_seq_len as u32;
// Build KV-cache config from Config
let kv_config = KvCacheConfig {
cache_type_k: parse_cache_type(&config.kv_cache_type_k),
cache_type_v: parse_cache_type(&config.kv_cache_type_v),
offload_kqv: config.kv_offload_kqv,
defrag_thold: config.kv_defrag_thold,
};
// Initialise Rust-native engine with KV-cache optimizations
let engine =
Engine::new_with_kv_config(&config.base_model_path, n_gpu_layers, n_ctx, kv_config)?;
// Register all LoRA adapters
for expert in &config.experts {
if let Some(adapter_file) = &expert.adapter_file {
// Support both .gguf (new) and .safetensors (legacy) extensions
let gguf_path = if adapter_file.ends_with(".gguf") {
config.adapters_dir.join(adapter_file)
} else {
let stem = adapter_file.trim_end_matches(".safetensors");
config.adapters_dir.join(format!("{stem}.gguf"))
};
if !gguf_path.exists() {
tracing::warn!(
"Adapter GGUF not found: {}. Skipping expert '{}'.",
gguf_path.display(),
expert.name,
);
continue;
}
let scale: f32 = 1.0; // Standard LoRA scale
engine.register_adapter(&expert.name, &gguf_path, scale)?;
tracing::info!(
" Registered adapter '{}' -> {}",
expert.name,
gguf_path.display()
);
}
}
// Initialise pipeline (with KB cache)
let kb_path = config.kb_path.clone();
let pipeline = Pipeline::new(config.clone(), kb_path)?;
tracing::info!(
"Pipeline initialised: KB entries={}, translate={}, cache={}",
pipeline.kb_len(),
true,
pipeline.has_kb(),
);
Ok(Self {
engine,
pipeline,
active_expert: "general".to_string(),
conversation: Vec::new(),
config,
total_queries: 0,
total_cache_hits: 0,
})
}
/// Process a user query (auto-route).
pub fn process_query(&mut self, query: &str) -> Result<String> {
self.process_query_with_expert(query, None)
}
/// Process a query with an optional expert override.
pub fn process_query_with_expert(
&mut self,
query: &str,
expert_override: Option<&str>,
) -> Result<String> {
self.total_queries += 1;
tracing::info!("Processing query through pipeline...");
// Run the full pipeline (preprocess → KB lookup → route → generate)
let history: Vec<ConversationTurn> = self.build_conversation_turns();
let result: PipelineResult =
self.pipeline
.run(query, &mut self.engine, expert_override, &history)?;
// Track cache hits
if result.from_cache {
self.total_cache_hits += 1;
}
// Log timing
tracing::info!(
"Pipeline timing: hash={}µs tr={}µs kb={}µs route={}µs gen={}ms total={}ms | cache={} | expert={}",
result.timing.hashtag_ms * 1000,
result.timing.translate_ms * 1000,
result.timing.kb_lookup_ms * 1000,
result.timing.routing_ms * 1000,
result.timing.generation_ms,
result.timing.total_ms,
if result.from_cache { "HIT" } else { "MISS" },
result.expert,
);
// Update conversation history
self.conversation.push(Message::User(query.to_string()));
self.conversation
.push(Message::Assistant(result.response.clone()));
self.active_expert = result.expert.clone();
tracing::info!("Response ready ({} chars)", result.response.len());
Ok(result.response)
}
/// Convert conversation Message pairs into ConversationTurn slices for the pipeline.
fn build_conversation_turns(&self) -> Vec<ConversationTurn> {
let mut turns = Vec::new();
let mut i = 0;
while i + 1 < self.conversation.len() {
if let (Message::User(user), Message::Assistant(assistant)) =
(&self.conversation[i], &self.conversation[i + 1])
{
turns.push(ConversationTurn {
user: user.clone(),
assistant: assistant.clone(),
});
}
i += 2;
}
turns
}
/// Reset conversation.
pub fn reset(&mut self) {
self.conversation.clear();
self.active_expert = "general".to_string();
let _ = self.engine.remove_adapter();
}
pub fn active_expert(&self) -> &str {
&self.active_expert
}
pub fn stats(&self) -> serde_json::Value {
serde_json::json!({
"active_expert": self.active_expert,
"conversation_length": self.conversation.len(),
"gpu_active": self.engine.is_gpu_active(),
"pipeline": {
"total_queries": self.total_queries,
"cache_hits": self.total_cache_hits,
"cache_hit_rate": if self.total_queries > 0 {
format!("{:.1}%", 100.0 * self.total_cache_hits as f64 / self.total_queries as f64)
} else {
"0%".to_string()
},
"kb_entries": self.pipeline.kb_len(),
},
"engine_stats": {
"total_prompts": self.engine.stats().total_prompts,
"total_tokens": self.engine.stats().total_tokens_generated,
"avg_tokens_per_second": self.engine.stats().avg_tokens_per_second,
}
})
}
/// Get KV-cache configuration summary
#[allow(dead_code)]
pub fn kv_cache_info(&self) -> String {
format!(
"KV-cache: K={} V={} offload_kqv={} defrag={:.1}",
self.config.kv_cache_type_k,
self.config.kv_cache_type_v,
self.config.kv_offload_kqv,
self.config.kv_defrag_thold,
)
}
/// Get pipeline info for display
pub fn pipeline_info(&self) -> String {
format!(
"Pipeline: KB={} entries, Cache hits={}/{}, Hashtag extractor=on, Translator=on",
self.pipeline.kb_len(),
self.total_cache_hits,
self.total_queries,
)
}
}
/// Parse KV-cache type string to KvCacheType enum.
fn parse_cache_type(s: &str) -> KvCacheType {
match s.to_lowercase().as_str() {
"q4_0" => KvCacheType::Q4_0,
"q4_1" => KvCacheType::Q4_1,
"q5_0" => KvCacheType::Q5_0,
"q5_1" => KvCacheType::Q5_1,
"q8_0" => KvCacheType::Q8_0,
"q8_1" => KvCacheType::Q8_1,
"q2_k" => KvCacheType::Q2_K,
"q3_k" => KvCacheType::Q3_K,
"q4_k" => KvCacheType::Q4_K,
"q5_k" => KvCacheType::Q5_K,
"q6_k" => KvCacheType::Q6_K,
"iq4_nl" => KvCacheType::IQ4_NL,
"f16" => KvCacheType::F16,
"f32" => KvCacheType::F32,
_ => {
tracing::warn!("Unknown KV-cache type '{s}', falling back to Q4_0");
KvCacheType::Q4_0
}
}
}