| | |
| |
|
| | use crate::{Error, Result}; |
| | use ndarray::{Array, IxDyn}; |
| | use std::collections::HashMap; |
| | use std::path::{Path, PathBuf}; |
| | use std::sync::{Arc, RwLock}; |
| |
|
| | |
| | pub struct OnnxSession { |
| | input_names: Vec<String>, |
| | output_names: Vec<String>, |
| | } |
| |
|
| | impl OnnxSession { |
| | |
| | pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
| | let path = path.as_ref(); |
| | if !path.exists() { |
| | return Err(Error::FileNotFound(path.display().to_string())); |
| | } |
| |
|
| | |
| | log::info!("Loading ONNX model from: {}", path.display()); |
| |
|
| | Ok(Self { |
| | input_names: vec!["input".to_string()], |
| | output_names: vec!["output".to_string()], |
| | }) |
| | } |
| |
|
| | |
| | pub fn run( |
| | &self, |
| | _inputs: HashMap<String, Array<f32, IxDyn>>, |
| | ) -> Result<HashMap<String, Array<f32, IxDyn>>> { |
| | |
| | let mut result = HashMap::new(); |
| | for name in &self.output_names { |
| | let dummy = Array::zeros(IxDyn(&[1, 1])); |
| | result.insert(name.clone(), dummy); |
| | } |
| | Ok(result) |
| | } |
| |
|
| | |
| | pub fn run_i64( |
| | &self, |
| | _inputs: HashMap<String, Array<i64, IxDyn>>, |
| | ) -> Result<HashMap<String, Array<f32, IxDyn>>> { |
| | let mut result = HashMap::new(); |
| | for name in &self.output_names { |
| | let dummy = Array::zeros(IxDyn(&[1, 1])); |
| | result.insert(name.clone(), dummy); |
| | } |
| | Ok(result) |
| | } |
| |
|
| | pub fn input_names(&self) -> &[String] { |
| | &self.input_names |
| | } |
| |
|
| | pub fn output_names(&self) -> &[String] { |
| | &self.output_names |
| | } |
| | } |
| |
|
| | |
| | pub struct ModelCache { |
| | sessions: RwLock<HashMap<String, Arc<OnnxSession>>>, |
| | model_dir: PathBuf, |
| | } |
| |
|
| | impl ModelCache { |
| | pub fn new<P: AsRef<Path>>(model_dir: P) -> Self { |
| | Self { |
| | sessions: RwLock::new(HashMap::new()), |
| | model_dir: model_dir.as_ref().to_path_buf(), |
| | } |
| | } |
| |
|
| | pub fn get_or_load(&self, name: &str) -> Result<Arc<OnnxSession>> { |
| | { |
| | let cache = self.sessions.read().unwrap(); |
| | if let Some(session) = cache.get(name) { |
| | return Ok(Arc::clone(session)); |
| | } |
| | } |
| |
|
| | let model_path = self.model_dir.join(format!("{}.onnx", name)); |
| | let session = OnnxSession::load(&model_path)?; |
| | let session = Arc::new(session); |
| |
|
| | { |
| | let mut cache = self.sessions.write().unwrap(); |
| | cache.insert(name.to_string(), Arc::clone(&session)); |
| | } |
| |
|
| | Ok(session) |
| | } |
| |
|
| | pub fn preload(&self, model_names: &[&str]) -> Result<()> { |
| | for name in model_names { |
| | self.get_or_load(name)?; |
| | } |
| | Ok(()) |
| | } |
| |
|
| | pub fn clear(&self) { |
| | let mut cache = self.sessions.write().unwrap(); |
| | cache.clear(); |
| | } |
| |
|
| | pub fn is_cached(&self, name: &str) -> bool { |
| | let cache = self.sessions.read().unwrap(); |
| | cache.contains_key(name) |
| | } |
| |
|
| | pub fn cached_models(&self) -> Vec<String> { |
| | let cache = self.sessions.read().unwrap(); |
| | cache.keys().cloned().collect() |
| | } |
| | } |
| |
|
| | #[cfg(test)] |
| | mod tests { |
| | use super::*; |
| |
|
| | #[test] |
| | fn test_model_cache_creation() { |
| | let cache = ModelCache::new("/tmp/models"); |
| | assert!(cache.cached_models().is_empty()); |
| | } |
| | } |
| |
|