| #include "simd_ops.h" |
| #include <immintrin.h> |
| #include <string.h> |
|
|
| |
| #if defined(_MSC_VER) |
| #include <intrin.h> |
| static __inline int ctz32(uint32_t value) { |
| unsigned long index; |
| _BitScanForward(&index, value); |
| return (int)index; |
| } |
| #define CTZ(x) ctz32(x) |
| #else |
| #define CTZ(x) __builtin_ctz(x) |
| #endif |
|
|
| |
| static inline int binary_search_chars(const uint8_t* chars, int count, uint8_t target) { |
| int left = 0, right = count - 1; |
| while (left <= right) { |
| int mid = left + (right - left) / 2; |
| if (chars[mid] == target) return mid; |
| if (chars[mid] < target) left = mid + 1; |
| else right = mid - 1; |
| } |
| return -1; |
| } |
|
|
| |
| int find_child_simd(const TrieNode* node, uint8_t target_char) { |
| |
| if (node->child_count == 0 || node->child_chars == NULL) { |
| return -1; |
| } |
| |
| |
| if (node->child_count <= 16) { |
| |
| __m128i target_vec = _mm_set1_epi8((char)target_char); |
| |
| |
| |
| __m128i chars_vec = _mm_loadu_si128((__m128i*)node->child_chars); |
| |
| |
| __m128i cmp_result = _mm_cmpeq_epi8(target_vec, chars_vec); |
| |
| |
| int mask = _mm_movemask_epi8(cmp_result); |
| |
| |
| mask &= (1 << node->child_count) - 1; |
| |
| |
| if (mask == 0) return -1; |
| |
| |
| return CTZ((uint32_t)mask); |
| } else { |
| |
| return binary_search_chars(node->child_chars, node->child_count, target_char); |
| } |
| } |
|
|
| |
| int compare_strings_avx2(const char* str1, const char* str2, size_t length) { |
| size_t i = 0; |
| |
| |
| for (; i + 32 <= length; i += 32) { |
| |
| __m256i vec1 = _mm256_loadu_si256((const __m256i*)(str1 + i)); |
| __m256i vec2 = _mm256_loadu_si256((const __m256i*)(str2 + i)); |
| |
| |
| __m256i cmp = _mm256_cmpeq_epi8(vec1, vec2); |
| |
| |
| uint32_t mask = (uint32_t)_mm256_movemask_epi8(cmp); |
| |
| |
| if (mask != 0xFFFFFFFF) { |
| |
| int offset = CTZ(~mask); |
| return (unsigned char)str1[i + offset] - (unsigned char)str2[i + offset]; |
| } |
| } |
| |
| |
| for (; i < length; i++) { |
| if (str1[i] != str2[i]) { |
| return (unsigned char)str1[i] - (unsigned char)str2[i]; |
| } |
| } |
| |
| |
| return 0; |
| } |
|
|
| |
| void classify_characters_avx2(const uint8_t* chars, uint8_t* classifications, size_t count) { |
| |
| const __m256i alpha_min = _mm256_set1_epi8('a'); |
| const __m256i alpha_max = _mm256_set1_epi8('z'); |
| const __m256i digit_min = _mm256_set1_epi8('0'); |
| const __m256i digit_max = _mm256_set1_epi8('9'); |
| const __m256i space_char = _mm256_set1_epi8(' '); |
| |
| size_t i = 0; |
| |
| for (; i + 32 <= count; i += 32) { |
| |
| __m256i char_vec = _mm256_loadu_si256((const __m256i*)(chars + i)); |
| |
| |
| |
| __m256i is_alpha = _mm256_and_si256( |
| _mm256_cmpgt_epi8(char_vec, _mm256_sub_epi8(alpha_min, _mm256_set1_epi8(1))), |
| _mm256_cmpgt_epi8(_mm256_add_epi8(alpha_max, _mm256_set1_epi8(1)), char_vec) |
| ); |
|
|
| |
| __m256i is_digit = _mm256_and_si256( |
| _mm256_cmpgt_epi8(char_vec, _mm256_sub_epi8(digit_min, _mm256_set1_epi8(1))), |
| _mm256_cmpgt_epi8(_mm256_add_epi8(digit_max, _mm256_set1_epi8(1)), char_vec) |
| ); |
| |
| |
| __m256i is_space = _mm256_cmpeq_epi8(char_vec, space_char); |
| |
| |
| __m256i result = _mm256_or_si256( |
| _mm256_and_si256(is_alpha, _mm256_set1_epi8(1)), |
| _mm256_or_si256( |
| _mm256_and_si256(is_digit, _mm256_set1_epi8(2)), |
| _mm256_and_si256(is_space, _mm256_set1_epi8(4)) |
| ) |
| ); |
| |
| |
| _mm256_storeu_si256((__m256i*)(classifications + i), result); |
| } |
| |
| |
| for (; i < count; i++) { |
| uint8_t c = chars[i]; |
| classifications[i] = 0; |
| if (c >= 'a' && c <= 'z') classifications[i] |= 1; |
| if (c >= '0' && c <= '9') classifications[i] |= 2; |
| if (c == ' ') classifications[i] |= 4; |
| } |
| } |