use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::Duration; use anyhow::{Context, Result}; use camino::Utf8PathBuf; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff}; use tokio::sync::broadcast; use crate::downloads::Downloads; use crate::packages::PackageCatalog; const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); pub type RuntimeHttpClient = Arc; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ComputePolicy { PreferGpu, CpuOnly, } #[derive(Debug, Clone)] pub struct RuntimeHttpConfig { pub connect_timeout_secs: u64, pub read_timeout_secs: u64, pub max_retries: u32, } impl Default for RuntimeHttpConfig { fn default() -> Self { Self { connect_timeout_secs: 20, read_timeout_secs: 300, max_retries: 3, } } } impl RuntimeHttpConfig { pub fn build_client(&self) -> Result { let base = reqwest::Client::builder() .user_agent(USER_AGENT) .connect_timeout(Duration::from_secs(self.connect_timeout_secs)) .read_timeout(Duration::from_secs(self.read_timeout_secs)) .build()?; Ok(Arc::new( ClientBuilder::new(base) .with(RetryTransientMiddleware::new_with_policy( ExponentialBackoff::builder().build_with_max_retries(self.max_retries), )) .build(), )) } } // FIXME: move this function to a more appropriate place, e.g. a `config` module pub fn default_app_data_root() -> Utf8PathBuf { // Env-override (tests, CI, portable installs). if let Ok(v) = std::env::var("KOHARU_DATA_ROOT") && !v.is_empty() { return Utf8PathBuf::from(v); } if let Ok(exe) = std::env::current_exe() && let Some(root) = exe.parent() && root.join("config.toml").is_file() { return Utf8PathBuf::from_path_buf(root.to_path_buf()) .unwrap_or_else(|path| Utf8PathBuf::from(path.to_string_lossy().into_owned())); } let root = dirs::data_local_dir() .or_else(dirs::data_dir) .unwrap_or_else(std::env::temp_dir) .join("Koharu"); Utf8PathBuf::from_path_buf(root) .unwrap_or_else(|path| Utf8PathBuf::from(path.to_string_lossy().into_owned())) } #[derive(Clone)] pub struct Runtime { inner: Arc, } struct RuntimeInner { root: PathBuf, compute: ComputePolicy, downloads: Downloads, packages: PackageCatalog, } impl Runtime { pub fn new(root: impl Into, compute: ComputePolicy) -> Result { Self::new_with_http(root, compute, RuntimeHttpConfig::default()) } pub fn new_with_http( root: impl Into, compute: ComputePolicy, http: RuntimeHttpConfig, ) -> Result { let root = root.into(); let downloads = Downloads::new( root.join("runtime").join(".downloads"), root.join("models").join("huggingface"), &http, )?; Ok(Self { inner: Arc::new(RuntimeInner { root, compute, downloads, packages: PackageCatalog::discover(), }), }) } pub fn root(&self) -> &Path { &self.inner.root } pub fn wants_gpu(&self) -> bool { matches!(self.inner.compute, ComputePolicy::PreferGpu) } pub fn http_client(&self) -> RuntimeHttpClient { self.inner.downloads.client() } pub fn subscribe_downloads(&self) -> broadcast::Receiver { self.inner.downloads.subscribe() } pub fn downloads(&self) -> Downloads { self.inner.downloads.clone() } pub async fn prepare(&self) -> Result<()> { let dirs = [ self.root().join("runtime"), self.root().join("runtime").join(".downloads"), self.root().join("models"), self.root().join("models").join("huggingface"), ]; for dir in dirs { std::fs::create_dir_all(&dir) .with_context(|| format!("failed to create `{}`", dir.display()))?; } self.inner.packages.prepare_bootstrap(self).await } pub fn llama_directory(&self) -> Result { crate::llama::runtime_dir(self) } } pub type RuntimeManager = Runtime; #[cfg(test)] mod tests { use std::fs; use anyhow::Result; use super::*; #[tokio::test] #[ignore] async fn prepares_llama_runtime_into_configured_root() -> Result<()> { let tempdir = tempfile::tempdir()?; let runtime = Runtime::new(tempdir.path(), ComputePolicy::CpuOnly)?; runtime.prepare().await?; assert!(runtime.llama_directory()?.exists()); Ok(()) } #[tokio::test] #[ignore] async fn repeated_basename_loads_succeed_after_prepare() -> Result<()> { let tempdir = tempfile::tempdir()?; let runtime = Runtime::new(tempdir.path(), ComputePolicy::CpuOnly)?; runtime.prepare().await?; let dir = runtime.llama_directory()?; let lib_name = fs::read_dir(&dir)? .filter_map(|entry| entry.ok()) .filter_map(|entry| { let name = entry.file_name().to_string_lossy().into_owned(); name.contains("llama").then_some(name) }) .next() .ok_or_else(|| anyhow::anyhow!("no llama library found"))?; let _first = crate::load_library_by_name(&lib_name)?; let _second = crate::load_library_by_name(&lib_name)?; Ok(()) } }