//! //! This module defines helpers to allow optional Rayon usage. //! use rayon::iter::IterBridge; use rayon::prelude::*; use rayon_cond::CondIterator; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; // Re-export rayon current_num_threads pub use rayon::current_num_threads; pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM"; static USED_PARALLELISM: AtomicBool = AtomicBool::new(false); /// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set pub fn is_parallelism_configured() -> bool { std::env::var(ENV_VARIABLE).is_ok() } /// Check if at some point we used a parallel iterator pub fn has_parallelism_been_used() -> bool { USED_PARALLELISM.load(Ordering::SeqCst) } /// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable pub fn get_parallelism() -> bool { match std::env::var(ENV_VARIABLE) { Ok(mut v) => { v.make_ascii_lowercase(); !matches!(v.as_ref(), "" | "off" | "false" | "f" | "no" | "n" | "0") } Err(_) => true, // If we couldn't get the variable, we use the default } } /// Set the value for `TOKENIZERS_PARALLELISM` for the current process pub fn set_parallelism(val: bool) { std::env::set_var(ENV_VARIABLE, if val { "true" } else { "false" }) } /// Allows to convert into an iterator that can be executed either parallelly or serially. /// /// The choice is made according to the currently set `TOKENIZERS_PARALLELISM` environment variable. /// This variable can have one of the following values /// - False => "" (empty value), "false", "f", "off", "no", "n", "0" /// - True => Any other value /// pub trait MaybeParallelIterator where P: ParallelIterator, S: Iterator, { /// Convert ourself in a CondIterator, that will be executed either in parallel or serially, /// based solely on the `TOKENIZERS_PARALLELISM` environment variable fn into_maybe_par_iter(self) -> CondIterator; /// Convert ourself in a CondIterator, that will be executed either in parallel or serially, /// based on both the `TOKENIZERS_PARALLELISM` environment variable and the provided bool. /// Both must be true to run with parallelism activated. fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator; } impl MaybeParallelIterator for I where I: IntoParallelIterator + IntoIterator, P: ParallelIterator, S: Iterator, { fn into_maybe_par_iter(self) -> CondIterator { let parallelism = get_parallelism(); if parallelism { USED_PARALLELISM.store(true, Ordering::SeqCst); } CondIterator::new(self, parallelism) } fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator { if cond { self.into_maybe_par_iter() } else { CondIterator::from_serial(self) } } } /// Shared reference version of MaybeParallelIterator, works the same but returns an iterator /// over references, does not consume self pub trait MaybeParallelRefIterator<'data, P, S> where P: ParallelIterator, S: Iterator, P::Item: 'data, { fn maybe_par_iter(&'data self) -> CondIterator; fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator; } impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I where &'data I: MaybeParallelIterator, P: ParallelIterator, S: Iterator, P::Item: 'data, { fn maybe_par_iter(&'data self) -> CondIterator { self.into_maybe_par_iter() } fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator { self.into_maybe_par_iter_cond(cond) } } /// Exclusive reference version of MaybeParallelIterator, works the same but returns an iterator /// over mutable references, does not consume self pub trait MaybeParallelRefMutIterator<'data, P, S> where P: ParallelIterator, S: Iterator, P::Item: 'data, { fn maybe_par_iter_mut(&'data mut self) -> CondIterator; fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator; } impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I where &'data mut I: MaybeParallelIterator, P: ParallelIterator, S: Iterator, P::Item: 'data, { fn maybe_par_iter_mut(&'data mut self) -> CondIterator { self.into_maybe_par_iter() } fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator { self.into_maybe_par_iter_cond(cond) } } /// Converts any serial iterator into a CondIterator, that can either run parallelly or serially. pub trait MaybeParallelBridge where S: Iterator + Send, T: Send, { fn maybe_par_bridge(self) -> CondIterator, S>; fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator, S>; } impl MaybeParallelBridge for S where S: Iterator + Send, T: Send, { fn maybe_par_bridge(self) -> CondIterator, S> { let iter = CondIterator::from_serial(self); if get_parallelism() { USED_PARALLELISM.store(true, Ordering::SeqCst); CondIterator::from_parallel(iter.into_parallel().right().unwrap()) } else { iter } } fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator, S> { if cond { self.maybe_par_bridge() } else { CondIterator::from_serial(self) } } } /// Allows to convert into `chunks` that can be executed either parallelly or serially. pub trait MaybeParallelSlice<'data, T> where T: Sync, { /// Create a CondIterator, that will be executed either in parallel or serially, /// based solely on the `TOKENIZERS_PARALLELISM` environment variable fn maybe_par_chunks( &'_ self, chunk_size: usize, ) -> CondIterator, std::slice::Chunks<'_, T>>; /// Create a CondIterator, that will be executed either in parallel or serially, /// based on both the `TOKENIZERS_PARALLELISM` environment variable and the provided bool. /// Both must be true to run with parallelism activated. fn maybe_par_chunks_cond( &'_ self, cond: bool, chunk_size: usize, ) -> CondIterator, std::slice::Chunks<'_, T>>; } impl MaybeParallelSlice<'_, T> for [T] where T: Sync, { fn maybe_par_chunks( &'_ self, chunk_size: usize, ) -> CondIterator, std::slice::Chunks<'_, T>> { let parallelism = get_parallelism(); if parallelism { CondIterator::from_parallel(self.par_chunks(chunk_size)) } else { CondIterator::from_serial(self.chunks(chunk_size)) } } fn maybe_par_chunks_cond( &'_ self, cond: bool, chunk_size: usize, ) -> CondIterator, std::slice::Chunks<'_, T>> { if cond { self.maybe_par_chunks(chunk_size) } else { CondIterator::from_serial(self.chunks(chunk_size)) } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_maybe_parallel_iterator() { let mut v = vec![1u32, 2, 3, 4, 5, 6]; assert_eq!(v.maybe_par_iter().sum::(), 21); assert_eq!( v.maybe_par_iter_mut() .map(|v| { *v *= 2; *v }) .sum::(), 42 ); assert_eq!(v.maybe_par_iter().sum::(), 42); assert_eq!(v.into_maybe_par_iter().sum::(), 42); } #[test] fn test_maybe_parallel_slice() { let v = [1, 2, 3, 4, 5]; let chunks: Vec<_> = v.maybe_par_chunks(2).collect(); assert_eq!(chunks, vec![&[1, 2][..], &[3, 4], &[5]]); } }