File size: 5,530 Bytes
708f4a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | #include "simd_ops.h"
#include <immintrin.h>
#include <string.h>
// Cross-platform count trailing zeros (CTZ) macro
#if defined(_MSC_VER)
#include <intrin.h>
static __inline int ctz32(uint32_t value) {
unsigned long index;
_BitScanForward(&index, value);
return (int)index;
}
#define CTZ(x) ctz32(x)
#else
#define CTZ(x) __builtin_ctz(x)
#endif
// Helper for binary search fallback [cite: 426]
static inline int binary_search_chars(const uint8_t* chars, int count, uint8_t target) {
int left = 0, right = count - 1;
while (left <= right) {
int mid = left + (right - left) / 2;
if (chars[mid] == target) return mid;
if (chars[mid] < target) left = mid + 1;
else right = mid - 1;
}
return -1;
}
// [cite: 414] SIMD-optimized character search
int find_child_simd(const TrieNode* node, uint8_t target_char) {
// Handle empty nodes (leaf nodes with no children)
if (node->child_count == 0 || node->child_chars == NULL) {
return -1;
}
// [cite: 415] Use SIMD for small child sets (<= 16)
if (node->child_count <= 16) {
// [cite: 418] Set target vector
__m128i target_vec = _mm_set1_epi8((char)target_char);
// Load child characters (unaligned load is safe)
// Note: child_chars must be padded to 16 bytes allocation-side
__m128i chars_vec = _mm_loadu_si128((__m128i*)node->child_chars);
// [cite: 420] Compare
__m128i cmp_result = _mm_cmpeq_epi8(target_vec, chars_vec);
// [cite: 421] Create mask
int mask = _mm_movemask_epi8(cmp_result);
// Mask out positions beyond child_count
mask &= (1 << node->child_count) - 1;
// [cite: 422] Check result
if (mask == 0) return -1;
// [cite: 423] Return index of first match (Count Trailing Zeros)
return CTZ((uint32_t)mask);
} else {
// [cite: 425] Fallback to binary search for large child sets
return binary_search_chars(node->child_chars, node->child_count, target_char);
}
}
// [cite: 487] Compare strings using AVX2
int compare_strings_avx2(const char* str1, const char* str2, size_t length) {
size_t i = 0;
// [cite: 489] Process in 32-byte chunks
for (; i + 32 <= length; i += 32) {
// Load 256-bit vectors
__m256i vec1 = _mm256_loadu_si256((const __m256i*)(str1 + i));
__m256i vec2 = _mm256_loadu_si256((const __m256i*)(str2 + i));
// [cite: 493] Compare equality
__m256i cmp = _mm256_cmpeq_epi8(vec1, vec2);
// [cite: 495] Move mask
uint32_t mask = (uint32_t)_mm256_movemask_epi8(cmp);
// [cite: 496] If not all ones (0xFFFFFFFF), we found a mismatch
if (mask != 0xFFFFFFFF) {
// [cite: 498] Find exact position
int offset = CTZ(~mask);
return (unsigned char)str1[i + offset] - (unsigned char)str2[i + offset];
}
}
// [cite: 502] Handle remaining bytes
for (; i < length; i++) {
if (str1[i] != str2[i]) {
return (unsigned char)str1[i] - (unsigned char)str2[i];
}
}
// [cite: 505] Strings match
return 0;
}
// [cite: 525] Vectorized Character Classification
void classify_characters_avx2(const uint8_t* chars, uint8_t* classifications, size_t count) {
// [cite: 526-529] Pre-computed constants
const __m256i alpha_min = _mm256_set1_epi8('a');
const __m256i alpha_max = _mm256_set1_epi8('z');
const __m256i digit_min = _mm256_set1_epi8('0');
const __m256i digit_max = _mm256_set1_epi8('9');
const __m256i space_char = _mm256_set1_epi8(' ');
size_t i = 0;
// [cite: 530] Loop 32 chars at a time
for (; i + 32 <= count; i += 32) {
// [cite: 532] Load
__m256i char_vec = _mm256_loadu_si256((const __m256i*)(chars + i));
// [cite: 533-536] Is Alpha logic (simplified for AVX comparison quirks)
// Note: PCMPGT compares signed bytes. We assume ASCII range here.
__m256i is_alpha = _mm256_and_si256(
_mm256_cmpgt_epi8(char_vec, _mm256_sub_epi8(alpha_min, _mm256_set1_epi8(1))),
_mm256_cmpgt_epi8(_mm256_add_epi8(alpha_max, _mm256_set1_epi8(1)), char_vec)
);
// [cite: 537-539] Is Digit logic
__m256i is_digit = _mm256_and_si256(
_mm256_cmpgt_epi8(char_vec, _mm256_sub_epi8(digit_min, _mm256_set1_epi8(1))),
_mm256_cmpgt_epi8(_mm256_add_epi8(digit_max, _mm256_set1_epi8(1)), char_vec)
);
// [cite: 540] Is Space
__m256i is_space = _mm256_cmpeq_epi8(char_vec, space_char);
// [cite: 543-544] Combine results: Alpha=1, Digit=2, Space=4
__m256i result = _mm256_or_si256(
_mm256_and_si256(is_alpha, _mm256_set1_epi8(1)),
_mm256_or_si256(
_mm256_and_si256(is_digit, _mm256_set1_epi8(2)),
_mm256_and_si256(is_space, _mm256_set1_epi8(4))
)
);
// [cite: 546] Store
_mm256_storeu_si256((__m256i*)(classifications + i), result);
}
// Fallback for remaining
for (; i < count; i++) {
uint8_t c = chars[i];
classifications[i] = 0;
if (c >= 'a' && c <= 'z') classifications[i] |= 1;
if (c >= '0' && c <= '9') classifications[i] |= 2;
if (c == ' ') classifications[i] |= 4;
}
} |