//! Sential Engine — Rust-native inference with llama.cpp backend. //! //! The heart of Phase 1: //! - Model lives in-process (no subprocess, no Python overhead) //! - GGUF on-the-fly dequantization (your architecture — built into llama.cpp) //! - Runtime LoRA hot-swap via `model.lora_adapter_init()` + `ctx.lora_adapter_set()` //! - ~1.3 GB VRAM saved vs PyTorch use std::collections::HashMap; use std::num::NonZeroU32; use std::path::{Path, PathBuf}; use std::ptr::NonNull; use std::sync::Mutex; use std::time::Instant; use anyhow::{bail, Context, Result}; use llama_cpp_2::context::params::{KvCacheType, LlamaContextParams}; use llama_cpp_2::context::LlamaContext; use llama_cpp_2::llama_backend::LlamaBackend; use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::{AddBos, LlamaLoraAdapter, LlamaModel}; use llama_cpp_2::sampling::LlamaSampler; use llama_cpp_2::token::LlamaToken; // ─── Registered Adapter (path + scale) ───────────────────────────────────── #[derive(Clone)] struct AdapterInfo { path: PathBuf, scale: f32, } // ─── Internal Mutable Context ────────────────────────────────────────────── struct ContextState { ctx: LlamaContext<'static>, sampler: LlamaSampler, active_adapter: Option, adapters: HashMap, } // ─── Statistics ──────────────────────────────────────────────────────────── #[derive(Debug, Clone, Default)] pub struct EngineStats { pub total_prompts: u64, pub total_tokens_generated: u64, pub total_generation_time_ms: u64, pub avg_tokens_per_second: f64, } // ─── KV-Cache Configuration ───────────────────────────────────────────────── /// KV-Cache configuration for memory optimization. #[derive(Debug, Clone)] pub struct KvCacheConfig { /// KV cache quantization type for keys (Q4_0 = 4-bit, saves ~75% VRAM vs F16) pub cache_type_k: KvCacheType, /// KV cache quantization type for values pub cache_type_v: KvCacheType, /// Offload K, Q, V tensors to GPU (faster but uses VRAM) pub offload_kqv: bool, /// KV cache defrag threshold (-1.0 = disabled, 0.1 = aggressive) pub defrag_thold: f32, } impl Default for KvCacheConfig { fn default() -> Self { Self { cache_type_k: KvCacheType::Q4_0, cache_type_v: KvCacheType::Q4_0, offload_kqv: true, defrag_thold: -1.0, // disabled: llama.cpp manages cache internally } } } // ─── Engine ──────────────────────────────────────────────────────────────── // // ⚠️ Field order matters for Drop safety: // `context` (which contains LlamaContext<'_> borrowing from model) // MUST be dropped BEFORE `model`. Rust drops fields in declaration order. pub struct Engine { _backend: LlamaBackend, /// Inference context — dropped FIRST (before model). context: Mutex, /// Base model — dropped SECOND (after context, so the &LlamaModel ref stays valid). model: LlamaModel, _base_model_path: PathBuf, supports_gpu: bool, stats: EngineStats, } #[allow(dead_code)] impl Engine { /// Load base model and create inference context. pub fn new(base_model_path: &Path, n_gpu_layers: u32, n_ctx: u32) -> Result { Self::new_with_kv_config( base_model_path, n_gpu_layers, n_ctx, KvCacheConfig::default(), ) } /// Load base model with custom KV-cache configuration. pub fn new_with_kv_config( base_model_path: &Path, n_gpu_layers: u32, n_ctx: u32, kv_config: KvCacheConfig, ) -> Result { let start = Instant::now(); tracing::info!("╔══════════════════════════════════════════╗"); tracing::info!("║ Sential Engine — llama.cpp backend ║"); tracing::info!("╚══════════════════════════════════════════╝"); // 1. Backend let backend = LlamaBackend::init().context("Failed to init llama.cpp backend")?; // 2. GPU check let gpu_ok = backend.supports_gpu_offload(); tracing::info!("GPU offload: {}", if gpu_ok { "✅" } else { "❌" }); // 3. Load model let model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers); tracing::info!("Loading model: {}", base_model_path.display()); tracing::info!(" n_gpu_layers: {}, n_ctx: {}", n_gpu_layers, n_ctx); let model = LlamaModel::load_from_file(&backend, base_model_path, &model_params).context( format!("Failed to load model from {}", base_model_path.display()), )?; tracing::info!( " {:.2}B params, ctx: {}, layers: {}, embd: {}", model.n_params() as f64 / 1_000_000_000.0, model.n_ctx_train(), model.n_layer(), model.n_embd(), ); // 4. Context with KV-cache optimizations let ctx_params = LlamaContextParams::default() .with_n_ctx(NonZeroU32::new(n_ctx)) // KV-cache quantization: Q4_0 = 4-bit, saves ~75% VRAM vs F16 // This is the single most effective VRAM optimization .with_type_k(kv_config.cache_type_k) .with_type_v(kv_config.cache_type_v) // Offload K, Q, V to GPU for faster attention computation .with_offload_kqv(kv_config.offload_kqv) // Defrag threshold: -1.0 disables (llama.cpp handles internally) .with_defrag_thold(kv_config.defrag_thold); tracing::info!( " KV-cache: K={:?} V={:?} offload_kqv={} defrag={:.1}", kv_config.cache_type_k, kv_config.cache_type_v, kv_config.offload_kqv, kv_config.defrag_thold, ); // new_context borrows from model (returns LlamaContext<'_>). // Both model and context live in this struct; context is dropped first. let ctx = model .new_context(&backend, ctx_params) .context("Failed to create inference context")?; // Safety: transmute to 'static since both live in Engine, context dropped before model. let ctx_static: LlamaContext<'static> = unsafe { std::mem::transmute(ctx) }; // 5. Default sampler (will be reconfigured per generation) let sampler = LlamaSampler::greedy(); // Log KV-cache size — both uncompressed (F16) and compressed with Q4_0 let n_layers = model.n_layer() as usize; let n_embd_head = (model.n_embd() as usize) / (model.n_head() as usize); let n_head_kv = model.n_head_kv() as usize; // F16: K+V per token per layer = n_embd_head × n_head_kv × 2(KV) × 2(bytes_per_f16) let kv_fp16_mb = (n_layers * n_embd_head * n_head_kv * 2 * 2 * n_ctx as usize) as f64 / (1024.0 * 1024.0); // Q4_0: 4 bits = 0.5 bytes per element vs 2 bytes for F16 → 0.25× // Plus ~1/32 overhead for block scale factors (one f16 per 32 elements) let kv_q4_mb = kv_fp16_mb * 0.25 * (1.0 + 1.0 / 32.0); tracing::info!( " KV-cache ({} ctx): {:.1} MB F16 → ~{:.1} MB Q4_0 (~{:.0}% savings)", n_ctx, kv_fp16_mb, kv_q4_mb, (1.0 - kv_q4_mb / kv_fp16_mb) * 100.0, ); tracing::info!("Engine ready in {:.1}s", start.elapsed().as_secs_f64()); Ok(Self { _backend: backend, // context before model → dropped first → model reference stays valid context: Mutex::new(ContextState { ctx: ctx_static, sampler, active_adapter: None, adapters: HashMap::new(), }), model, _base_model_path: base_model_path.to_path_buf(), supports_gpu: gpu_ok, stats: EngineStats::default(), }) } // ─── LoRA Management ───────────────────────────────────────────────── /// Register a LoRA adapter (must be in GGUF format). pub fn register_adapter(&self, name: &str, gguf_path: &Path, scale: f32) -> Result<()> { if !gguf_path.exists() { bail!("LoRA GGUF not found: {}", gguf_path.display()); } let mut state = self.context.lock().unwrap(); state.adapters.insert( name.to_string(), AdapterInfo { path: gguf_path.to_path_buf(), scale, }, ); tracing::info!("Registered adapter '{}' -> {}", name, gguf_path.display()); Ok(()) } /// Apply a LoRA adapter at runtime using the safe llama-cpp-2 API. pub fn apply_adapter(&self, name: &str) -> Result<()> { let mut state = self.context.lock().unwrap(); let info = state .adapters .get(name) .cloned() .context(format!("Adapter '{name}' not registered"))?; tracing::info!("Applying LoRA adapter: {name}"); // Load LoRA adapter via safe wrapper let mut lora_adapter = self .model .lora_adapter_init(info.path.to_str().context("Invalid UTF-8 in path")?) .context(format!("Failed to init adapter '{name}'"))?; // Apply to context state .ctx .lora_adapter_set(&mut lora_adapter, info.scale) .context(format!("Failed to set adapter '{name}'"))?; // Ownership of the raw pointer has been transferred to llama.cpp context. // Forget our wrapper to avoid double-free on drop. std::mem::forget(lora_adapter); state.active_adapter = Some(name.to_string()); tracing::info!("Adapter '{name}' applied ✅"); Ok(()) } /// Remove active LoRA adapter (revert to base model). pub fn remove_adapter(&self) -> Result<()> { let mut state = self.context.lock().unwrap(); if state.active_adapter.is_none() { return Ok(()); } tracing::info!("Removing LoRA adapter..."); // lora_adapter_remove needs a &mut LlamaLoraAdapter but the parameter is unused. // Create a dummy from NonNull::dangling() — safe: never dereferenced, then forgotten. let mut dummy_adapter: LlamaLoraAdapter = unsafe { std::mem::transmute(NonNull::::dangling()) }; state .ctx .lora_adapter_remove(&mut dummy_adapter) .context("Failed to remove adapter")?; // dummy was never actually loaded, forget to avoid freeing invalid memory. std::mem::forget(dummy_adapter); state.active_adapter = None; tracing::info!("LoRA adapter removed, base model restored"); Ok(()) } /// Currently active adapter name. pub fn active_adapter(&self) -> Option { self.context.lock().unwrap().active_adapter.clone() } /// List all registered adapters. pub fn list_adapters(&self) -> Vec<(String, PathBuf)> { self.context .lock() .unwrap() .adapters .iter() .map(|(n, a)| (n.clone(), a.path.clone())) .collect() } // ─── Generation ────────────────────────────────────────────────────── /// Generate text with full sampling control. /// /// Temperature 0.0 = greedy. top_p 0.0 = disabled. top_k 0 = disabled. pub fn generate( &mut self, prompt: &str, max_tokens: u32, temperature: f32, top_p: f32, top_k: i32, ) -> Result { let gen_start = Instant::now(); let mut state = self.context.lock().unwrap(); // 0. Clear KV-cache — prevent position mismatch errors when switching // adapters or running multiple turns in interactive mode. // M-RoPE (used by Qwen3) requires strictly increasing positions; // without clearing, old cache entries (positions 0..N) conflict // with the new batch starting from position 0. state.ctx.clear_kv_cache(); // 1. Tokenize let tokens = self .model .str_to_token(prompt, AddBos::Always) .context("Failed to tokenize prompt")?; let n_prompt = tokens.len(); if n_prompt == 0 { bail!("Prompt produced 0 tokens"); } tracing::debug!("Prompt: {n_prompt} tokens"); // 2. Context-size check with auto-truncation // Fix: cap max_tokens so prompt always has room; ensure truncation converges let n_ctx = state.ctx.n_ctx() as usize; let effective_max = (max_tokens as usize).min(n_ctx.saturating_sub(64).max(32)); // at least 32 tokens for prompt if n_prompt + effective_max > n_ctx { // Drop the lock before recursing to avoid deadlock drop(state); let keep = (n_ctx - effective_max).max(32); // guaranteed positive: effective_max <= n_ctx-32 tracing::warn!( "Prompt too long ({n_prompt} tok, max_gen={effective_max}, n_ctx={n_ctx}). Truncating to {keep} tokens." ); let truncated = self .detokenize_tokens(&tokens[tokens.len().saturating_sub(keep)..]) .context("Failed to decode truncated prompt")?; return self.generate(&truncated, effective_max as u32, temperature, top_p, top_k); } // 3. Prefill — feed all prompt tokens in one batch let mut batch = LlamaBatch::new(n_prompt, 1); for (i, &token) in tokens.iter().enumerate() { let is_last = i == n_prompt - 1; batch.add(token, i as i32, &[0], is_last)?; } state .ctx .decode(&mut batch) .context("Prefill decode failed")?; // 4. Build sampler chain, swap into state (old one gets dropped) let mut new_sampler = Self::build_sampler(temperature, top_p, top_k); std::mem::swap(&mut state.sampler, &mut new_sampler); // 5. Generate loop (capped to effective_max to fit in n_ctx) let mut output_tokens: Vec = Vec::with_capacity(effective_max); let eos = self.model.token_eos(); // Position of the last batch element with logits=True let mut sample_idx = batch.n_tokens() - 1; for _step in 0..effective_max { // NOTE: MutexGuard does not support field-split borrows // through DerefMut, so we use a raw pointer to pass ctx immutably // while sampler takes &mut self on its own field. let token = { let ctx_ptr: *const llama_cpp_2::context::LlamaContext = &state.ctx; // SAFETY: ctx_ptr is valid for the duration of sample(); // sampler only reads ctx immutably. state.sampler.sample(unsafe { &*ctx_ptr }, sample_idx) }; if token == eos || self.model.is_eog_token(token) { break; } output_tokens.push(token.0); state.sampler.accept(token); let pos = (n_prompt + output_tokens.len() - 1) as i32; let mut single = LlamaBatch::new(1, 1); single.add(token, pos, &[0], true)?; state .ctx .decode(&mut single) .context("Decode failed during generation")?; sample_idx = 0; } // 6. Detokenize — use token_to_piece_bytes with 256-byte buffer // (the deprecated tokens_to_str uses only 8 bytes, too small for some tokens) let llama_tokens: Vec = output_tokens.iter().map(|&t| LlamaToken::new(t)).collect(); let output = self .detokenize_tokens(&llama_tokens) .context("Failed to detokenize")?; // 7. Stats let elapsed = gen_start.elapsed(); let tok_count = output_tokens.len() as u64; let tps = if elapsed.as_secs_f64() > 0.0 { tok_count as f64 / elapsed.as_secs_f64() } else { 0.0 }; self.stats.total_prompts += 1; self.stats.total_tokens_generated += tok_count; self.stats.total_generation_time_ms += elapsed.as_millis() as u64; let total_secs = self.stats.total_generation_time_ms as f64 / 1000.0; if total_secs > 0.0 { self.stats.avg_tokens_per_second = self.stats.total_tokens_generated as f64 / total_secs; } tracing::info!( "Generated {tok_count} tok in {:.1}s ({tps:.1} t/s) — adapter: {:?}", elapsed.as_secs_f64(), state.active_adapter, ); Ok(output) } /// Generate with optional LoRA adapter (apply → generate → remove). pub fn generate_with_adapter( &mut self, prompt: &str, max_tokens: u32, temperature: f32, top_p: f32, adapter_name: Option<&str>, ) -> Result { if let Some(adapter) = adapter_name { if adapter != "general" { if let Err(e) = self.apply_adapter(adapter) { tracing::warn!("Failed to apply adapter '{adapter}': {e}. Using base model."); } } } else { let _ = self.remove_adapter(); } let result = self.generate(prompt, max_tokens, temperature, top_p, 40); if adapter_name.is_some() && adapter_name != Some("general") { if let Err(e) = self.remove_adapter() { tracing::warn!("Failed to remove adapter: {e}"); } } result } /// Build a sampler chain from parameters. fn build_sampler(temperature: f32, top_p: f32, top_k: i32) -> LlamaSampler { if temperature <= 0.0 { return LlamaSampler::chain_simple([LlamaSampler::greedy()]); } let mut chain: Vec = Vec::new(); if top_k > 0 { chain.push(LlamaSampler::top_k(top_k)); } if top_p > 0.0 { chain.push(LlamaSampler::top_p(top_p, 1)); } chain.push(LlamaSampler::temp(temperature)); chain.push(LlamaSampler::dist(42)); LlamaSampler::chain_simple(chain) } // ─── Utility ───────────────────────────────────────────────────────── /// Detokenize a slice of LlamaToken into a String. /// Uses `token_to_piece_bytes` with 256-byte buffer per token /// (the deprecated `tokens_to_str` uses only 8 bytes, causing errors). fn detokenize_tokens(&self, tokens: &[LlamaToken]) -> Result { let mut output = String::with_capacity(tokens.len() * 4); for &token in tokens { let bytes = self .model .token_to_piece_bytes(token, 256, true, None) .context("Failed to detokenize token")?; match String::from_utf8(bytes) { Ok(s) => output.push_str(&s), Err(e) => { tracing::warn!( "Token produced invalid UTF-8: {}. Using lossy replacement.", e ); output.push_str(&String::from_utf8_lossy(e.as_bytes())); } } } Ok(output) } pub fn clear_cache(&self) { tracing::debug!("Cache clear requested (no-op, managed by llama.cpp)"); } pub fn stats(&self) -> &EngineStats { &self.stats } pub fn is_gpu_active(&self) -> bool { self.supports_gpu } pub fn model(&self) -> &LlamaModel { &self.model } } impl Drop for Engine { fn drop(&mut self) { // context (with &model reference) is dropped first because it comes first // in the struct. Then model is dropped safely. tracing::info!( "Shutdown. {} prompts, {} tokens ({:.1} t/s avg)", self.stats.total_prompts, self.stats.total_tokens_generated, self.stats.avg_tokens_per_second, ); } }