File size: 4,766 Bytes
0ab3d49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
// Base64 字符集 (a-z, A-Z, 0-9, -, _)
const BASE64_CHARS: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_";
// 预计算的 Base64 查找表,用于快速解码
const BASE64_LOOKUP: [i8; 256] = {
let mut lookup = [-1i8; 256];
let mut i = 0;
while i < BASE64_CHARS.len() {
lookup[BASE64_CHARS[i] as usize] = i as i8;
i += 1;
}
lookup
};
/// 将字节切片编码为 Base64 字符串。
///
/// # Arguments
///
/// * `bytes`: 要编码的字节切片
///
/// # Returns
///
/// 编码后的 Base64 字符串
pub fn to_base64(bytes: &[u8]) -> String {
// 预分配足够容量,避免多次分配内存
let capacity = (bytes.len() + 2) / 3 * 4;
let mut result = Vec::with_capacity(capacity);
// 每三个字节为一组进行处理
for chunk in bytes.chunks(3) {
// 将三个字节合并为一个 u32
let b1 = chunk[0] as u32;
let b2 = chunk.get(1).map_or(0, |&b| b as u32);
let b3 = chunk.get(2).map_or(0, |&b| b as u32);
let n = (b1 << 16) | (b2 << 8) | b3;
// 将 u32 拆分成四个 6 位的值,并根据查找表转换为 Base64 字符
result.push(BASE64_CHARS[(n >> 18) as usize]);
result.push(BASE64_CHARS[((n >> 12) & 0x3F) as usize]);
// 如果 chunk 长度大于 1,则需要处理第二个字符
if chunk.len() > 1 {
result.push(BASE64_CHARS[((n >> 6) & 0x3F) as usize]);
// 如果 chunk 长度大于 2,则需要处理第三个字符
if chunk.len() > 2 {
result.push(BASE64_CHARS[(n & 0x3F) as usize]);
}
}
}
// 使用 from_utf8_unchecked 提高性能,因为 BASE64_CHARS 都是有效的 ASCII 字符
unsafe { String::from_utf8_unchecked(result) }
}
/// 将 Base64 字符串解码为字节数组。
///
/// # Arguments
///
/// * `input`: 要解码的 Base64 字符串
///
/// # Returns
///
/// 如果解码成功,返回 Some(解码后的字节数组);如果输入无效,返回 None
pub fn from_base64(input: &str) -> Option<Vec<u8>> {
let input = input.as_bytes();
// 检查输入长度,Base64 编码的长度必须是 4 的倍数或余 2/3
if input.is_empty() || input.len() % 4 == 1 {
return None;
}
// 检查是否包含无效字符,无效字符直接返回None
if input.iter().any(|&b| BASE64_LOOKUP[b as usize] == -1) {
return None;
}
// 预分配足够容量,避免多次分配内存
let capacity = input.len() / 4 * 3;
let mut result = Vec::with_capacity(capacity);
// 每四个字符为一组进行处理
let mut chunks = input.chunks_exact(4);
for chunk in &mut chunks {
// 使用查找表将 Base64 字符转换为 6 位的值
let n1 = BASE64_LOOKUP[chunk[0] as usize] as u32;
let n2 = BASE64_LOOKUP[chunk[1] as usize] as u32;
let n3 = BASE64_LOOKUP[chunk[2] as usize] as u32;
let n4 = BASE64_LOOKUP[chunk[3] as usize] as u32;
// 将四个 6 位的值合并为一个 u32,并拆分成三个字节
let n = (n1 << 18) | (n2 << 12) | (n3 << 6) | n4;
result.push((n >> 16) as u8);
result.push(((n >> 8) & 0xFF) as u8);
result.push((n & 0xFF) as u8);
}
// 处理剩余的字符
let remainder = chunks.remainder();
if !remainder.is_empty() {
let n1 = BASE64_LOOKUP[remainder[0] as usize] as u32;
let n2 = BASE64_LOOKUP[remainder[1] as usize] as u32;
let mut n = (n1 << 18) | (n2 << 12);
result.push((n >> 16) as u8);
// 如果剩余字符长度大于 2,则需要处理第二个字节
if remainder.len() > 2 {
let n3 = BASE64_LOOKUP[remainder[2] as usize] as u32;
n |= n3 << 6;
result.push(((n >> 8) & 0xFF) as u8);
}
}
Some(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_base64_roundtrip() {
let test_cases = vec![
vec![0u8, 1, 2, 3],
vec![255u8, 254, 253],
vec![0u8],
vec![0u8, 1],
vec![0u8, 1, 2],
vec![255u8; 1000],
];
for case in test_cases {
let encoded = to_base64(&case);
let decoded = from_base64(&encoded).unwrap();
assert_eq!(case, decoded);
}
}
#[test]
fn test_invalid_input() {
assert_eq!(from_base64(""), None); // 空字符串
assert_eq!(from_base64("a"), None); // 长度为 1
assert_eq!(from_base64("!@#$"), None); // 无效字符
assert_eq!(from_base64("YWJj!"), None); // 包含无效字符
assert!(from_base64("YWJj").is_some()); // 有效输入
}
}
|