Anna / ToF_utils.py
JohnChiu's picture
add elegant algorithm
6fa13e0
import numpy as np
from scipy.ndimage import convolve
import os
from numpy.lib.stride_tricks import sliding_window_view
from numba import njit
import cv2
def mipi_raw10_to_raw8_scaled(raw10_data):
raw10_data = np.frombuffer(raw10_data, dtype=np.uint8)
n_blocks = len(raw10_data) // 5
raw10_data = raw10_data[:n_blocks * 5].reshape(-1, 5)
p0 = (raw10_data[:, 0].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 0) & 0x03)
p1 = (raw10_data[:, 1].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 2) & 0x03)
p2 = (raw10_data[:, 2].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 4) & 0x03)
p3 = (raw10_data[:, 3].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 6) & 0x03)
raw8_data = np.empty((n_blocks * 4 * 2,), dtype=np.uint8)
raw8_data[0::8] = p0 & 0xFF
raw8_data[1::8] = p0 >> 8
raw8_data[2::8] = p1 & 0xFF
raw8_data[3::8] = p1 >> 8
raw8_data[4::8] = p2 & 0xFF
raw8_data[5::8] = p2 >> 8
raw8_data[6::8] = p3 & 0xFF
raw8_data[7::8] = p3 >> 8
return raw8_data.tobytes()
def readRAW(path):
filesize = os.path.getsize(path)
with open(path, "rb") as f:
raw_data = f.read()
# Case 1: 如果是 MIPI RAW10 格式,大小为 7,372,800 字节
if filesize == 7372800:
raw_data = mipi_raw10_to_raw8_scaled(raw_data)
# 转换为 int16 并 reshape
arr = np.frombuffer(raw_data, dtype=np.int16).reshape(96, 240, 256)
# Byte Swap: [x,y,256] → [x,y,128,2] → swap last dim → [x,y,256]
reshaped = arr.reshape(*arr.shape[:-1], -1, 2)
swapped = reshaped[..., ::-1]
histogram_data = swapped.reshape(arr.shape)
# Line remapping (每组8行:0,4,1,5,...)
mapping = [0, 4, 1, 5, 2, 6, 3, 7]
group_size = 8
num_groups = 12 # 96 / 8
output = np.empty_like(histogram_data)
for g in range(num_groups):
start = g * group_size
end = start + group_size
output[start:end, :, :] = histogram_data[start:end, :, :][mapping, :, :]
return output.astype(np.int16)
def binning_2x2_stride2(data):
"""
data: numpy array (96, 240, 256)
return: numpy array (48, 120, 256)
"""
# 先 reshape 再求和,效率高
return data.reshape(48, 2, 120, 2, 256).sum(axis=(1, 3))
def binning_2x2_stride1(data):
"""
data: numpy array (96, 240, 256)
return: numpy array (95, 239, 256) # 因为stride=1,边界少1行1列
"""
# 直接用切片叠加四个偏移
binned = (data[:-1, :-1] + data[1:, :-1] +
data[:-1, 1:] + data[1:, 1:])
# binned = np.pad(binned, ((0,1),(0,1),(0,0)), mode='constant')
return binned
def ma_vectorized(data, kernel=[-2, -2, 1, 2, 2, 3, -1, -1]):
kernel = np.array(kernel, dtype=np.float32)
k = kernel.size
kernel_sum = kernel.sum()
# 确保 data 是 numpy array
data = np.asarray(data, dtype=np.float32)
# 取滑动窗口视图,shape: (96, 240, 256 - k + 1, k)
windows = np.lib.stride_tricks.sliding_window_view(data, window_shape=k, axis=2)
# 直接点乘kernel,然后求和,得到平滑结果 (96,240,256-k+1)
smoothed = np.tensordot(windows, kernel, axes=([3],[0])) / kernel_sum
smoothed[smoothed<0] = 0
# 为了保持和输入长度一致,可以两边补0或其他策略,这里简单在尾部补零
pad_width = ((0,0), (0,0), (0,k-1))
smoothed = np.pad(smoothed, pad_width, mode='constant', constant_values=0)
return smoothed
def ma_vectorized_fast(data, kernel=[-2, -2, 1, 2, 2, 3, -1, -1]):
kernel = np.array(kernel, dtype=np.float32)
k = kernel.size
kernel_sum = kernel.sum()
if kernel_sum == 0:
kernel_sum = 1 # 避免除 0
# padding 边界,保持中心对齐
pad = k // 2
data_padded = np.pad(data, ((0,0),(0,0),(pad,pad)), mode='edge')
# 利用 np.convolve 沿最后一个轴计算
def conv_1d(x):
return np.convolve(x, kernel[::-1], mode='valid') / kernel_sum
# 按最后一维应用
smoothed = np.apply_along_axis(conv_1d, 2, data_padded)
smoothed = np.maximum(smoothed, 0) # 负数置零
return smoothed
BIN_SIZE = 180
MAX_PEAKS = 2 # 峰值个数
@njit
def sum_hist(hist, length):
return np.sum(hist[:length])
@njit
def max_hist(hist, length):
return np.max(hist[:length])
@njit
def compute_centroid(hist, start_bin, end_bin):
bins = np.arange(start_bin, end_bin + 1)
values = hist[start_bin:end_bin + 1]
total = np.sum(values)
if total == 0:
return (start_bin + end_bin) / 2.0
return np.sum(bins * values) / total
@njit
def find_peaks_hw(histograms, histograms_ma):
"""
输入:
histograms: (H, W, 256) 原始直方图
histograms_ma: (H, W, 256) 平滑直方图
输出:
tof_data: (H, W, MAX_PEAKS) 质心位置
peak_data: (H, W, MAX_PEAKS) 峰值强度
noise_data: (H, W) 噪声值
multishot_data: (H, W) 多拍信息
totalcount: (H, W) 总计数
nt_count_n: (H, W, MAX_PEAKS) NT计数
"""
H, W, _ = histograms.shape
# 输出初始化
tof_data = np.zeros((H, W, MAX_PEAKS), dtype=np.float32)
peak_data = np.zeros((H, W, MAX_PEAKS), dtype=np.int32)
noise_data = np.zeros((H, W), dtype=np.int16)
multishot_data = np.zeros((H, W), dtype=np.int16)
totalcount = np.zeros((H, W), dtype=np.int32)
nt_count_n = np.zeros((H, W, MAX_PEAKS), dtype=np.int32)
for i in range(H):
for j in range(W):
hist_raw = histograms[i, j]
hist = histograms_ma[i, j]
count = 0
bin_idx = 1
# 多拍信息
multishot = (int(hist_raw[254]) << 8) + (int(hist_raw[255]) >> 2)
multishot_data[i, j] = multishot
totalcount[i, j] = int(sum_hist(hist_raw, BIN_SIZE) << 13) // multishot
# 噪声估计
noise_data[i, j] = int(sum_hist(hist, 8) + 4) >> 3
noise_i = max_hist(hist, 8)
# max_th = np.max(hist) * 0.01 + 25
max_th = 50
noise_i = max_th if noise_i < max_th else noise_i * 3
th = 2 * noise_i - 24 / (2e4 * 4) * noise_i * noise_i
# 峰值检测
while bin_idx < BIN_SIZE - 1 and count < MAX_PEAKS:
# 找到第一个 >= TH 的 bin
while bin_idx < BIN_SIZE - 1 and hist[bin_idx] < th:
bin_idx += 1
bin_idx -= 1
start_bin = bin_idx
start_peak = hist[start_bin]
# 找最大值 bin
while bin_idx + 1 < BIN_SIZE and (
hist[bin_idx] < hist[bin_idx - 1] or hist[bin_idx] < hist[bin_idx + 1]
):
bin_idx += 1
max_bin = bin_idx
max_peak = hist[max_bin]
# 找 end_bin
while bin_idx + 1 < BIN_SIZE and (
hist[bin_idx] > th or hist[bin_idx] > start_peak or hist[bin_idx] > hist[bin_idx + 1]
):
bin_idx += 1
end_bin = bin_idx
if (
start_bin == end_bin
or start_bin == max_bin
or max_bin == end_bin
or (max_peak - start_peak) < 50
):
bin_idx += 1
continue
# 质心
centroid = compute_centroid(hist, start_bin, end_bin)
tof_data[i, j, count] = centroid
peak_data[i, j, count] = (int(hist[max_bin]) << 13) // multishot
# NT count
nt_end_bin = max_bin - 10
nt_start_bin = nt_end_bin - 48
nt_start_bin = 0 if nt_start_bin < 0 else nt_start_bin
nt_num = nt_end_bin - nt_start_bin
est_nt = 48 * noise_data[i, j] if nt_num < 48 else 0
nt_count_n[i, j, count] = np.sum(hist[nt_start_bin:nt_start_bin + nt_num]) + est_nt
count += 1
peak_data=peak_data*256/48000
return tof_data, peak_data, noise_data, multishot_data, totalcount, nt_count_n
def local_threshold(img, window_size=15, C=2):
h, w = img.shape
out = np.zeros_like(img, dtype=np.uint8)
pad = window_size // 2
padded = cv2.copyMakeBorder(img, pad, pad, pad, pad, cv2.BORDER_REFLECT)
for i in range(h):
for j in range(w):
local_region = padded[i:i+window_size, j:j+window_size]
# thresh = np.mean(local_region) - C
thresh = np.mean(local_region,axis=(0,1)) - 0.1
out[i,j] = img[i,j] if img[i,j] > thresh else 0
return out
def select_peaks_hw(tof_data, peak_data):
"""
输入:
tof_data: (H, W, MAX_PEAKS) 质心位置
peak_data: (H, W, MAX_PEAKS) 峰值强度
输出:
tof: (H, W)
peak: (H, W)
"""
# peak_data = peak_data *256/48000
ref_set = np.zeros_like(peak_data)
for i in range(MAX_PEAKS):
ref_set[...,i] = peak_data[...,i]*tof_data[...,i]*tof_data[...,i] /1200 * 6
ref = np.log2(ref_set[...,i])
ref[ref<0] = 0
_, otsu_binary = cv2.threshold(ref.astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# dynamic_ref = np.max(ref,axis=(0,1)) * 0.8
# ref[ref>100] = 100
# ref[ref<3] = 0
# ref = local_threshold(ref,5,0)
ref_set[...,i] =(ref)
# ref_set[...,1]*= 3
ref = np.max(ref_set,axis=2)
frame1 = ref_set[...,0]
frame2 = ref_set[...,1]
# 两个 mask 初始化为 0
mask1 = np.zeros_like(frame1, dtype=np.uint8)
mask2 = np.zeros_like(frame2, dtype=np.uint8)
# 逐像素比较:frame1 > frame2
mask1[(frame1 > frame2) & (frame1 > 0)] = 1
mask2[(frame2 > frame1) & (frame2 > 0)] = 1
# 如果相等且非零,可以任选一帧,这里给 mask1
mask1[(frame1 == frame2) & (frame1 > 0)] = 1
tof = tof_data[...,0] * mask1 + tof_data[...,1] * mask2
peak = peak_data[...,0] * mask1 + peak_data[...,1] * mask2
return tof,peak,ref,ref_set