| |
| |
| |
|
|
| use rayon::iter::IterBridge; |
| use rayon::prelude::*; |
| use rayon_cond::CondIterator; |
| use std::sync::atomic::AtomicBool; |
| use std::sync::atomic::Ordering; |
|
|
| |
| pub use rayon::current_num_threads; |
|
|
| pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM"; |
|
|
| static USED_PARALLELISM: AtomicBool = AtomicBool::new(false); |
|
|
| |
| pub fn is_parallelism_configured() -> bool { |
| std::env::var(ENV_VARIABLE).is_ok() |
| } |
|
|
| |
| pub fn has_parallelism_been_used() -> bool { |
| USED_PARALLELISM.load(Ordering::SeqCst) |
| } |
|
|
| |
| pub fn get_parallelism() -> bool { |
| match std::env::var(ENV_VARIABLE) { |
| Ok(mut v) => { |
| v.make_ascii_lowercase(); |
| !matches!(v.as_ref(), "" | "off" | "false" | "f" | "no" | "n" | "0") |
| } |
| Err(_) => true, |
| } |
| } |
|
|
| |
| pub fn set_parallelism(val: bool) { |
| std::env::set_var(ENV_VARIABLE, if val { "true" } else { "false" }) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| pub trait MaybeParallelIterator<P, S> |
| where |
| P: ParallelIterator, |
| S: Iterator<Item = P::Item>, |
| { |
| |
| |
| fn into_maybe_par_iter(self) -> CondIterator<P, S>; |
| |
| |
| |
| fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S>; |
| } |
|
|
| impl<P, S, I> MaybeParallelIterator<P, S> for I |
| where |
| I: IntoParallelIterator<Iter = P, Item = P::Item> + IntoIterator<IntoIter = S, Item = S::Item>, |
| P: ParallelIterator, |
| S: Iterator<Item = P::Item>, |
| { |
| fn into_maybe_par_iter(self) -> CondIterator<P, S> { |
| let parallelism = get_parallelism(); |
| if parallelism { |
| USED_PARALLELISM.store(true, Ordering::SeqCst); |
| } |
| CondIterator::new(self, parallelism) |
| } |
|
|
| fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S> { |
| if cond { |
| self.into_maybe_par_iter() |
| } else { |
| CondIterator::from_serial(self) |
| } |
| } |
| } |
|
|
| |
| |
| pub trait MaybeParallelRefIterator<'data, P, S> |
| where |
| P: ParallelIterator, |
| S: Iterator<Item = P::Item>, |
| P::Item: 'data, |
| { |
| fn maybe_par_iter(&'data self) -> CondIterator<P, S>; |
| fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S>; |
| } |
|
|
| impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I |
| where |
| &'data I: MaybeParallelIterator<P, S>, |
| P: ParallelIterator, |
| S: Iterator<Item = P::Item>, |
| P::Item: 'data, |
| { |
| fn maybe_par_iter(&'data self) -> CondIterator<P, S> { |
| self.into_maybe_par_iter() |
| } |
|
|
| fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S> { |
| self.into_maybe_par_iter_cond(cond) |
| } |
| } |
|
|
| |
| |
| pub trait MaybeParallelRefMutIterator<'data, P, S> |
| where |
| P: ParallelIterator, |
| S: Iterator<Item = P::Item>, |
| P::Item: 'data, |
| { |
| fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S>; |
| fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S>; |
| } |
|
|
| impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I |
| where |
| &'data mut I: MaybeParallelIterator<P, S>, |
| P: ParallelIterator, |
| S: Iterator<Item = P::Item>, |
| P::Item: 'data, |
| { |
| fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S> { |
| self.into_maybe_par_iter() |
| } |
|
|
| fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S> { |
| self.into_maybe_par_iter_cond(cond) |
| } |
| } |
|
|
| |
| pub trait MaybeParallelBridge<T, S> |
| where |
| S: Iterator<Item = T> + Send, |
| T: Send, |
| { |
| fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S>; |
| fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S>; |
| } |
|
|
| impl<T, S> MaybeParallelBridge<T, S> for S |
| where |
| S: Iterator<Item = T> + Send, |
| T: Send, |
| { |
| fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S> { |
| let iter = CondIterator::from_serial(self); |
|
|
| if get_parallelism() { |
| USED_PARALLELISM.store(true, Ordering::SeqCst); |
| CondIterator::from_parallel(iter.into_parallel().right().unwrap()) |
| } else { |
| iter |
| } |
| } |
|
|
| fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S> { |
| if cond { |
| self.maybe_par_bridge() |
| } else { |
| CondIterator::from_serial(self) |
| } |
| } |
| } |
|
|
| |
| pub trait MaybeParallelSlice<'data, T> |
| where |
| T: Sync, |
| { |
| |
| |
| fn maybe_par_chunks( |
| &'_ self, |
| chunk_size: usize, |
| ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>>; |
| |
| |
| |
| fn maybe_par_chunks_cond( |
| &'_ self, |
| cond: bool, |
| chunk_size: usize, |
| ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>>; |
| } |
|
|
| impl<T> MaybeParallelSlice<'_, T> for [T] |
| where |
| T: Sync, |
| { |
| fn maybe_par_chunks( |
| &'_ self, |
| chunk_size: usize, |
| ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>> { |
| let parallelism = get_parallelism(); |
| if parallelism { |
| CondIterator::from_parallel(self.par_chunks(chunk_size)) |
| } else { |
| CondIterator::from_serial(self.chunks(chunk_size)) |
| } |
| } |
| fn maybe_par_chunks_cond( |
| &'_ self, |
| cond: bool, |
| chunk_size: usize, |
| ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>> { |
| if cond { |
| self.maybe_par_chunks(chunk_size) |
| } else { |
| CondIterator::from_serial(self.chunks(chunk_size)) |
| } |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_maybe_parallel_iterator() { |
| let mut v = vec![1u32, 2, 3, 4, 5, 6]; |
|
|
| assert_eq!(v.maybe_par_iter().sum::<u32>(), 21); |
| assert_eq!( |
| v.maybe_par_iter_mut() |
| .map(|v| { |
| *v *= 2; |
| *v |
| }) |
| .sum::<u32>(), |
| 42 |
| ); |
| assert_eq!(v.maybe_par_iter().sum::<u32>(), 42); |
| assert_eq!(v.into_maybe_par_iter().sum::<u32>(), 42); |
| } |
|
|
| #[test] |
| fn test_maybe_parallel_slice() { |
| let v = [1, 2, 3, 4, 5]; |
|
|
| let chunks: Vec<_> = v.maybe_par_chunks(2).collect(); |
| assert_eq!(chunks, vec![&[1, 2][..], &[3, 4], &[5]]); |
| } |
| } |
|
|