SelentialCore / src /engine.rs
S4ntyC1t's picture
Upload 23 files
18e0633 verified
Raw
History Blame Contribute Delete
21.3 kB
//! Sential Engine β€” Rust-native inference with llama.cpp backend.
//!
//! The heart of Phase 1:
//! - Model lives in-process (no subprocess, no Python overhead)
//! - GGUF on-the-fly dequantization (your architecture β€” built into llama.cpp)
//! - Runtime LoRA hot-swap via `model.lora_adapter_init()` + `ctx.lora_adapter_set()`
//! - ~1.3 GB VRAM saved vs PyTorch
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::path::{Path, PathBuf};
use std::ptr::NonNull;
use std::sync::Mutex;
use std::time::Instant;
use anyhow::{bail, Context, Result};
use llama_cpp_2::context::params::{KvCacheType, LlamaContextParams};
use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::{AddBos, LlamaLoraAdapter, LlamaModel};
use llama_cpp_2::sampling::LlamaSampler;
use llama_cpp_2::token::LlamaToken;
// ─── Registered Adapter (path + scale) ─────────────────────────────────────
#[derive(Clone)]
struct AdapterInfo {
path: PathBuf,
scale: f32,
}
// ─── Internal Mutable Context ──────────────────────────────────────────────
struct ContextState {
ctx: LlamaContext<'static>,
sampler: LlamaSampler,
active_adapter: Option<String>,
adapters: HashMap<String, AdapterInfo>,
}
// ─── Statistics ────────────────────────────────────────────────────────────
#[derive(Debug, Clone, Default)]
pub struct EngineStats {
pub total_prompts: u64,
pub total_tokens_generated: u64,
pub total_generation_time_ms: u64,
pub avg_tokens_per_second: f64,
}
// ─── KV-Cache Configuration ─────────────────────────────────────────────────
/// KV-Cache configuration for memory optimization.
#[derive(Debug, Clone)]
pub struct KvCacheConfig {
/// KV cache quantization type for keys (Q4_0 = 4-bit, saves ~75% VRAM vs F16)
pub cache_type_k: KvCacheType,
/// KV cache quantization type for values
pub cache_type_v: KvCacheType,
/// Offload K, Q, V tensors to GPU (faster but uses VRAM)
pub offload_kqv: bool,
/// KV cache defrag threshold (-1.0 = disabled, 0.1 = aggressive)
pub defrag_thold: f32,
}
impl Default for KvCacheConfig {
fn default() -> Self {
Self {
cache_type_k: KvCacheType::Q4_0,
cache_type_v: KvCacheType::Q4_0,
offload_kqv: true,
defrag_thold: -1.0, // disabled: llama.cpp manages cache internally
}
}
}
// ─── Engine ────────────────────────────────────────────────────────────────
//
// ⚠️ Field order matters for Drop safety:
// `context` (which contains LlamaContext<'_> borrowing from model)
// MUST be dropped BEFORE `model`. Rust drops fields in declaration order.
pub struct Engine {
_backend: LlamaBackend,
/// Inference context β€” dropped FIRST (before model).
context: Mutex<ContextState>,
/// Base model β€” dropped SECOND (after context, so the &LlamaModel ref stays valid).
model: LlamaModel,
_base_model_path: PathBuf,
supports_gpu: bool,
stats: EngineStats,
}
#[allow(dead_code)]
impl Engine {
/// Load base model and create inference context.
pub fn new(base_model_path: &Path, n_gpu_layers: u32, n_ctx: u32) -> Result<Self> {
Self::new_with_kv_config(
base_model_path,
n_gpu_layers,
n_ctx,
KvCacheConfig::default(),
)
}
/// Load base model with custom KV-cache configuration.
pub fn new_with_kv_config(
base_model_path: &Path,
n_gpu_layers: u32,
n_ctx: u32,
kv_config: KvCacheConfig,
) -> Result<Self> {
let start = Instant::now();
tracing::info!("╔══════════════════════════════════════════╗");
tracing::info!("β•‘ Sential Engine β€” llama.cpp backend β•‘");
tracing::info!("β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•");
// 1. Backend
let backend = LlamaBackend::init().context("Failed to init llama.cpp backend")?;
// 2. GPU check
let gpu_ok = backend.supports_gpu_offload();
tracing::info!("GPU offload: {}", if gpu_ok { "βœ…" } else { "❌" });
// 3. Load model
let model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers);
tracing::info!("Loading model: {}", base_model_path.display());
tracing::info!(" n_gpu_layers: {}, n_ctx: {}", n_gpu_layers, n_ctx);
let model = LlamaModel::load_from_file(&backend, base_model_path, &model_params).context(
format!("Failed to load model from {}", base_model_path.display()),
)?;
tracing::info!(
" {:.2}B params, ctx: {}, layers: {}, embd: {}",
model.n_params() as f64 / 1_000_000_000.0,
model.n_ctx_train(),
model.n_layer(),
model.n_embd(),
);
// 4. Context with KV-cache optimizations
let ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(n_ctx))
// KV-cache quantization: Q4_0 = 4-bit, saves ~75% VRAM vs F16
// This is the single most effective VRAM optimization
.with_type_k(kv_config.cache_type_k)
.with_type_v(kv_config.cache_type_v)
// Offload K, Q, V to GPU for faster attention computation
.with_offload_kqv(kv_config.offload_kqv)
// Defrag threshold: -1.0 disables (llama.cpp handles internally)
.with_defrag_thold(kv_config.defrag_thold);
tracing::info!(
" KV-cache: K={:?} V={:?} offload_kqv={} defrag={:.1}",
kv_config.cache_type_k,
kv_config.cache_type_v,
kv_config.offload_kqv,
kv_config.defrag_thold,
);
// new_context borrows from model (returns LlamaContext<'_>).
// Both model and context live in this struct; context is dropped first.
let ctx = model
.new_context(&backend, ctx_params)
.context("Failed to create inference context")?;
// Safety: transmute to 'static since both live in Engine, context dropped before model.
let ctx_static: LlamaContext<'static> = unsafe { std::mem::transmute(ctx) };
// 5. Default sampler (will be reconfigured per generation)
let sampler = LlamaSampler::greedy();
// Log KV-cache size β€” both uncompressed (F16) and compressed with Q4_0
let n_layers = model.n_layer() as usize;
let n_embd_head = (model.n_embd() as usize) / (model.n_head() as usize);
let n_head_kv = model.n_head_kv() as usize;
// F16: K+V per token per layer = n_embd_head Γ— n_head_kv Γ— 2(KV) Γ— 2(bytes_per_f16)
let kv_fp16_mb = (n_layers * n_embd_head * n_head_kv * 2 * 2 * n_ctx as usize) as f64
/ (1024.0 * 1024.0);
// Q4_0: 4 bits = 0.5 bytes per element vs 2 bytes for F16 β†’ 0.25Γ—
// Plus ~1/32 overhead for block scale factors (one f16 per 32 elements)
let kv_q4_mb = kv_fp16_mb * 0.25 * (1.0 + 1.0 / 32.0);
tracing::info!(
" KV-cache ({} ctx): {:.1} MB F16 β†’ ~{:.1} MB Q4_0 (~{:.0}% savings)",
n_ctx,
kv_fp16_mb,
kv_q4_mb,
(1.0 - kv_q4_mb / kv_fp16_mb) * 100.0,
);
tracing::info!("Engine ready in {:.1}s", start.elapsed().as_secs_f64());
Ok(Self {
_backend: backend,
// context before model β†’ dropped first β†’ model reference stays valid
context: Mutex::new(ContextState {
ctx: ctx_static,
sampler,
active_adapter: None,
adapters: HashMap::new(),
}),
model,
_base_model_path: base_model_path.to_path_buf(),
supports_gpu: gpu_ok,
stats: EngineStats::default(),
})
}
// ─── LoRA Management ─────────────────────────────────────────────────
/// Register a LoRA adapter (must be in GGUF format).
pub fn register_adapter(&self, name: &str, gguf_path: &Path, scale: f32) -> Result<()> {
if !gguf_path.exists() {
bail!("LoRA GGUF not found: {}", gguf_path.display());
}
let mut state = self.context.lock().unwrap();
state.adapters.insert(
name.to_string(),
AdapterInfo {
path: gguf_path.to_path_buf(),
scale,
},
);
tracing::info!("Registered adapter '{}' -> {}", name, gguf_path.display());
Ok(())
}
/// Apply a LoRA adapter at runtime using the safe llama-cpp-2 API.
pub fn apply_adapter(&self, name: &str) -> Result<()> {
let mut state = self.context.lock().unwrap();
let info = state
.adapters
.get(name)
.cloned()
.context(format!("Adapter '{name}' not registered"))?;
tracing::info!("Applying LoRA adapter: {name}");
// Load LoRA adapter via safe wrapper
let mut lora_adapter = self
.model
.lora_adapter_init(info.path.to_str().context("Invalid UTF-8 in path")?)
.context(format!("Failed to init adapter '{name}'"))?;
// Apply to context
state
.ctx
.lora_adapter_set(&mut lora_adapter, info.scale)
.context(format!("Failed to set adapter '{name}'"))?;
// Ownership of the raw pointer has been transferred to llama.cpp context.
// Forget our wrapper to avoid double-free on drop.
std::mem::forget(lora_adapter);
state.active_adapter = Some(name.to_string());
tracing::info!("Adapter '{name}' applied βœ…");
Ok(())
}
/// Remove active LoRA adapter (revert to base model).
pub fn remove_adapter(&self) -> Result<()> {
let mut state = self.context.lock().unwrap();
if state.active_adapter.is_none() {
return Ok(());
}
tracing::info!("Removing LoRA adapter...");
// lora_adapter_remove needs a &mut LlamaLoraAdapter but the parameter is unused.
// Create a dummy from NonNull::dangling() β€” safe: never dereferenced, then forgotten.
let mut dummy_adapter: LlamaLoraAdapter = unsafe {
std::mem::transmute(NonNull::<llama_cpp_sys_2::llama_adapter_lora>::dangling())
};
state
.ctx
.lora_adapter_remove(&mut dummy_adapter)
.context("Failed to remove adapter")?;
// dummy was never actually loaded, forget to avoid freeing invalid memory.
std::mem::forget(dummy_adapter);
state.active_adapter = None;
tracing::info!("LoRA adapter removed, base model restored");
Ok(())
}
/// Currently active adapter name.
pub fn active_adapter(&self) -> Option<String> {
self.context.lock().unwrap().active_adapter.clone()
}
/// List all registered adapters.
pub fn list_adapters(&self) -> Vec<(String, PathBuf)> {
self.context
.lock()
.unwrap()
.adapters
.iter()
.map(|(n, a)| (n.clone(), a.path.clone()))
.collect()
}
// ─── Generation ──────────────────────────────────────────────────────
/// Generate text with full sampling control.
///
/// Temperature 0.0 = greedy. top_p 0.0 = disabled. top_k 0 = disabled.
pub fn generate(
&mut self,
prompt: &str,
max_tokens: u32,
temperature: f32,
top_p: f32,
top_k: i32,
) -> Result<String> {
let gen_start = Instant::now();
let mut state = self.context.lock().unwrap();
// 0. Clear KV-cache β€” prevent position mismatch errors when switching
// adapters or running multiple turns in interactive mode.
// M-RoPE (used by Qwen3) requires strictly increasing positions;
// without clearing, old cache entries (positions 0..N) conflict
// with the new batch starting from position 0.
state.ctx.clear_kv_cache();
// 1. Tokenize
let tokens = self
.model
.str_to_token(prompt, AddBos::Always)
.context("Failed to tokenize prompt")?;
let n_prompt = tokens.len();
if n_prompt == 0 {
bail!("Prompt produced 0 tokens");
}
tracing::debug!("Prompt: {n_prompt} tokens");
// 2. Context-size check with auto-truncation
// Fix: cap max_tokens so prompt always has room; ensure truncation converges
let n_ctx = state.ctx.n_ctx() as usize;
let effective_max = (max_tokens as usize).min(n_ctx.saturating_sub(64).max(32)); // at least 32 tokens for prompt
if n_prompt + effective_max > n_ctx {
// Drop the lock before recursing to avoid deadlock
drop(state);
let keep = (n_ctx - effective_max).max(32); // guaranteed positive: effective_max <= n_ctx-32
tracing::warn!(
"Prompt too long ({n_prompt} tok, max_gen={effective_max}, n_ctx={n_ctx}). Truncating to {keep} tokens."
);
let truncated = self
.detokenize_tokens(&tokens[tokens.len().saturating_sub(keep)..])
.context("Failed to decode truncated prompt")?;
return self.generate(&truncated, effective_max as u32, temperature, top_p, top_k);
}
// 3. Prefill β€” feed all prompt tokens in one batch
let mut batch = LlamaBatch::new(n_prompt, 1);
for (i, &token) in tokens.iter().enumerate() {
let is_last = i == n_prompt - 1;
batch.add(token, i as i32, &[0], is_last)?;
}
state
.ctx
.decode(&mut batch)
.context("Prefill decode failed")?;
// 4. Build sampler chain, swap into state (old one gets dropped)
let mut new_sampler = Self::build_sampler(temperature, top_p, top_k);
std::mem::swap(&mut state.sampler, &mut new_sampler);
// 5. Generate loop (capped to effective_max to fit in n_ctx)
let mut output_tokens: Vec<i32> = Vec::with_capacity(effective_max);
let eos = self.model.token_eos();
// Position of the last batch element with logits=True
let mut sample_idx = batch.n_tokens() - 1;
for _step in 0..effective_max {
// NOTE: MutexGuard<ContextState> does not support field-split borrows
// through DerefMut, so we use a raw pointer to pass ctx immutably
// while sampler takes &mut self on its own field.
let token = {
let ctx_ptr: *const llama_cpp_2::context::LlamaContext = &state.ctx;
// SAFETY: ctx_ptr is valid for the duration of sample();
// sampler only reads ctx immutably.
state.sampler.sample(unsafe { &*ctx_ptr }, sample_idx)
};
if token == eos || self.model.is_eog_token(token) {
break;
}
output_tokens.push(token.0);
state.sampler.accept(token);
let pos = (n_prompt + output_tokens.len() - 1) as i32;
let mut single = LlamaBatch::new(1, 1);
single.add(token, pos, &[0], true)?;
state
.ctx
.decode(&mut single)
.context("Decode failed during generation")?;
sample_idx = 0;
}
// 6. Detokenize β€” use token_to_piece_bytes with 256-byte buffer
// (the deprecated tokens_to_str uses only 8 bytes, too small for some tokens)
let llama_tokens: Vec<LlamaToken> =
output_tokens.iter().map(|&t| LlamaToken::new(t)).collect();
let output = self
.detokenize_tokens(&llama_tokens)
.context("Failed to detokenize")?;
// 7. Stats
let elapsed = gen_start.elapsed();
let tok_count = output_tokens.len() as u64;
let tps = if elapsed.as_secs_f64() > 0.0 {
tok_count as f64 / elapsed.as_secs_f64()
} else {
0.0
};
self.stats.total_prompts += 1;
self.stats.total_tokens_generated += tok_count;
self.stats.total_generation_time_ms += elapsed.as_millis() as u64;
let total_secs = self.stats.total_generation_time_ms as f64 / 1000.0;
if total_secs > 0.0 {
self.stats.avg_tokens_per_second =
self.stats.total_tokens_generated as f64 / total_secs;
}
tracing::info!(
"Generated {tok_count} tok in {:.1}s ({tps:.1} t/s) β€” adapter: {:?}",
elapsed.as_secs_f64(),
state.active_adapter,
);
Ok(output)
}
/// Generate with optional LoRA adapter (apply β†’ generate β†’ remove).
pub fn generate_with_adapter(
&mut self,
prompt: &str,
max_tokens: u32,
temperature: f32,
top_p: f32,
adapter_name: Option<&str>,
) -> Result<String> {
if let Some(adapter) = adapter_name {
if adapter != "general" {
if let Err(e) = self.apply_adapter(adapter) {
tracing::warn!("Failed to apply adapter '{adapter}': {e}. Using base model.");
}
}
} else {
let _ = self.remove_adapter();
}
let result = self.generate(prompt, max_tokens, temperature, top_p, 40);
if adapter_name.is_some() && adapter_name != Some("general") {
if let Err(e) = self.remove_adapter() {
tracing::warn!("Failed to remove adapter: {e}");
}
}
result
}
/// Build a sampler chain from parameters.
fn build_sampler(temperature: f32, top_p: f32, top_k: i32) -> LlamaSampler {
if temperature <= 0.0 {
return LlamaSampler::chain_simple([LlamaSampler::greedy()]);
}
let mut chain: Vec<LlamaSampler> = Vec::new();
if top_k > 0 {
chain.push(LlamaSampler::top_k(top_k));
}
if top_p > 0.0 {
chain.push(LlamaSampler::top_p(top_p, 1));
}
chain.push(LlamaSampler::temp(temperature));
chain.push(LlamaSampler::dist(42));
LlamaSampler::chain_simple(chain)
}
// ─── Utility ─────────────────────────────────────────────────────────
/// Detokenize a slice of LlamaToken into a String.
/// Uses `token_to_piece_bytes` with 256-byte buffer per token
/// (the deprecated `tokens_to_str` uses only 8 bytes, causing errors).
fn detokenize_tokens(&self, tokens: &[LlamaToken]) -> Result<String> {
let mut output = String::with_capacity(tokens.len() * 4);
for &token in tokens {
let bytes = self
.model
.token_to_piece_bytes(token, 256, true, None)
.context("Failed to detokenize token")?;
match String::from_utf8(bytes) {
Ok(s) => output.push_str(&s),
Err(e) => {
tracing::warn!(
"Token produced invalid UTF-8: {}. Using lossy replacement.",
e
);
output.push_str(&String::from_utf8_lossy(e.as_bytes()));
}
}
}
Ok(output)
}
pub fn clear_cache(&self) {
tracing::debug!("Cache clear requested (no-op, managed by llama.cpp)");
}
pub fn stats(&self) -> &EngineStats {
&self.stats
}
pub fn is_gpu_active(&self) -> bool {
self.supports_gpu
}
pub fn model(&self) -> &LlamaModel {
&self.model
}
}
impl Drop for Engine {
fn drop(&mut self) {
// context (with &model reference) is dropped first because it comes first
// in the struct. Then model is dropped safely.
tracing::info!(
"Shutdown. {} prompts, {} tokens ({:.1} t/s avg)",
self.stats.total_prompts,
self.stats.total_tokens_generated,
self.stats.avg_tokens_per_second,
);
}
}